diff --git a/.bandit b/.bandit index 5b31b5bc1..130bf859f 100644 --- a/.bandit +++ b/.bandit @@ -1,5 +1,6 @@ -[bandit] -skips: B303 - -# Disabled: -# B303, Use of insecure MD2, MD4, MD5, or SHA1 hash function. +[bandit] +skips: B101, B303 + +# Disabled: +# B101, Use of assert detected. The enclosed code will be removed when compiling to optimised byte code +# B303, Use of insecure MD2, MD4, MD5, or SHA1 hash function. diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 4812a33b7..0c6eb7409 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 1.0.0 +current_version = 1.2.6 files = setup.py straxen/__init__.py commit = True tag = True diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..27c6416af --- /dev/null +++ b/.gitattributes @@ -0,0 +1,9 @@ +# https://www.aleksandrhovhannisyan.com/blog/crlf-vs-lf-normalizing-line-endings-in-git/#a-simple-gitattributes-config +# We'll let Git's auto-detection algorithm infer if a file is text. If it is, +# enforce LF line endings regardless of OS or git configurations. +* text=auto eol=lf + +# Isolate binary files in case the auto-detection algorithm fails and +# marks them as text files (which could brick them). +*.{png,jpg,jpeg,gif,webp,woff,woff2} binary + diff --git a/.github/dependabot.yml b/.github/dependabot.yml index dfda41feb..38ee7d87f 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,23 +1,23 @@ -# Set update schedule for GitHub Actions to check they are up to date -# If one of the github actions is out of date, dependabot will open a -# PR to update the version of that action - -version: 2 -updates: - # Maintain the requirements in the github actiuons - - package-ecosystem: "github-actions" - directory: "/" - schedule: - # Check for updates to GitHub Actions every weekday - interval: "weekly" - assignees: - - jorana - # Maintain the requirements requirements folder - - package-ecosystem: "pip" - directory: "/extra_requirements" - schedule: - # Check for updates to requirements every week - interval: "monthly" - open-pull-requests-limit: 15 - assignees: - - jorana +# Set update schedule for GitHub Actions to check they are up to date +# If one of the github actions is out of date, dependabot will open a +# PR to update the version of that action + +version: 2 +updates: + # Maintain the requirements in the github actiuons + - package-ecosystem: "github-actions" + directory: "/" + schedule: + # Check for updates to GitHub Actions every weekday + interval: "weekly" + assignees: + - jorana + # Maintain the requirements requirements folder + - package-ecosystem: "pip" + directory: "/extra_requirements" + schedule: + # Check for updates to requirements every week + interval: "monthly" + open-pull-requests-limit: 15 + assignees: + - jorana diff --git a/.github/scripts/create-utilix-config.sh b/.github/scripts/create-utilix-config.sh deleted file mode 100644 index 47767cf6f..000000000 --- a/.github/scripts/create-utilix-config.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -# Create config with write permissions! -cat > $HOME/.xenon_config < requirements.txt -cat requirements.txt diff --git a/.github/scripts/install_straxen.sh b/.github/scripts/install_straxen.sh deleted file mode 100644 index 711741fdd..000000000 --- a/.github/scripts/install_straxen.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash - -start=`pwd` -cd straxen/straxen -python setup.py install -cd $start diff --git a/.github/scripts/update-context-collection.py b/.github/scripts/update-context-collection.py deleted file mode 100644 index bbbd85b50..000000000 --- a/.github/scripts/update-context-collection.py +++ /dev/null @@ -1,42 +0,0 @@ -import strax -import straxen -from straxen.contexts import * -from utilix import DB -import datetime - -db = DB() - -# list of contexts that gets tracked in runDB context collection -# needs to be maintained for each straxen release -context_list = ['xenonnt_led', - 'xenonnt_online', - ] - - -# returns the list of dtype, hashes for a given strax context -def get_hashes(st): - return set([(d, st.key_for('0', d).lineage_hash) - for p in st._plugin_class_registry.values() - for d in p.provides]) - - -def main(): - for context in context_list: - # get these from straxen.contexts.* - st = eval("%s()" % context) - hashes = get_hashes(st) - hash_dict = {dtype: h for dtype, h in hashes} - - doc = dict(name=context, - date_added=datetime.datetime.utcnow(), - hashes=hash_dict, - straxen_version=straxen.__version__, - strax_version=strax.__version__ - ) - - # update the context collection using utilix + runDB_api - db.update_context_collection(doc) - - -if __name__ == "__main__": - main() diff --git a/.github/workflows/code_style.yml b/.github/workflows/code_style.yml index d7c3862b5..fcab7068b 100644 --- a/.github/workflows/code_style.yml +++ b/.github/workflows/code_style.yml @@ -1,25 +1,25 @@ -name: Python style -on: - pull_request: - types: [opened] -jobs: - qa: - name: Quality check - runs-on: ubuntu-18.04 - steps: - - uses: actions/checkout@v2 - - name: Set up Python - uses: actions/setup-python@master - with: - python-version: 3.8 - - name: Change __all__ exports for pyflake - run: | - bash .github/scripts/pre_pyflakes.sh - - name: Wemake Python Stylguide - uses: wemake-services/wemake-python-styleguide@0.15.3 - continue-on-error: true - with: - reporter: 'github-pr-review' - env: - NUMBA_DISABLE_JIT: 1 - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} +name: Python style +on: + pull_request: + types: [opened] +jobs: + qa: + name: Quality check + runs-on: ubuntu-18.04 + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@master + with: + python-version: 3.8 + - name: Change __all__ exports for pyflake + run: | + bash .github/scripts/pre_pyflakes.sh + - name: Wemake Python Stylguide + uses: wemake-services/wemake-python-styleguide@0.16.0 + continue-on-error: true + with: + reporter: 'github-pr-review' + env: + NUMBA_DISABLE_JIT: 1 + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/contexts.yml b/.github/workflows/contexts.yml deleted file mode 100644 index 75839d396..000000000 --- a/.github/workflows/contexts.yml +++ /dev/null @@ -1,41 +0,0 @@ -# Automatically update the context collection in the runDB - -name: Update context collection - -# Controls when the action will run. - -# Trigger this code when a new release is published -on: - workflow_dispatch: - release: - types: [created] - -jobs: - update: - runs-on: ubuntu-latest - steps: - - name: Setup python - uses: actions/setup-python@v2 - with: - python-version: '3.8' - - name: Checkout repo - uses: actions/checkout@v2 - - name: Install python dependencies - uses: py-actions/py-dependency-install@v2 - - name: Install straxen - run: bash .github/scripts/install_straxen.sh - # writes a utilix configuration file. Uses the secret functionality of GitHub. - - name: Write utilix config - run: | - bash .github/scripts/create-utilix-config.sh - env: - RUNDB_API_URL: ${{ secrets.RUNDB_API_URL }} - RUNDB_API_USER: ${{ secrets.RUNDB_API_USER }} - RUNDB_API_PASSWORD: ${{ secrets.RUNDB_API_PASSWORD }} - PYMONGO_URL: ${{ secrets.PYMONGO_URL }} - PYMONGO_USER: ${{ secrets.PYMONGO_USER }} - PYMONGO_PASSWORD: ${{ secrets.PYMONGO_PASSWORD }} - PYMONGO_DATABASE: ${{ secrets.PYMONGO_DATABASE }} - - name: Update context - run: | - python .github/scripts/update-context-collection.py diff --git a/.github/workflows/pypi_install.yml b/.github/workflows/pypi_install.yml index 394dc3ed8..5d6335c3d 100644 --- a/.github/workflows/pypi_install.yml +++ b/.github/workflows/pypi_install.yml @@ -1,32 +1,36 @@ -# Pipy upload straxen after a release (or manually). -## Mostly based on https://github.com/marketplace/actions/pypi-publish -name: Pipy - -on: - workflow_dispatch: - release: - types: [created] - -jobs: - build: - runs-on: ubuntu-latest - steps: - # Setup steps - - name: Setup python - uses: actions/setup-python@v2 - with: - python-version: '3.8' - - name: Checkout repo - uses: actions/checkout@v2 - - name: Install dependencies - run: pip install wheel - - name: Build package - run: python setup.py sdist bdist_wheel - # Do the publish - - name: Publish a Python distribution to PyPI - # Might want to add but does not work on workflow_dispatch : - # if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') - uses: pypa/gh-action-pypi-publish@master - with: - user: ${{ secrets.token }} - password: ${{ secrets.pypi_password }} +# Pipy upload straxen after a release (or manually). +## Mostly based on https://github.com/marketplace/actions/pypi-publish +name: Pipy + +on: + workflow_dispatch: + release: + types: [created] + +jobs: + build: + runs-on: ubuntu-latest + steps: + # Setup steps + - name: Setup python + uses: actions/setup-python@v2 + with: + python-version: '3.8' + + - name: Checkout repo + uses: actions/checkout@v2 + + - name: Install dependencies + run: pip install wheel + + - name: Build package + run: python setup.py sdist bdist_wheel + + - name: Publish a Python distribution to PyPI + # Do the publishing + # Might want to add but does not work on workflow_dispatch : + # if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') + uses: pypa/gh-action-pypi-publish@master + with: + user: ${{ secrets.token }} + password: ${{ secrets.pypi_password }} diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 877e71928..07f8cbef4 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -1,116 +1,140 @@ -# Test straxen on each PR. -# We run three types of tests: -# - Pytest -> these are the "normal" tests and should be run for all -# python versions -# - Coveralls -> this is to see if we are coverering all our lines of -# code with our tests. The results get uploaded to -# coveralls.io/github/XENONnT/straxen -# - pytest_no_database -> we want to make sure we can run the tests even -# if we don't have access to our datebase since this will e.g. happen -# when someone is pushing a PR from their own fork as we don't -# propagate our secrets there. - -name: Test package - -# Trigger this code when a new release is published -on: - workflow_dispatch: - release: - types: [created] - pull_request: - branches: - - master - - stable - push: - branches: - - master - -jobs: - update: - name: "${{ matrix.test }}_py${{ matrix.python-version }}" - runs-on: ubuntu-latest - strategy: - fail-fast: False - matrix: - python-version: [3.7, 3.8, 3.9] - test: ['coveralls', 'pytest', 'pytest_no_database'] - # Only run coverage / no_database on py3.8 - exclude: - - python-version: 3.7 - test: coveralls - - python-version: 3.9 - test: coveralls - - python-version: 3.7 - test: pytest_no_database - - python-version: 3.9 - test: pytest_no_database - steps: - # Setup and installation - - name: Setup python - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Checkout repo - uses: actions/checkout@v2 - - name: Remove strax from reqs. - if: matrix.test == 'coveralls' && env.HAVE_ACCESS_TO_SECTETS != null - env: - HAVE_ACCESS_TO_SECTETS: ${{ secrets.RUNDB_API_URL }} - run: | - bash .github/scripts/filter_strax_from_requirements.sh - - name: Install requirements for tests and latest strax - run: | - pip install -r extra_requirements/requirements-tests.txt - pip install git+https://github.com/AxFoundation/strax.git - - # Secrets and required files - - name: patch utilix file - # Patch this file if we want to have access to the database - if: matrix.test != 'pytest_no_database' - run: bash .github/scripts/create_readonly_utilix_config.sh - env: - # RunDB - RUNDB_API_URL: ${{ secrets.RUNDB_API_URL }} - RUNDB_API_USER_READONLY: ${{ secrets.RUNDB_API_USER_READONLY }} - RUNDB_API_PASSWORD_READONLY: ${{ secrets.RUNDB_API_PASSWORD_READONLY}} - PYMONGO_URL: ${{ secrets.PYMONGO_URL }} - PYMONGO_USER: ${{ secrets.PYMONGO_USER }} - PYMONGO_PASSWORD: ${{ secrets.PYMONGO_PASSWORD }} - PYMONGO_DATABASE: ${{ secrets.PYMONGO_DATABASE }} - # SCADA - SCADA_URL: ${{ secrets.SCADA_URL }} - SCADA_VALUE_URL: ${{ secrets.SCADA_VALUE_URL }} - SCADA_USER: ${{ secrets.SCADA_USER }} - SCADA_LOGIN_URL: ${{ secrets.SCADA_LOGIN_URL }} - SCADA_PWD: ${{ secrets.SCADA_PWD }} - - name: Create pre-apply function file - # In case we do not have database. We need to make a local file for - # The pre_apply_function (see #559). - env: - HAVE_ACCESS_TO_SECTETS: ${{ secrets.RUNDB_API_URL }} - if: env.HAVE_ACCESS_TO_SECTETS == null || matrix.test == 'pytest_no_database' - run: | - bash .github/scripts/create_pre_apply_function.sh $HOME - - # Run tests - - name: Test package - # This is running a normal test - if: matrix.test == 'pytest_no_database' || matrix.test == 'pytest' - run: | - pytest -v - - name: Coveralls - # Make the coverage report and upload - env: - NUMBA_DISABLE_JIT: 1 - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - # We need to check if we had access to the secrets, otherwise coveralls - # will yield a low coverage because of the lack of interfacing with the - # database. - HAVE_ACCESS_TO_SECTETS: ${{ secrets.RUNDB_API_URL }} - if: matrix.test == 'coveralls' && env.HAVE_ACCESS_TO_SECTETS != null - run: | - coverage run --source=straxen setup.py test -v - coveralls --service=github - # Done - - name: goodbye - run: echo "tests done, bye bye" +# Test straxen on each PR. +# We run three types of tests: +# - Pytest -> these are the "normal" tests and should be run for all +# python versions +# - Coveralls -> this is to see if we are covering all our lines of +# code with our tests. The results get uploaded to +# coveralls.io/github/XENONnT/straxen +# - pytest_no_database -> we want to make sure we can run the tests even +# if we don't have access to our database since this will e.g. happen +# when someone is pushing a PR from their own fork as we don't +# propagate our secrets there. + +name: Test package + +# Trigger this code when a new release is published +on: + workflow_dispatch: + release: + types: [created] + pull_request: + push: + branches: + - master + - stable + - development + +jobs: + update: + name: "${{ matrix.test }}_py${{ matrix.python-version }}" + runs-on: ubuntu-latest + strategy: + fail-fast: False + matrix: + python-version: [3.7, 3.8, 3.9, "3.10"] + test: ['coveralls', 'pytest', 'pytest_no_database'] + # Only run coverage / no_database on py3.8 + exclude: + - python-version: 3.7 + test: coveralls + - python-version: 3.9 + test: coveralls + - python-version: "3.10" + test: coveralls + - python-version: 3.7 + test: pytest_no_database + - python-version: 3.9 + test: pytest_no_database + - python-version: "3.10" + test: pytest_no_database + + steps: + # Setup and installation + - name: Checkout repo + uses: actions/checkout@v2 + + - name: Setup python + uses: actions/setup-python@v2.3.0 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: 'extra_requirements/requirements-tests.txt' + + - name: Install requirements + run: pip install -r extra_requirements/requirements-tests.txt + + - name: Install strax + run: | + git clone https://github.com/AxFoundation/strax ../strax + pip install -e ../strax + + - name: Start MongoDB + uses: supercharge/mongodb-github-action@1.7.0 + with: + mongodb-version: 4.2 + + - name: patch utilix file + # Secrets and required files + # Patch this file if we want to have access to the database + if: matrix.test != 'pytest_no_database' + run: bash .github/scripts/create_readonly_utilix_config.sh + env: + # RunDB + RUNDB_API_URL: ${{ secrets.RUNDB_API_URL }} + RUNDB_API_USER_READONLY: ${{ secrets.RUNDB_API_USER_READONLY }} + RUNDB_API_PASSWORD_READONLY: ${{ secrets.RUNDB_API_PASSWORD_READONLY}} + PYMONGO_URL: ${{ secrets.PYMONGO_URL }} + PYMONGO_USER: ${{ secrets.PYMONGO_USER }} + PYMONGO_PASSWORD: ${{ secrets.PYMONGO_PASSWORD }} + PYMONGO_DATABASE: ${{ secrets.PYMONGO_DATABASE }} + # SCADA + SCADA_URL: ${{ secrets.SCADA_URL }} + SCADA_VALUE_URL: ${{ secrets.SCADA_VALUE_URL }} + SCADA_USER: ${{ secrets.SCADA_USER }} + SCADA_LOGIN_URL: ${{ secrets.SCADA_LOGIN_URL }} + SCADA_PWD: ${{ secrets.SCADA_PWD }} + + - name: Create pre-apply function file + # In case we do not have database. We need to make a local file for + # The pre_apply_function (see #559). + env: + HAVE_ACCESS_TO_SECTETS: ${{ secrets.RUNDB_API_URL }} + if: env.HAVE_ACCESS_TO_SECTETS == null || matrix.test == 'pytest_no_database' + run: bash .github/scripts/create_pre_apply_function.sh $HOME + + - name: Test package python 3.7 - 3.9 + # This is running a normal test + if: (matrix.test == 'pytest_no_database' || matrix.test == 'pytest') && matrix.python-version != '3.10' + env: + ALLOW_WFSIM_TEST: 1 + TEST_MONGO_URI: 'mongodb://localhost:27017/' + run: | + pytest -rsxv --durations 0 + + - name: Test package python 3.10 + # We cannot test with WFSim (yet) so we run this test separately without ALLOW_WFSIM_TEST + if: (matrix.test == 'pytest_no_database' || matrix.test == 'pytest') && matrix.python-version == '3.10' + env: + TEST_MONGO_URI: 'mongodb://localhost:27017/' + run: | + pytest -rsxv --durations 0 + + - name: Coveralls + # Make the coverage report and upload + env: + TEST_MONGO_URI: 'mongodb://localhost:27017/' + ALLOW_WFSIM_TEST: 1 + NUMBA_DISABLE_JIT: 1 + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + # We need to check if we have access to the secrets, otherwise coveralls + # will yield a low coverage because of the lack of interfacing with the + # database. + HAVE_ACCESS_TO_SECTETS: ${{ secrets.RUNDB_API_URL }} + + if: matrix.test == 'coveralls' + run: | + coverage run --source=straxen setup.py test -v + coveralls --service=github + + - name: goodbye + run: echo "tests done, bye bye" diff --git a/.github/workflows/test_install.yml b/.github/workflows/test_install.yml new file mode 100644 index 000000000..b388db8dc --- /dev/null +++ b/.github/workflows/test_install.yml @@ -0,0 +1,38 @@ +# Test if we can actually install strax by installing +name: Installation test + +on: + workflow_dispatch: + release: + types: [created] + pull_request: + branches: + - master + - stable + push: + branches: + - master + +jobs: + update: + name: "py${{ matrix.python-version }}" + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [3.8] + steps: + - name: Setup python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Checkout repo + uses: actions/checkout@v2 + - name: pre-install requirements + run: pip install -r requirements.txt + - name: Install straxen + run: python setup.py install + - name: Test import + run: python -c "import straxen; straxen.print_versions()" + - name: goodbye + run: echo goodbye diff --git a/.gitignore b/.gitignore index 064b5b3c3..1316e647b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ # XENON secrets: run database passwords, S3 keys, etc. # Though these are readonly, they should not be committed to a public repo. xenon_secrets.py +.xenon_config # Jupyter .ipynb_checkpoints @@ -11,6 +12,9 @@ xenon_secrets.py *.npy *.blosc *.h5 +*.pdf +live_data +daq_test_data strax_data strax_test_data from_fake_daq diff --git a/.pylintrc b/.pylintrc index e9a53e17c..3b9f13e6a 100644 --- a/.pylintrc +++ b/.pylintrc @@ -4,5 +4,6 @@ # - cyclic-import (we use this all the time in strax, see __init__.py) # - no-else-return (I think this makes sense for symmetric conditions, see https://dmerej.info/blog/post/else-after-return-yea-or-nay/) # - len-as-condition (if you do 'if data' on a numpy array it will crash) +# - fixme (Useful to keep in the code here and there) disable=all -enable=assert-on-tuple,astroid-error,bad-except-order,bad-inline-option,bad-option-value,bad-reversed-sequence,bare-except,binary-op-exception,boolean-datetime,catching-non-exception,cell-var-from-loop,confusing-with-statement,consider-merging-isinstance,consider-using-enumerate,consider-using-ternary,continue-in-finally,deprecated-pragma,django-not-available,duplicate-except,duplicate-key,eval-used,exec-used,expression-not-assigned,fatal,file-ignored,fixme,global-at-module-level,global-statement,global-variable-not-assigned,global-variable-undefined,http-response-with-content-type-json,http-response-with-json-dumps,invalid-all-object,invalid-characters-in-docstring,literal-comparison,locally-disabled,locally-enabled,lost-exception,lowercase-l-suffix,misplaced-bare-raise,missing-final-newline,missing-kwoa,mixed-line-endings,model-has-unicode,model-missing-unicode,model-no-explicit-unicode,model-unicode-not-callable,multiple-imports,multiple-statements,new-db-field-with-default,no-else-raise,non-ascii-bytes-literals,nonexistent-operator,not-an-iterable,not-in-loop,notimplemented-raised,overlapping-except,parse-error,pointless-statement,pointless-string-statement,raising-bad-type,raising-non-exception,raw-checker-failed,redefine-in-handler,redefined-argument-from-local,redefined-builtin,redundant-content-type-for-json-response,reimported,relative-import,return-outside-function,simplifiable-if-statement,singleton-comparison,syntax-error,trailing-comma-tuple,trailing-newlines,unbalanced-tuple-unpacking,undefined-all-variable,undefined-loop-variable,unexpected-line-ending-format,unidiomatic-typecheck,unnecessary-lambda,unnecessary-pass,unnecessary-semicolon,unneeded-not,unpacking-non-sequence,unreachable,unrecognized-inline-option,used-before-assignment,useless-else-on-loop,using-constant-test,wildcard-import,yield-outside-function,useless-return +enable=assert-on-tuple,astroid-error,bad-except-order,bad-inline-option,bad-option-value,bad-reversed-sequence,bare-except,binary-op-exception,boolean-datetime,catching-non-exception,cell-var-from-loop,confusing-with-statement,consider-merging-isinstance,consider-using-enumerate,consider-using-ternary,continue-in-finally,deprecated-pragma,django-not-available,duplicate-except,duplicate-key,eval-used,exec-used,expression-not-assigned,fatal,file-ignored,global-at-module-level,global-statement,global-variable-not-assigned,global-variable-undefined,http-response-with-content-type-json,http-response-with-json-dumps,invalid-all-object,invalid-characters-in-docstring,literal-comparison,locally-disabled,locally-enabled,lost-exception,lowercase-l-suffix,misplaced-bare-raise,missing-final-newline,missing-kwoa,mixed-line-endings,model-has-unicode,model-missing-unicode,model-no-explicit-unicode,model-unicode-not-callable,multiple-imports,multiple-statements,new-db-field-with-default,no-else-raise,non-ascii-bytes-literals,nonexistent-operator,not-an-iterable,not-in-loop,notimplemented-raised,overlapping-except,parse-error,pointless-statement,pointless-string-statement,raising-bad-type,raising-non-exception,raw-checker-failed,redefine-in-handler,redefined-argument-from-local,redefined-builtin,redundant-content-type-for-json-response,reimported,relative-import,return-outside-function,simplifiable-if-statement,singleton-comparison,syntax-error,trailing-comma-tuple,trailing-newlines,unbalanced-tuple-unpacking,undefined-all-variable,undefined-loop-variable,unexpected-line-ending-format,unidiomatic-typecheck,unnecessary-lambda,unnecessary-pass,unnecessary-semicolon,unneeded-not,unpacking-non-sequence,unreachable,unrecognized-inline-option,used-before-assignment,useless-else-on-loop,using-constant-test,wildcard-import,yield-outside-function,useless-return diff --git a/.readthedocs.yml b/.readthedocs.yml index 0f97cc1c3..6115686d6 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -7,14 +7,19 @@ sphinx: configuration: docs/source/conf.py build: - image: latest + image: latest python: - version: 3.6 - install: - - method: pip - path: . - extra_requirements: - - docs - - method: setuptools - path: . + version: "3.8" + install: + - requirements: extra_requirements/requirements-tests.txt + - method: pip + path: . + extra_requirements: + - docs + - method: setuptools + path: . + +formats: + - pdf + - epub diff --git a/HISTORY.md b/HISTORY.md index b01fedfaf..0a23fc7e1 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,214 @@ +1.2.6 / 2022-01-18 +------------------ +fixes/tests: +- Fix online monitor test (#882) + +notes: +- No lineage changes + +1.2.5 / 2022-01-14 +------------------ +fixes/tests: + - test with py3.10 (#878) + - remove fixme error (e0e30d94ec8f5276c581da166787db72ba0eef4a) + - bump numba (#880) + - Tests for scada interface (#877) + +notes: + - No lineage changes + +1.2.4 / 2022-01-10 +------------------ +fixes/tests: + - Fixes for WFSim <-> CMT (#865) + - Tests for WFSim contexts (#855) + +notes: + - First 1.2.X version compatible with WFSim + - No lineage changes + +1.2.3 / 2022-01-10 +------------------ +- Bump numpy (#876) + +notes: + - Incompatible with WFSim + +1.2.2 / 2022-01-10 +------------------ +tests: + - Test for Mongo-down/uploader (#859) + - Test for rucio-documents in the rundb (#858) + - Test for bokeh_utils (#857) + - Tests for common.py fix #741 (#856) + +bugfix: + - Bump peaklets version (#873) + +notes: + - Lineage change for `peaklets` (#875) + + +1.2.1 / 2021-12-27 +------------------ +fixes/tests: + - Add cmt tests and fix bug in apply_cmt_version (#860) + - Pin documentation requirements (#862) + - Add read the docs config (#861) + - Pymongo requirement should be <4.0 (#852) + +notes: + - Bug for `peaklets-uhfusstvab` due to (#875) + - No lineage changes + - Incompatible with WFSim + + +1.2.0 / 2021-12-21 +------------------- +major: + +* Update CorrectedAreas (instead of EnergyEstimates) (#817) +* S2 pattern fit (#780) +* Exclude S1 as triggering peak (#779) +* Two manual boundaries (updated 11/24/2021) (#775) +* Add main peaks' shadow for event shadow (#770) +* Events synchronize (#761) +* Implement peak-level shadow and event-level shadow refactor (#753) +* use channel tight coincidence level (#745) + +minor / patches: + +* Normalized line endings (#833) +* Fix codefactor issues (#832) +* Another try at codefactor (#831) +* URLConfig take protocol for nested keys (#826) +* Rename tight coincidence (#825) +* Move URLConfig cache to global dictionary (#822) +* Remove codefactor (#818) +* Performance update for binomial test (#783) +* URLConfig not in strax (#781) +* Add refactor event building cut (#778) +* whipe online monitor data (#777) +* Cache dependencies (#772) +* Update definition array_valued (#757) + +fixes/tests: + +* Add test for filter_kwargs (#837) +* Fix nv testing data (#830) +* Unittest for DAQreader (#828) +* Fix broken matplotlib/minianalyses (#815) +* Itp test (#813) +* Loose packaging requirement (#810) +* can we disable codefactor please (#809) +* Fix #781 (#808) +* Matplotlib changed requirements (#805) +* Pin pymongo (#801) +* Bump wfsim tests (#773) +* Patch peaks merging (#767) + +notes: + - Bug for `peaklets-uhfusstvab` due to (#875) + - plugins changed (new lineage) everything >= 'peaklet_classification' + - offline CMT versions don't work in this release + - Incompatible with WFSim + + +1.1.3 / 2021-11-19 +------------------- +minor / patches: +- Add URL based configs (#758) +- Add perpendicular wires handling info and function (#756) +- Add a few special cases event_info_double (#740) +- Process afterpulses on ebs (#727) +- Add zenodo (#742) +- Set check_broken=False for RucioFrontend.find (#749) +- Explicitly set infer_dtype=False for all Options (#750) +- Use alt z for alternative s1 binomial test (#724) + +fixes/tests: +- update docs (#743) +- Remove RuntimeError in RucioFrontend (#719) +- cleanup bootstrax logic for target determination (#768) +- Test installation without extra requirements (#725) +- Adding code comments for corrected z position (#763) +- Reactivate scada test (#764) +- Added resource exception for Scada (#755) +- test_widgets is broken? (#726) +- Track bokeh (#759) +- Fix keras requirement (#748) +- Update requirements-tests.txt (#739) +- Fix deprecation warning (#723) +- Update test_misc.py (90f2fc30141704158a0e297ea05679515a62b397) + +notes: + - plugins changed (new lineage) are `event_info_double` and `event_pattern_fit` + + +1.1.2 / 2021-10-27 +------------------- +minor / patches: +- Plugin for afterpulse processing (#549) +- Veto online monitor (#707) +- Refactor straxen tests (#703) +- WFSim registry as argument for simulations context (#713) +- Update S1 AFT map in event pattern fit (#697) +- Refactor s2 correction (#704) + +fixes/tests: +- Set default drift time as nan (#700) +- Revert auto inclusion of rucio remote #688 (#701) +- fix bug in CMT (#710) +- Fix one year querries (#711) +- Test new numba (#702) +- Unify CMT call in contexts (#717) +- Small codefactor patch (#714) +- test nv with nv data (#709) +- Add small test for wfsim (#716) + +notes: + - plugins changed (new lineage) are: + - `afterpulses` + - `online_monitor_nv` + - `online_monitor_mv` + - `event_pattern_fit` + - `corrected_areas` + +1.1.1 / 2021-10-19 +------------------- + - Fix to test for RunDB frontend when no test DB is sourced (6da2233) + + +1.1.0 / 2021-10-18 +------------------- +major / minor: + +- Previous S2 Shadow Plugin draft (#664) +- Use admix in straxen (#688) +- Add posdiff plugin (#669) +- updated S2 corrected area (#686) +- Version bump of hitlets (#690) +- Add n saturated channels (#691) +- add small tool to extract run comments from database (#692) +- Update online_monitor_nv to v0.0.3 (#696) + + +patches and fixes: + +- Use read by index and check for NaNs (#661) +- Add small feature for printing versions of git (#665) +- Fix minianalyses from apply_selection (#666) +- fix some warnings from testing (#667) +- Add source to runs table (#673) +- Pbar patch for rundb query (#685) +- Implement SDSC as a local RSE for Expanse (#687) +- Skips superruns in rucio frontend (#689) +- Warn about non-loadable loggers (#693) +- Add RunDb read/write-test (#695) +- Fix bug in rucio frontend (#699) + + + 1.0.0 / 2021-09-01 ------------------- major / minor: diff --git a/README.md b/README.md index e93da1196..6ed1f37ce 100644 --- a/README.md +++ b/README.md @@ -2,10 +2,12 @@ Streaming analysis for XENON(nT) [![Test package](https://github.com/XENONnT/straxen/actions/workflows/pytest.yml/badge.svg?branch=master)](https://github.com/XENONnT/straxen/actions/workflows/pytest.yml) -[![CodeFactor](https://www.codefactor.io/repository/github/xenonnt/straxen/badge)](https://www.codefactor.io/repository/github/xenonnt/straxen) [![Coverage Status](https://coveralls.io/repos/github/XENONnT/straxen/badge.svg)](https://coveralls.io/github/XENONnT/straxen) [![PyPI version shields.io](https://img.shields.io/pypi/v/straxen.svg)](https://pypi.python.org/pypi/straxen/) [![Readthedocs Badge](https://readthedocs.org/projects/straxen/badge/?version=latest)](https://straxen.readthedocs.io/en/latest/?badge=latest) +[![CodeFactor](https://www.codefactor.io/repository/github/xenonnt/straxen/badge)](https://www.codefactor.io/repository/github/xenonnt/straxen) + +[![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5576262.svg)](https://doi.org/10.5281/zenodo.5576262) [Straxen](https://straxen.readthedocs.io) is the analysis framework for XENONnT, built on top of the generic [strax framework](https://github.com/AxFoundation/strax). Currently it is configured for analyzing XENONnT and XENON1T data. diff --git a/bin/ajax b/bin/ajax index f43544f13..ecd882fd2 100644 --- a/bin/ajax +++ b/bin/ajax @@ -1,915 +1,915 @@ -#!/usr/bin/env python -""" -AJAX: XENON-nT -Aggregate Junking Ancient Xenon-data -cleaning tool to remove old data from event builders. -============================================= -Joran Angevaare, 2020 - -This tool keeps the event builders clean by performing any of the cleaning modes -(try 'ajax --help' for a complete description). - -Some of these modes also affect the live_data rather than the processed data. As there can -be multiple instances of ajax running there is a dedicated machine 'eb_can_clean_ceph' -that may delete the live_data such that we prevent multiple attempts to perform a single -action. -""" - -__version__ = '0.3.2' - -import argparse -from datetime import datetime, timedelta -import logging -import os -import socket -import shutil -import pymongo -import pytz -import threading -import time -import re -import numpy as np -import sys -import straxen - -## -# Parameters -## -ajax_thresholds = { - # Remove the live data only if this many seconds old - 'remove_live_after': 24 * 3600, # s - # Remove do not remove successfully processed runs if they have finished less than - # this many seconds ago to e.g. allow for saving et cetera - 'wait_after_processing': 12 * 3600, # s - # Remove high level plugins if the run is this old - # TODO - # change the value to a different one - 'remove_high_level_data': 365 * 24 * 3600, # s - # Minimum time for restarting ajax to check for new stuff to clean (essentially the - # timeout) - 'nap_time': 3600, # s - # Short nap time - 'short_nap': 5, # s -} - -# The low level data is taken care of by admix so we don't have to delete this. All other -# data-types are deleted in the clean_high_level_data routine -low_level_data = ['raw_records*', 'records*', - 'veto_regions', 'pulse_counts', 'led_calibration', - 'peaklets'] - -# Open deletion threads with this prefix (and also look for these before quiting ajax) -thread_prefix = 'ajax_delete' - -# To prevent multiple ebs from cleaning ceph only this one can actually do it -eb_can_clean_ceph = 'eb0.xenon.local' - - -## -# Main functions -## - - -def main_ajax(): - """Main function""" - if args.number: - remove_run_from_host(args.number, delete_live=args.delete_live, - force=args.force, reason='manual') - - if not args.clean: - raise ValueError('either --number or --clean should be specified') - - _modes = {'ceph': clean_ceph, - 'high_level_data': clean_high_level_data, - 'non_latest': clean_non_latest, - 'unregistered': clean_unregistered, - 'abandoned': clean_abandoned, - 'old_hash': clean_old_hash, - 'database': clean_database} - - if args.clean not in _modes and args.clean != 'all': - raise ValueError(f'Unknown cleaning mode {args.clean}') - - if args.clean != 'all': - # Do this mode - _modes[args.clean]() - if hostname == eb_can_clean_ceph and (args.clean in - ['abandoned', 'database']): - _modes[args.clean](delete_live=True) - log.info(f'Done with {args.clean}') - return - - # Apparently we want to clean everything. Let's loop: - while True: - try: - if hostname == eb_can_clean_ceph: - clean_ceph() - clean_abandoned(delete_live=True) - - # Don't do clean_unregistered by default (only if specified - # as specific argument on host "eb_can_clean_ceph") as - # someone might be storing some test data in the /live_data - # folder. That is bad practise though - # clean_unregistered(delete_live=True) - - clean_abandoned() - clean_non_latest() - - # These two modes below shouldn't be done on autopilot at the - # time of writing. - # clean_unregistered() - # clean_database() - if not args.execute: - break - log.info(f'Loop finished, take a {ajax_thresholds["nap_time"]} s nap') - time.sleep(ajax_thresholds['nap_time']) - except (KeyboardInterrupt, SystemExit, NotImplementedError) as e: - log.info('\nStopping, wait a second for the delete threads\n') - wait_on_delete_thread() - raise e - except Exception as fatal_error: - log.error(f'Fatal warning:\tran into {fatal_error}. Try ' - f'logging error and restart ajax') - try: - log_warning(f'Fatal warning:\tran into {fatal_error}', - priority='error') - except Exception as warning_error: - log.error(f'Fatal warning:\tcould not log {warning_error}') - # This usually only takes a minute or two - time.sleep(60) - log.warning('Restarting main loop') - - -def clean_ceph(): - """ - Look for old data on ceph. If found, delete it. Recursively iterate until no old data - is found. This function only works on the the host that is eb_can_clean_ceph. - """ - set_state('clean_ceph') - if not hostname == eb_can_clean_ceph: - log.info(f'for cleaning ceph, go to {eb_can_clean_ceph}') - return - - rd = run_coll.find_one({'bootstrax.state': 'done', - 'status': 'transferred', - 'start': - {'$lt': now(-ajax_thresholds['remove_live_after'])}, - 'bootstrax.time': - {"$lt": now(-ajax_thresholds['wait_after_processing'])}, - 'data.type': 'live'}) - if rd is None: - return - else: - run_number = rd['number'] - log.info(f'remove data associated to run {run_number}') - remove_run_from_host(run_number, delete_live=True, force=args.force, - reason='clean ceph') - log.info(f'finished for {run_number}, take a' - f' {ajax_thresholds["short_nap"]} s nap') - time.sleep(ajax_thresholds['short_nap']) - - # Repeat. - if args.execute: - wait_on_delete_thread() - clean_ceph() - - # Finally, check that there is no old data on ceph that is also not in the rundoc. - # clean_unregistered(delete_live=True) - - -def clean_high_level_data(): - """ - Check the runs database for old data on this host and clean all non-low-level - data-types from this host. - """ - set_state('clean_high_level_data') - raise NotImplementedError('This is a drastic measure that we do not ' - 'intend using at the time of writing') - - # We need to grep all the rundocs at once as simply doing one at the time (find_one) - # might get stuck as the query may yield the same result the next time it is called. - rds = run_coll.find({'bootstrax.state': 'done', - 'status': 'transferred', - 'start': - {"$lt": now(-ajax_thresholds['remove_high_level_data'])}, - 'data.host': hostname}) - if not rds: - # Great, we are clean - return - - for rd in rds: - # First check that we actually have raw_records stored somewhere - _, have_raw_records = check_rundoc_for_live_and_raw(rd) - if not have_raw_records: - break - - # Okay, good to go let's see if it is high level data and if so, delete it - for ddoc in rd['data']: - # Only delete data on this host - if 'host' in ddoc and ddoc['host'] == hostname: - is_low_level = re.findall('|'.join(low_level_data), ddoc['type']) - if is_low_level: - continue - loc = ddoc['location'] - if 'raw_records' in ddoc['type']: - raise ValueError('' - 'your regex syntax fails!') - elif os.path.exists(loc): - log.info(f'delete data at {loc}') - delete_data(rd, loc, ddoc['type'], test=not args.execute, - reason='high level data') - else: - loc = loc + '_temp' - log.info(f'delete data at {loc}') - delete_data(rd, loc, ddoc['type'], test=not args.execute, - reason='high level data') - - -def clean_non_latest(): - """ - Remove data on this host if the boostrax.host field is not this host while - processing has finished - """ - set_state('clean_non_latest') - # We need to grep all the rundocs at once as simply doing one at the time (find_one) - # might get stuck as the query may yield the same result the next time it is called. - rds = run_coll.find({'bootstrax.state': 'done', - 'status': 'transferred', - 'bootstrax.time': - {"$lt": now(-ajax_thresholds['wait_after_processing'])}, - 'bootstrax.host': {'$ne': hostname}, - 'data.host': hostname}) - if not rds: - # Great, we are clean - return - - for rd in rds: - # First check that we actually have raw_records stored on one of the other ebs - have_raw_records = False - for dd in rd['data']: - if (dd['type'] == 'raw_records' and - dd['host'] != hostname): - have_raw_records = True - break - if not have_raw_records: - break - - # Okay, good to go let's see if it is high level data and if so, delete it - for ddoc in rd['data']: - # Only delete data on this host - if 'host' in ddoc and ddoc['host'] == hostname: - loc = ddoc['location'] - if os.path.exists(loc): - log.info(f'clean_non_latest::\tdelete data at {loc}') - delete_data(rd, loc, ddoc['type'], test=not args.execute, - reason='non latest') - else: - loc = loc + '_temp' - log.info(f'clean_non_latest::\tdelete data at {loc}') - delete_data(rd, loc, ddoc['type'], test=not args.execute, - reason='non latest') - - -def clean_unregistered(delete_live=False): - """ - Clean data that is not in the database. To do this check the output folder and remove - files if there is no corresponding entry in the rundoc. - :param delete_live: bool, if true delete unregistered live data. - """ - set_state('clean_unregistered') - folder_to_check = output_folder if delete_live is False else ceph_folder - all_data = os.listdir(folder_to_check) - run_ids = [] - for f in all_data: - run_ids.append(f.split('-')[0]) - run_ids = np.unique(run_ids) - log.info(f'clean_unregistered::\tfound {len(run_ids)} runs stored on' - f'{folder_to_check}. Checking that each is in the runs-database') - for run_id in run_ids: - run_number = int(run_id) - remove_if_unregistered(run_number, delete_live=delete_live) - - -def clean_old_hash(): - """ - Loop over the files on the host and check that the lineage is the - same as the current lineage hash we would get from straxen - """ - set_state('clean_old_hash') - files = os.listdir(output_folder) - for f in files: - run_id, data_type, lineage_hash = f.split('-') - if (bool(re.findall('|'.join(low_level_data), data_type)) or - '_temp' in lineage_hash): - continue - elif data_type not in st._plugin_class_registry: - log.warning(f'{data_type} is not registered!') - continue - - current_st_hash = st.key_for(run_id, data_type).lineage_hash - if current_st_hash != lineage_hash: - loc = os.path.join(output_folder, f) - log.info(f'clean_old_hash::\tLineage for run {run_id}, {data_type} is ' - f'{current_st_hash}. Removing old hash: {lineage_hash} from {loc}.') - rd = run_coll.find_one({'bootstrax.state': 'done', - 'status': 'transferred', - 'bootstrax.time': - {"$lt": now(-ajax_thresholds['wait_after_processing'])}, - 'data.type': data_type, - 'data.host': hostname, - 'number': int(run_id)}) - if 'raw_records' in data_type: - raise ValueError(f'You did some sloppy regex on the data-type, almost ' - f'deleted {data_type}!') - if rd: - delete_data(rd, loc, data_type, test=not args.execute, - reason='old lineage hash') - else: - log.info(f'{data_type} at {loc} not registered as done for {run_id}. ' - f'Perhaps the data still processing the run?') - - -def clean_abandoned(delete_live=False): - """ - Recursively delete data associated to abandoned runs. If deleting live data, submit - multiple threads at the same time. - :param delete_live: bool, if true also delete the live_data of these runs. - """ - set_state('clean_abandoned') - # Notice that we thread the deletion of live_data. As such we need to allow multiple - # runs to be deleted simultaneously. - if not delete_live: - rd = run_coll.find_one({'bootstrax.state': 'abandoned', - 'data.host': hostname}) - if rd is None: - log.info('clean_abandoned::\tNo more matches in rundoc') - return - else: - # Make it iterable for the loop below - rds = [rd] - else: - rds = run_coll.find({'bootstrax.state': 'abandoned', - 'bootstrax.time': - {"$lt": now(-ajax_thresholds['wait_after_processing'])}, - 'data.location': '/live_data/xenonnt', - 'data.host': 'daq'}) - - # Count the number of threads. Only allow one if we aks for user input. - i = 0 - i_max = 1 if args.ask_confirm else 5 - for rd in rds: - if i >= i_max: - break - run_number = rd['number'] - log.info(f'clean_abandoned::\tremove data associated to run {run_number}') - # Please note that we have to force these runs always since they should not be - # stored elsewhere. They are abandoned for a reason! - remove_run_from_host(run_number, delete_live=delete_live, force=True, - reason='run abandonned') - i += 1 - if not i: - log.info('clean_abandoned::\tNo more live_data matches in rundoc') - # Apparently there is no rd in rds - return - - # Repeat - if args.execute: - if delete_live: - wait_on_delete_thread() - return clean_abandoned(delete_live=delete_live) - - -def clean_database(delete_live=False): - """ - Remove entries from the database if the data is not on this host - """ - set_state('clean_database') - # We need to grep all the rundocs at once as simply doing one at the - # time (find_one) might get stuck as the query may yield the same - # result the next time it is called. - if not delete_live: - rds = run_coll.find({'bootstrax.state': 'done', - 'bootstrax.time': - {"$lt": now(-ajax_thresholds['wait_after_processing'])}, - 'data.host': hostname}) - else: - rds = run_coll.find({'bootstrax.state': 'done', - 'bootstrax.time': - {"$lt": now(-ajax_thresholds['wait_after_processing'])}, - 'data.host': 'daq'}) - if not rds: - # Great, we are clean - return - - for rd in rds: - # Okay, good to go let's see if it is high level data and if so, delete it - for ddoc in rd['data']: - # Only delete data on this host - if (('host' in ddoc and ddoc['host'] == hostname) or - ('host' in ddoc and ddoc['host'] == 'daq' and delete_live)): - loc = ddoc['location'] - if not os.path.exists(loc): - log.info(f'clean_database::\tdelete entry of data from ' - f'{rd["number"]} at {loc} as it does not exist') - delete_data(rd, loc, ddoc['type'], test=not args.execute, - reason='reason already removed from eb?!') - - -## -# Core functions -## - - -def _rmtree(path): - """ - Wrapper for shutil.rmtree. All deletion statements in this script go - through this function in order to make sure that the - args.execute statement is always double (tripple) checked - before deleting data. - :param path: path to delete - :return: - """ - if args.execute: - if confirm(f'delete {path}?'): - shutil.rmtree(path) - else: - log.info(f'TESTING:\tshutil.rmtree({path})') - if not os.path.exists(path): - raise ValueError(f'{path} does not exist') - - -def threaded_delete_data(rd, path, data_type, - test=True, ignore_low_data_check=False, - reason=''): - """ - Wrapper for delete_data to run in separate threads. - :param rd: rundoc - :param path: location of the folder to be deleted - :param data_type: type of data to be deleted - :param test: bool if we are testing or not. If true, nothing will be deleted. - :param ignore_low_data_check: ignore the fact that this might be the only copy of the - data. We can specify this e.g. if we know this is data associated to some abandoned - run. - """ - - thread_name = thread_prefix + path.split('/')[-1] - delete_thread = threading.Thread(name=thread_name, - target=delete_data, - args=(rd, path, data_type, - test, ignore_low_data_check), - kwargs={'reason': reason}) - log.info(f'Starting thread to delete {path} at {now()}') - # We rather not stop deleting the live_data if something else fails. Set the thread - # to daemon. - delete_thread.setDaemon(True) - delete_thread.start() - log.info(f'DeleteThread {path} should be running in parallel, continue MainThread ' - f'now: {now()}') - - -def delete_data(rd, path, data_type, - test=True, - ignore_low_data_check=False, - reason='', - ): - """ - Delete data and update the rundoc - :param rd: rundoc - :param path: location of the folder to be deleted - :param data_type: type of data to be deleted - :param test: bool if we are testing or not. If true, nothing will - be deleted. - :param ignore_low_data_check: ignore the fact that this might be - the only copy of the data. We can specify this e.g. if we know - this is data associated to some abandoned run. - """ - # First check that we are not deleting essential data (unless we are working - # with abandoned runs or so. - if not ignore_low_data_check: - n_live, n_rr = check_rundoc_for_live_and_raw(rd) - if not ( - n_live or - n_rr >= 1 + int('raw_records' in data_type)): - message = (f'Trying to delete {data_type} but we only have {n_live}' - f' live- and {n_rr} raw_record-files in the ' - f'runs-database. This might be an essential copy of the' - f' data!') - log_warning(message, priority='fatal', run_id=f'{rd["number"]:06}') - if not test: - raise ValueError(message) - - if os.path.exists(path): - log.info(f'Deleting data at {path}') - if not test: - _rmtree(path) - log.info(f'deleting {path} finished') - else: - log.info(f'There is no data on {path}! Just doing the rundoc.') - - # Remove the data location from the rundoc and append it to the 'deleted_data' entries - if not os.path.exists(path): - log.info('changing data field in rundoc') - for ddoc in rd['data']: - if ddoc['type'] == data_type and ddoc['host'] in ('daq', hostname): - break - for k in ddoc.copy().keys(): - if k in ['location', 'meta', 'protocol']: - ddoc.pop(k) - - ddoc.update({'at': now(), 'by': f'ajax.{hostname}', 'reason': reason}) - log.info(f'update with {ddoc}') - if args.execute and not test: - if confirm('update rundoc?'): - run_coll.update_one({'_id': rd['_id']}, - {"$addToSet": {'deleted_data': ddoc}, - "$pull": {"data": - {"type": data_type, - "host": {'$in': ['daq', hostname]}}}}) - else: - log.info(f'Update ddoc with : {ddoc}') - elif not test and not args.ask_confirm: - raise ValueError(f"Something went wrong we wanted to delete {path}!") - - -def check_rundoc_for_live_and_raw(rd): - """ - Count the number of files of live_data (cannot be >1) and raw_data (transfers can get - it greater than 1) - :param rd: rundoc - :return: length 2 tuple with n_live_data and n_raw_records being the number of files - in the rundoc for the live data and raw records respectively. - """ - n_live_data, n_raw_records = 0, 0 - for dd in rd['data']: - if dd['type'] == 'live': - n_live_data += 1 - if dd['type'] == 'raw_records': - n_raw_records += 1 - return n_live_data, n_raw_records - - -def remove_run_from_host(number, delete_live=False, force=False, reason=''): - """ - Save way of removing data from host if data registered elsewhere - :param number: run number (not ID!) - :param delete_live: bool, if true delete the live_data else the processed data - :param force: forcefully remove the data even if we don't have the right copies (e.g. - deleting /live_data when the raw_records are not stored. Be careful with this option! - Should only be used for the deletion of abandoned runs. - """ - # Query the database to remove data - rd = run_coll.find_one({'number': number, - 'data.host': hostname if not delete_live else 'daq'}) - if not rd: - log_warning(f'No registered data for {number} on {hostname}', - run_id=f'{number:06}', priority='info') - return - - have_live_data, have_raw_records = check_rundoc_for_live_and_raw(rd) - - for ddoc in rd['data']: - # This is processed data on the eventbuilders - if 'host' in ddoc and ddoc['host'] == hostname: - if delete_live: - # If you want to delete the live data you shouldn't consider this ddoc - continue - loc = ddoc['location'] - if not force and not have_live_data and 'raw_records' in ddoc['type']: - # If we do not have live_data, don't delete raw_records. However, if we - # --force deletion, do go to the next else statement - log.info(f'prevent {loc} from being deleted. The live_data has already' - f' been removed') - else: - log.info(f'delete data at {loc}') - delete_data(rd, loc, ddoc['type'], test=not args.execute, - ignore_low_data_check=force, reason=reason) - - loc = loc + '_temp' - if os.path.exists(loc): - log.info(f'delete data at {loc}') - delete_data(rd, loc, ddoc['type'], test=not args.execute, - ignore_low_data_check=force, reason=reason) - elif 'host' in ddoc and ddoc['host'] == 'daq': - # This is the live_data - if not delete_live: - log.info(f'prevent {ddoc["location"]} from being deleted. Do so with --delete_live') - # If you want to delete processed data you shouldn't consider this ddoc - continue - - run_id = '%06d' % number - loc = os.path.join(ddoc['location'], run_id) - if not force and not have_raw_records: - # If we do not have raw records, don't delete this data. However, if we - # --force deletion, do go to the next else statement - log_warning( - f'Unsafe to delete {loc}, no raw_records registered. Force with ' - f'--force', priority='info') - elif not force and have_raw_records: - log.info(f'Deleting {loc} since we have raw_records registered.') - threaded_delete_data(rd, loc, ddoc['type'], test=not args.execute, - ignore_low_data_check=False, reason=reason) - # Redundant elif but let's double check the force nonetheless. - elif force: - log.info(f'Forcefully delete {loc}, but no raw_records registered!') - threaded_delete_data(rd, loc, ddoc['type'], test=not args.execute, - ignore_low_data_check=True, reason=reason) - - -def remove_if_unregistered(number, delete_live=False): - """ - Given a run number delete data on this machine that matches that number - :param number: int! the run_number (not run_id) - :param delete_live: Bool, if True: Remove the live_data. Else remove processed data - """ - # Query the database to remove data - # WANRING! If we don't find something we WILL remove the data! Don't make the query - # specific! We only check if any of the data is on this host (i.e. not get None from - # the query below) - rd = run_coll.find_one({'number': number, - 'data.host': hostname if not delete_live else 'daq'}) - run_id = '%06d' % number - - if rd: - # Just for explicitness, this is where we want to end up. If we have a rundoc, - # the data is registered and we don't have to do anything. - return - else: - log_warning(f'remove_if_unregistered::\trun {number} is NOT registered ' - f'in the runDB but is stored on {hostname}', - run_id=run_id, - priority='error') - if not delete_live: - # Check the local ebs disk for data. - _remove_unregistered_run(output_folder, run_id, checked_db=True) - else: - # Check ceph for data associated to this run (which is apparently not in the - # runDB) - _remove_unregistered_run(ceph_folder, run_id, checked_db=True) - - -def _remove_unregistered_run(base_folder, run_id, checked_db=False): - """ - NB: The check that this run is not registered should be performed first! - Deletes any folder from base_folder that matches run_id. - :param base_folder: folder to check - :param run_id: run_id to remove from folder - :param checked_db: Bool if it is checked that this run in not in the database. - """ - if not checked_db: - log_warning(f'remove_if_unregistered::\trogue ajax operations! Trying ' - f'to delete {run_id} from {hostname}', - run_id=run_id, - priority='fatal') - raise ValueError("Only insert runs where for it is checked that it is " - "not registered in the runs database and double checked.") - log_warning(f'No data for {run_id} found! Double checking {base_folder}!', - run_id=run_id, priority='warning') - deleted_data = False - - for folder in os.listdir(base_folder): - if run_id in folder: - log.info(f'Cleaning {base_folder + folder}') - - # Do a final check if we are not deleting essential data! - # Do not disable this check! If you don't like it: make a smarter query - rd = run_coll.find_one({'number': int(run_id)}) - n_live, n_rr = check_rundoc_for_live_and_raw(rd) - - if not (n_live or n_rr >= 1 + int('raw_records' in folder)): - message = (f'Trying to delete {folder} but we only have ' - f'{n_live} live- and {n_rr} raw_record-files in the ' - f'runs-database. This might be an essential copy of ' - f'the data!') - log_warning(message, run_id=run_id, priority='fatal') - raise ValueError(message) - - # OK, we still have live_data somewhere or we have raw_records (elsewhere) - if args.execute: - # Double check returns True automatically if not args.ask_confirm - if confirm( - f'Should we really move {os.path.join(base_folder, folder)} ' - f'to {os.path.join(non_registered_folder, folder)}?'): - shutil.move(os.path.join(base_folder, folder), - os.path.join(non_registered_folder, folder)) - else: - log.info(f'TEST\tmoving {base_folder + folder}') - deleted_data = True - - if not deleted_data: - message = f'No data registered on {hostname} for {run_id}' - log_warning(message, priority='fatal') - raise FileNotFoundError(message) - - -## -# Helper functions -## - - -def now(plus=0): - """UTC timestamp""" - return datetime.now(pytz.utc) + timedelta(seconds=plus) - - -def confirm(question): - """ - If --ask_confirm is specified, ask user to confirm to proceed. - :return: bool - """ - if not args.ask_confirm: - return True - answer = str(input(question + ' (y/n): \n')).lower().strip() - if answer in ('y', 'n'): - return answer == 'y' - else: - confirm('please input (y/n)\n') - - -def wait_on_delete_thread(): - """Check that the threads with the thread_prefix are finished before continuing.""" - threads = threading.enumerate() - for thread in threads: - if thread_prefix in thread.name: - wait = True - while wait: - wait = False - if thread.isAlive(): - log.info(f'{thread.name} still running take a ' - f'{ajax_thresholds["short_nap"]} s nap') - time.sleep(ajax_thresholds['short_nap']) - wait = True - log.info(f'wait_on_delete_thread::\tChecked that all {thread_prefix}* finished') - - -def log_warning(message, priority='warning', run_id=None): - """Report a warning to the terminal (using the logging module) - and the DAQ log DB. - :param message: insert string into log_coll - :param priority: severity of warning. Can be: - info: 1, - warning: 2, - : 3 - :param run_id: optional run id. - """ - if not args.execute: - return - getattr(log, priority)(message) - # Log according to redax rules - # https://github.com/coderdj/redax/blob/master/MongoLog.hh#L22 - warning_message = { - 'message': message, - 'user': f'ajax_{hostname}', - 'priority': - dict(debug=0, - info=1, - warning=2, - error=3, - fatal=4, - ).get(priority.lower(), 3)} - if run_id is not None: - warning_message.update({'runid': run_id}) - if args.execute: - # Only upload warning if we would actually execute the script - log_coll.insert_one(warning_message) - log.info(message) - - -def set_state(state): - """Inform the bootstrax collection we're in a different state - - if state is None, leave state unchanged, just update heartbeat time - """ - ajax_state = dict( - host='ajax.' + hostname, - pid=os.getpid(), - time=now(), - state=state, - mode=f'clean {args.clean}', - production_mode=args.execute - ) - bs_coll.insert_one(ajax_state) - - -## -# Main -## - - -if __name__ == '__main__': - print(f'---\n ajax version {__version__}\n---') - - parser = argparse.ArgumentParser( - description="XENONnT cleaning manager") - parser.add_argument('--force', action='store_true', - help="Forcefully remove stuff from this host") - parser.add_argument('--ask_confirm', action='store_true', - help="Always ask for confirmation before deleting data/updating" - " the rundoc") - parser.add_argument('--delete_live', action='store_true', - help="delete live data for this run") - parser.add_argument('--execute', action='store_true', - help="Execute the deletion commands. If not specified, ajax " - "assumes you want to test") - parser.add_argument('--logging', default='DEBUG', - help="logging level (DEBUG/INFO/WARNING)") - - actions = parser.add_mutually_exclusive_group() - actions.add_argument('--number', type=int, metavar='NUMBER', - help="Process a single run, regardless of its status.") - actions.add_argument('--clean', type=str, help= - 'Run ajax in any of the following modes: clean [ceph, ' # noqa - 'unregistered, abandoned, high_level_data, all]\n' - '"ceph": Remove successfully processed runs and abandoned runs from /live_data\n' - '"unregistered": remove all data from this host that is not registered in the rundb\n' - '"abandoned": remove all the data associated to abandoned runs\n' - '"high_level_data": remove all high level data on this host\n' - '"non_latest": remove data on this host if it was not the last to process a given run \n' - '"database": check if all the entries that the database claims are actually here \n' - '"old_hash": remove high level data if the hash doesnt equal the latest for this datatype \n' - '"all": Clean everything that AJAX can get its hands on: unregistered data, high level ' - 'data') - - args = parser.parse_args() - hostname = socket.getfqdn() - - if not hasattr(logging, args.logging): - raise AttributeError(f'Set --logging to a logging level like DEBUG or INFO') - logging_level = getattr(logging, args.logging) - try: - import daqnt - log_name = 'ajax_' + hostname + ('' if args.execute else '_TESTING') - log = daqnt.get_daq_logger('main', log_name, level=logging_level) - except ModuleNotFoundError: - logging.basicConfig( - level=logging_level, - format='%(asctime)s %(name)s %(levelname)-8s %(message)s', - datefmt='%m-%d %H:%M') - log = logging.getLogger() - - if args.delete_live: - if not args.number: - raise ValueError("Specify which number with --number") - if args.force: - log.warning(f'main::\tDANGER ZONE you are forcefully deleting data that may ' - f'result in an irrecoverable loss of data.') - log.info(f'main::\tPlease note that execute argument is {args.execute} which ' - f'means you are {"" if not args.execute else "!NOT!"} safe') - if not args.ask_confirm: - raise NotImplementedError( - f'main::\tI cannot let your forcefully delete data without asking for ' - f'confirmation. Add --ask_confirm. Bye, bye') - - if not input('Want to proceed willingly? [y]').lower() == 'y': - log.info(f'main::\tAlright no unsafe operations, bye bye') - exit(-1) - if args.clean != 'abandoned' and not args.number: - raise NotImplementedError( - "main::\tI don't want to have this option enabled (yet).") - - if args.clean in ['all', 'old_hash']: - # Need the context for the latest hash - st = straxen.contexts.xenonnt_online() - - # Set the folders - ceph_folder = '/live_data/xenonnt/' - output_folder = '/data/xenonnt_processed/' - non_registered_folder = '/data/xenonnt_unregistered/' - for f in (output_folder, non_registered_folder): - if os.access(f, os.W_OK) is not True: - log_warning(f'main::\tNo writing access to {f}', priority='fatal') - raise IOError(f'main::\tNo writing access to {f}') - - # Runs database - run_collname = 'runs' - run_dbname = straxen.uconfig.get('rundb_admin', 'mongo_rdb_database') - run_uri = straxen.get_mongo_uri(header='rundb_admin', - user_key='mongo_rdb_username', - pwd_key='mongo_rdb_password', - url_key='mongo_rdb_url') - run_client = pymongo.MongoClient(run_uri) - run_db = run_client[run_dbname] - run_coll = run_db[run_collname] - run_db.command('ping') - - # DAQ database - daq_db_name = 'daq' - daq_uri = straxen.get_mongo_uri(header='rundb_admin', - user_key='mongo_daq_username', - pwd_key='mongo_daq_password', - url_key='mongo_daq_url') - daq_client = pymongo.MongoClient(daq_uri) - daq_db = daq_client[daq_db_name] - log_coll = daq_db['log'] - bs_coll = daq_db['eb_monitor'] - daq_db.command('ping') - - try: - set_state('idle') - main_ajax() - except (KeyboardInterrupt, SystemExit) as e: - log.info('\nStopping, wait a second for the delete threads\n') - wait_on_delete_thread() - raise e - - wait_on_delete_thread() - log.info(f'main::\tAjax finished, bye bye') +#!/usr/bin/env python +""" +AJAX: XENON-nT +Aggregate Junking Ancient Xenon-data +cleaning tool to remove old data from event builders. +============================================= +Joran Angevaare, 2020 + +This tool keeps the event builders clean by performing any of the cleaning modes +(try 'ajax --help' for a complete description). + +Some of these modes also affect the live_data rather than the processed data. As there can +be multiple instances of ajax running there is a dedicated machine 'eb_can_clean_ceph' +that may delete the live_data such that we prevent multiple attempts to perform a single +action. +""" + +__version__ = '0.3.2' + +import argparse +from datetime import datetime, timedelta +import logging +import os +import socket +import shutil +import pymongo +import pytz +import threading +import time +import re +import numpy as np +import sys +import straxen + +## +# Parameters +## +ajax_thresholds = { + # Remove the live data only if this many seconds old + 'remove_live_after': 24 * 3600, # s + # Remove do not remove successfully processed runs if they have finished less than + # this many seconds ago to e.g. allow for saving et cetera + 'wait_after_processing': 12 * 3600, # s + # Remove high level plugins if the run is this old + # TODO + # change the value to a different one + 'remove_high_level_data': 365 * 24 * 3600, # s + # Minimum time for restarting ajax to check for new stuff to clean (essentially the + # timeout) + 'nap_time': 3600, # s + # Short nap time + 'short_nap': 5, # s +} + +# The low level data is taken care of by admix so we don't have to delete this. All other +# data-types are deleted in the clean_high_level_data routine +low_level_data = ['raw_records*', 'records*', + 'veto_regions', 'pulse_counts', 'led_calibration', + 'peaklets'] + +# Open deletion threads with this prefix (and also look for these before quiting ajax) +thread_prefix = 'ajax_delete' + +# To prevent multiple ebs from cleaning ceph only this one can actually do it +eb_can_clean_ceph = 'eb0.xenon.local' + + +## +# Main functions +## + + +def main_ajax(): + """Main function""" + if args.number: + remove_run_from_host(args.number, delete_live=args.delete_live, + force=args.force, reason='manual') + + if not args.clean: + raise ValueError('either --number or --clean should be specified') + + _modes = {'ceph': clean_ceph, + 'high_level_data': clean_high_level_data, + 'non_latest': clean_non_latest, + 'unregistered': clean_unregistered, + 'abandoned': clean_abandoned, + 'old_hash': clean_old_hash, + 'database': clean_database} + + if args.clean not in _modes and args.clean != 'all': + raise ValueError(f'Unknown cleaning mode {args.clean}') + + if args.clean != 'all': + # Do this mode + _modes[args.clean]() + if hostname == eb_can_clean_ceph and (args.clean in + ['abandoned', 'database']): + _modes[args.clean](delete_live=True) + log.info(f'Done with {args.clean}') + return + + # Apparently we want to clean everything. Let's loop: + while True: + try: + if hostname == eb_can_clean_ceph: + clean_ceph() + clean_abandoned(delete_live=True) + + # Don't do clean_unregistered by default (only if specified + # as specific argument on host "eb_can_clean_ceph") as + # someone might be storing some test data in the /live_data + # folder. That is bad practise though + # clean_unregistered(delete_live=True) + + clean_abandoned() + clean_non_latest() + + # These two modes below shouldn't be done on autopilot at the + # time of writing. + # clean_unregistered() + # clean_database() + if not args.execute: + break + log.info(f'Loop finished, take a {ajax_thresholds["nap_time"]} s nap') + time.sleep(ajax_thresholds['nap_time']) + except (KeyboardInterrupt, SystemExit, NotImplementedError) as e: + log.info('\nStopping, wait a second for the delete threads\n') + wait_on_delete_thread() + raise e + except Exception as fatal_error: + log.error(f'Fatal warning:\tran into {fatal_error}. Try ' + f'logging error and restart ajax') + try: + log_warning(f'Fatal warning:\tran into {fatal_error}', + priority='error') + except Exception as warning_error: + log.error(f'Fatal warning:\tcould not log {warning_error}') + # This usually only takes a minute or two + time.sleep(60) + log.warning('Restarting main loop') + + +def clean_ceph(): + """ + Look for old data on ceph. If found, delete it. Recursively iterate until no old data + is found. This function only works on the the host that is eb_can_clean_ceph. + """ + set_state('clean_ceph') + if not hostname == eb_can_clean_ceph: + log.info(f'for cleaning ceph, go to {eb_can_clean_ceph}') + return + + rd = run_coll.find_one({'bootstrax.state': 'done', + 'status': 'transferred', + 'start': + {'$lt': now(-ajax_thresholds['remove_live_after'])}, + 'bootstrax.time': + {"$lt": now(-ajax_thresholds['wait_after_processing'])}, + 'data.type': 'live'}) + if rd is None: + return + else: + run_number = rd['number'] + log.info(f'remove data associated to run {run_number}') + remove_run_from_host(run_number, delete_live=True, force=args.force, + reason='clean ceph') + log.info(f'finished for {run_number}, take a' + f' {ajax_thresholds["short_nap"]} s nap') + time.sleep(ajax_thresholds['short_nap']) + + # Repeat. + if args.execute: + wait_on_delete_thread() + clean_ceph() + + # Finally, check that there is no old data on ceph that is also not in the rundoc. + # clean_unregistered(delete_live=True) + + +def clean_high_level_data(): + """ + Check the runs database for old data on this host and clean all non-low-level + data-types from this host. + """ + set_state('clean_high_level_data') + raise NotImplementedError('This is a drastic measure that we do not ' + 'intend using at the time of writing') + + # We need to grep all the rundocs at once as simply doing one at the time (find_one) + # might get stuck as the query may yield the same result the next time it is called. + rds = run_coll.find({'bootstrax.state': 'done', + 'status': 'transferred', + 'start': + {"$lt": now(-ajax_thresholds['remove_high_level_data'])}, + 'data.host': hostname}) + if not rds: + # Great, we are clean + return + + for rd in rds: + # First check that we actually have raw_records stored somewhere + _, have_raw_records = check_rundoc_for_live_and_raw(rd) + if not have_raw_records: + break + + # Okay, good to go let's see if it is high level data and if so, delete it + for ddoc in rd['data']: + # Only delete data on this host + if 'host' in ddoc and ddoc['host'] == hostname: + is_low_level = re.findall('|'.join(low_level_data), ddoc['type']) + if is_low_level: + continue + loc = ddoc['location'] + if 'raw_records' in ddoc['type']: + raise ValueError('' + 'your regex syntax fails!') + elif os.path.exists(loc): + log.info(f'delete data at {loc}') + delete_data(rd, loc, ddoc['type'], test=not args.execute, + reason='high level data') + else: + loc = loc + '_temp' + log.info(f'delete data at {loc}') + delete_data(rd, loc, ddoc['type'], test=not args.execute, + reason='high level data') + + +def clean_non_latest(): + """ + Remove data on this host if the boostrax.host field is not this host while + processing has finished + """ + set_state('clean_non_latest') + # We need to grep all the rundocs at once as simply doing one at the time (find_one) + # might get stuck as the query may yield the same result the next time it is called. + rds = run_coll.find({'bootstrax.state': 'done', + 'status': 'transferred', + 'bootstrax.time': + {"$lt": now(-ajax_thresholds['wait_after_processing'])}, + 'bootstrax.host': {'$ne': hostname}, + 'data.host': hostname}) + if not rds: + # Great, we are clean + return + + for rd in rds: + # First check that we actually have raw_records stored on one of the other ebs + have_raw_records = False + for dd in rd['data']: + if (dd['type'] == 'raw_records' and + dd['host'] != hostname): + have_raw_records = True + break + if not have_raw_records: + break + + # Okay, good to go let's see if it is high level data and if so, delete it + for ddoc in rd['data']: + # Only delete data on this host + if 'host' in ddoc and ddoc['host'] == hostname: + loc = ddoc['location'] + if os.path.exists(loc): + log.info(f'clean_non_latest::\tdelete data at {loc}') + delete_data(rd, loc, ddoc['type'], test=not args.execute, + reason='non latest') + else: + loc = loc + '_temp' + log.info(f'clean_non_latest::\tdelete data at {loc}') + delete_data(rd, loc, ddoc['type'], test=not args.execute, + reason='non latest') + + +def clean_unregistered(delete_live=False): + """ + Clean data that is not in the database. To do this check the output folder and remove + files if there is no corresponding entry in the rundoc. + :param delete_live: bool, if true delete unregistered live data. + """ + set_state('clean_unregistered') + folder_to_check = output_folder if delete_live is False else ceph_folder + all_data = os.listdir(folder_to_check) + run_ids = [] + for f in all_data: + run_ids.append(f.split('-')[0]) + run_ids = np.unique(run_ids) + log.info(f'clean_unregistered::\tfound {len(run_ids)} runs stored on' + f'{folder_to_check}. Checking that each is in the runs-database') + for run_id in run_ids: + run_number = int(run_id) + remove_if_unregistered(run_number, delete_live=delete_live) + + +def clean_old_hash(): + """ + Loop over the files on the host and check that the lineage is the + same as the current lineage hash we would get from straxen + """ + set_state('clean_old_hash') + files = os.listdir(output_folder) + for f in files: + run_id, data_type, lineage_hash = f.split('-') + if (bool(re.findall('|'.join(low_level_data), data_type)) or + '_temp' in lineage_hash): + continue + elif data_type not in st._plugin_class_registry: + log.warning(f'{data_type} is not registered!') + continue + + current_st_hash = st.key_for(run_id, data_type).lineage_hash + if current_st_hash != lineage_hash: + loc = os.path.join(output_folder, f) + log.info(f'clean_old_hash::\tLineage for run {run_id}, {data_type} is ' + f'{current_st_hash}. Removing old hash: {lineage_hash} from {loc}.') + rd = run_coll.find_one({'bootstrax.state': 'done', + 'status': 'transferred', + 'bootstrax.time': + {"$lt": now(-ajax_thresholds['wait_after_processing'])}, + 'data.type': data_type, + 'data.host': hostname, + 'number': int(run_id)}) + if 'raw_records' in data_type: + raise ValueError(f'You did some sloppy regex on the data-type, almost ' + f'deleted {data_type}!') + if rd: + delete_data(rd, loc, data_type, test=not args.execute, + reason='old lineage hash') + else: + log.info(f'{data_type} at {loc} not registered as done for {run_id}. ' + f'Perhaps the data still processing the run?') + + +def clean_abandoned(delete_live=False): + """ + Recursively delete data associated to abandoned runs. If deleting live data, submit + multiple threads at the same time. + :param delete_live: bool, if true also delete the live_data of these runs. + """ + set_state('clean_abandoned') + # Notice that we thread the deletion of live_data. As such we need to allow multiple + # runs to be deleted simultaneously. + if not delete_live: + rd = run_coll.find_one({'bootstrax.state': 'abandoned', + 'data.host': hostname}) + if rd is None: + log.info('clean_abandoned::\tNo more matches in rundoc') + return + else: + # Make it iterable for the loop below + rds = [rd] + else: + rds = run_coll.find({'bootstrax.state': 'abandoned', + 'bootstrax.time': + {"$lt": now(-ajax_thresholds['wait_after_processing'])}, + 'data.location': '/live_data/xenonnt', + 'data.host': 'daq'}) + + # Count the number of threads. Only allow one if we aks for user input. + i = 0 + i_max = 1 if args.ask_confirm else 5 + for rd in rds: + if i >= i_max: + break + run_number = rd['number'] + log.info(f'clean_abandoned::\tremove data associated to run {run_number}') + # Please note that we have to force these runs always since they should not be + # stored elsewhere. They are abandoned for a reason! + remove_run_from_host(run_number, delete_live=delete_live, force=True, + reason='run abandonned') + i += 1 + if not i: + log.info('clean_abandoned::\tNo more live_data matches in rundoc') + # Apparently there is no rd in rds + return + + # Repeat + if args.execute: + if delete_live: + wait_on_delete_thread() + return clean_abandoned(delete_live=delete_live) + + +def clean_database(delete_live=False): + """ + Remove entries from the database if the data is not on this host + """ + set_state('clean_database') + # We need to grep all the rundocs at once as simply doing one at the + # time (find_one) might get stuck as the query may yield the same + # result the next time it is called. + if not delete_live: + rds = run_coll.find({'bootstrax.state': 'done', + 'bootstrax.time': + {"$lt": now(-ajax_thresholds['wait_after_processing'])}, + 'data.host': hostname}) + else: + rds = run_coll.find({'bootstrax.state': 'done', + 'bootstrax.time': + {"$lt": now(-ajax_thresholds['wait_after_processing'])}, + 'data.host': 'daq'}) + if not rds: + # Great, we are clean + return + + for rd in rds: + # Okay, good to go let's see if it is high level data and if so, delete it + for ddoc in rd['data']: + # Only delete data on this host + if (('host' in ddoc and ddoc['host'] == hostname) or + ('host' in ddoc and ddoc['host'] == 'daq' and delete_live)): + loc = ddoc['location'] + if not os.path.exists(loc): + log.info(f'clean_database::\tdelete entry of data from ' + f'{rd["number"]} at {loc} as it does not exist') + delete_data(rd, loc, ddoc['type'], test=not args.execute, + reason='reason already removed from eb?!') + + +## +# Core functions +## + + +def _rmtree(path): + """ + Wrapper for shutil.rmtree. All deletion statements in this script go + through this function in order to make sure that the + args.execute statement is always double (tripple) checked + before deleting data. + :param path: path to delete + :return: + """ + if args.execute: + if confirm(f'delete {path}?'): + shutil.rmtree(path) + else: + log.info(f'TESTING:\tshutil.rmtree({path})') + if not os.path.exists(path): + raise ValueError(f'{path} does not exist') + + +def threaded_delete_data(rd, path, data_type, + test=True, ignore_low_data_check=False, + reason=''): + """ + Wrapper for delete_data to run in separate threads. + :param rd: rundoc + :param path: location of the folder to be deleted + :param data_type: type of data to be deleted + :param test: bool if we are testing or not. If true, nothing will be deleted. + :param ignore_low_data_check: ignore the fact that this might be the only copy of the + data. We can specify this e.g. if we know this is data associated to some abandoned + run. + """ + + thread_name = thread_prefix + path.split('/')[-1] + delete_thread = threading.Thread(name=thread_name, + target=delete_data, + args=(rd, path, data_type, + test, ignore_low_data_check), + kwargs={'reason': reason}) + log.info(f'Starting thread to delete {path} at {now()}') + # We rather not stop deleting the live_data if something else fails. Set the thread + # to daemon. + delete_thread.setDaemon(True) + delete_thread.start() + log.info(f'DeleteThread {path} should be running in parallel, continue MainThread ' + f'now: {now()}') + + +def delete_data(rd, path, data_type, + test=True, + ignore_low_data_check=False, + reason='', + ): + """ + Delete data and update the rundoc + :param rd: rundoc + :param path: location of the folder to be deleted + :param data_type: type of data to be deleted + :param test: bool if we are testing or not. If true, nothing will + be deleted. + :param ignore_low_data_check: ignore the fact that this might be + the only copy of the data. We can specify this e.g. if we know + this is data associated to some abandoned run. + """ + # First check that we are not deleting essential data (unless we are working + # with abandoned runs or so. + if not ignore_low_data_check: + n_live, n_rr = check_rundoc_for_live_and_raw(rd) + if not ( + n_live or + n_rr >= 1 + int('raw_records' in data_type)): + message = (f'Trying to delete {data_type} but we only have {n_live}' + f' live- and {n_rr} raw_record-files in the ' + f'runs-database. This might be an essential copy of the' + f' data!') + log_warning(message, priority='fatal', run_id=f'{rd["number"]:06}') + if not test: + raise ValueError(message) + + if os.path.exists(path): + log.info(f'Deleting data at {path}') + if not test: + _rmtree(path) + log.info(f'deleting {path} finished') + else: + log.info(f'There is no data on {path}! Just doing the rundoc.') + + # Remove the data location from the rundoc and append it to the 'deleted_data' entries + if not os.path.exists(path): + log.info('changing data field in rundoc') + for ddoc in rd['data']: + if ddoc['type'] == data_type and ddoc['host'] in ('daq', hostname): + break + for k in ddoc.copy().keys(): + if k in ['location', 'meta', 'protocol']: + ddoc.pop(k) + + ddoc.update({'at': now(), 'by': f'ajax.{hostname}', 'reason': reason}) + log.info(f'update with {ddoc}') + if args.execute and not test: + if confirm('update rundoc?'): + run_coll.update_one({'_id': rd['_id']}, + {"$addToSet": {'deleted_data': ddoc}, + "$pull": {"data": + {"type": data_type, + "host": {'$in': ['daq', hostname]}}}}) + else: + log.info(f'Update ddoc with : {ddoc}') + elif not test and not args.ask_confirm: + raise ValueError(f"Something went wrong we wanted to delete {path}!") + + +def check_rundoc_for_live_and_raw(rd): + """ + Count the number of files of live_data (cannot be >1) and raw_data (transfers can get + it greater than 1) + :param rd: rundoc + :return: length 2 tuple with n_live_data and n_raw_records being the number of files + in the rundoc for the live data and raw records respectively. + """ + n_live_data, n_raw_records = 0, 0 + for dd in rd['data']: + if dd['type'] == 'live': + n_live_data += 1 + if dd['type'] == 'raw_records': + n_raw_records += 1 + return n_live_data, n_raw_records + + +def remove_run_from_host(number, delete_live=False, force=False, reason=''): + """ + Save way of removing data from host if data registered elsewhere + :param number: run number (not ID!) + :param delete_live: bool, if true delete the live_data else the processed data + :param force: forcefully remove the data even if we don't have the right copies (e.g. + deleting /live_data when the raw_records are not stored. Be careful with this option! + Should only be used for the deletion of abandoned runs. + """ + # Query the database to remove data + rd = run_coll.find_one({'number': number, + 'data.host': hostname if not delete_live else 'daq'}) + if not rd: + log_warning(f'No registered data for {number} on {hostname}', + run_id=f'{number:06}', priority='info') + return + + have_live_data, have_raw_records = check_rundoc_for_live_and_raw(rd) + + for ddoc in rd['data']: + # This is processed data on the eventbuilders + if 'host' in ddoc and ddoc['host'] == hostname: + if delete_live: + # If you want to delete the live data you shouldn't consider this ddoc + continue + loc = ddoc['location'] + if not force and not have_live_data and 'raw_records' in ddoc['type']: + # If we do not have live_data, don't delete raw_records. However, if we + # --force deletion, do go to the next else statement + log.info(f'prevent {loc} from being deleted. The live_data has already' + f' been removed') + else: + log.info(f'delete data at {loc}') + delete_data(rd, loc, ddoc['type'], test=not args.execute, + ignore_low_data_check=force, reason=reason) + + loc = loc + '_temp' + if os.path.exists(loc): + log.info(f'delete data at {loc}') + delete_data(rd, loc, ddoc['type'], test=not args.execute, + ignore_low_data_check=force, reason=reason) + elif 'host' in ddoc and ddoc['host'] == 'daq': + # This is the live_data + if not delete_live: + log.info(f'prevent {ddoc["location"]} from being deleted. Do so with --delete_live') + # If you want to delete processed data you shouldn't consider this ddoc + continue + + run_id = '%06d' % number + loc = os.path.join(ddoc['location'], run_id) + if not force and not have_raw_records: + # If we do not have raw records, don't delete this data. However, if we + # --force deletion, do go to the next else statement + log_warning( + f'Unsafe to delete {loc}, no raw_records registered. Force with ' + f'--force', priority='info') + elif not force and have_raw_records: + log.info(f'Deleting {loc} since we have raw_records registered.') + threaded_delete_data(rd, loc, ddoc['type'], test=not args.execute, + ignore_low_data_check=False, reason=reason) + # Redundant elif but let's double check the force nonetheless. + elif force: + log.info(f'Forcefully delete {loc}, but no raw_records registered!') + threaded_delete_data(rd, loc, ddoc['type'], test=not args.execute, + ignore_low_data_check=True, reason=reason) + + +def remove_if_unregistered(number, delete_live=False): + """ + Given a run number delete data on this machine that matches that number + :param number: int! the run_number (not run_id) + :param delete_live: Bool, if True: Remove the live_data. Else remove processed data + """ + # Query the database to remove data + # WANRING! If we don't find something we WILL remove the data! Don't make the query + # specific! We only check if any of the data is on this host (i.e. not get None from + # the query below) + rd = run_coll.find_one({'number': number, + 'data.host': hostname if not delete_live else 'daq'}) + run_id = '%06d' % number + + if rd: + # Just for explicitness, this is where we want to end up. If we have a rundoc, + # the data is registered and we don't have to do anything. + return + else: + log_warning(f'remove_if_unregistered::\trun {number} is NOT registered ' + f'in the runDB but is stored on {hostname}', + run_id=run_id, + priority='error') + if not delete_live: + # Check the local ebs disk for data. + _remove_unregistered_run(output_folder, run_id, checked_db=True) + else: + # Check ceph for data associated to this run (which is apparently not in the + # runDB) + _remove_unregistered_run(ceph_folder, run_id, checked_db=True) + + +def _remove_unregistered_run(base_folder, run_id, checked_db=False): + """ + NB: The check that this run is not registered should be performed first! + Deletes any folder from base_folder that matches run_id. + :param base_folder: folder to check + :param run_id: run_id to remove from folder + :param checked_db: Bool if it is checked that this run in not in the database. + """ + if not checked_db: + log_warning(f'remove_if_unregistered::\trogue ajax operations! Trying ' + f'to delete {run_id} from {hostname}', + run_id=run_id, + priority='fatal') + raise ValueError("Only insert runs where for it is checked that it is " + "not registered in the runs database and double checked.") + log_warning(f'No data for {run_id} found! Double checking {base_folder}!', + run_id=run_id, priority='warning') + deleted_data = False + + for folder in os.listdir(base_folder): + if run_id in folder: + log.info(f'Cleaning {base_folder + folder}') + + # Do a final check if we are not deleting essential data! + # Do not disable this check! If you don't like it: make a smarter query + rd = run_coll.find_one({'number': int(run_id)}) + n_live, n_rr = check_rundoc_for_live_and_raw(rd) + + if not (n_live or n_rr >= 1 + int('raw_records' in folder)): + message = (f'Trying to delete {folder} but we only have ' + f'{n_live} live- and {n_rr} raw_record-files in the ' + f'runs-database. This might be an essential copy of ' + f'the data!') + log_warning(message, run_id=run_id, priority='fatal') + raise ValueError(message) + + # OK, we still have live_data somewhere or we have raw_records (elsewhere) + if args.execute: + # Double check returns True automatically if not args.ask_confirm + if confirm( + f'Should we really move {os.path.join(base_folder, folder)} ' + f'to {os.path.join(non_registered_folder, folder)}?'): + shutil.move(os.path.join(base_folder, folder), + os.path.join(non_registered_folder, folder)) + else: + log.info(f'TEST\tmoving {base_folder + folder}') + deleted_data = True + + if not deleted_data: + message = f'No data registered on {hostname} for {run_id}' + log_warning(message, priority='fatal') + raise FileNotFoundError(message) + + +## +# Helper functions +## + + +def now(plus=0): + """UTC timestamp""" + return datetime.now(pytz.utc) + timedelta(seconds=plus) + + +def confirm(question): + """ + If --ask_confirm is specified, ask user to confirm to proceed. + :return: bool + """ + if not args.ask_confirm: + return True + answer = str(input(question + ' (y/n): \n')).lower().strip() + if answer in ('y', 'n'): + return answer == 'y' + else: + confirm('please input (y/n)\n') + + +def wait_on_delete_thread(): + """Check that the threads with the thread_prefix are finished before continuing.""" + threads = threading.enumerate() + for thread in threads: + if thread_prefix in thread.name: + wait = True + while wait: + wait = False + if thread.isAlive(): + log.info(f'{thread.name} still running take a ' + f'{ajax_thresholds["short_nap"]} s nap') + time.sleep(ajax_thresholds['short_nap']) + wait = True + log.info(f'wait_on_delete_thread::\tChecked that all {thread_prefix}* finished') + + +def log_warning(message, priority='warning', run_id=None): + """Report a warning to the terminal (using the logging module) + and the DAQ log DB. + :param message: insert string into log_coll + :param priority: severity of warning. Can be: + info: 1, + warning: 2, + : 3 + :param run_id: optional run id. + """ + if not args.execute: + return + getattr(log, priority)(message) + # Log according to redax rules + # https://github.com/coderdj/redax/blob/master/MongoLog.hh#L22 + warning_message = { + 'message': message, + 'user': f'ajax_{hostname}', + 'priority': + dict(debug=0, + info=1, + warning=2, + error=3, + fatal=4, + ).get(priority.lower(), 3)} + if run_id is not None: + warning_message.update({'runid': run_id}) + if args.execute: + # Only upload warning if we would actually execute the script + log_coll.insert_one(warning_message) + log.info(message) + + +def set_state(state): + """Inform the bootstrax collection we're in a different state + + if state is None, leave state unchanged, just update heartbeat time + """ + ajax_state = dict( + host='ajax.' + hostname, + pid=os.getpid(), + time=now(), + state=state, + mode=f'clean {args.clean}', + production_mode=args.execute + ) + bs_coll.insert_one(ajax_state) + + +## +# Main +## + + +if __name__ == '__main__': + print(f'---\n ajax version {__version__}\n---') + + parser = argparse.ArgumentParser( + description="XENONnT cleaning manager") + parser.add_argument('--force', action='store_true', + help="Forcefully remove stuff from this host") + parser.add_argument('--ask_confirm', action='store_true', + help="Always ask for confirmation before deleting data/updating" + " the rundoc") + parser.add_argument('--delete_live', action='store_true', + help="delete live data for this run") + parser.add_argument('--execute', action='store_true', + help="Execute the deletion commands. If not specified, ajax " + "assumes you want to test") + parser.add_argument('--logging', default='DEBUG', + help="logging level (DEBUG/INFO/WARNING)") + + actions = parser.add_mutually_exclusive_group() + actions.add_argument('--number', type=int, metavar='NUMBER', + help="Process a single run, regardless of its status.") + actions.add_argument('--clean', type=str, help= + 'Run ajax in any of the following modes: clean [ceph, ' # noqa + 'unregistered, abandoned, high_level_data, all]\n' + '"ceph": Remove successfully processed runs and abandoned runs from /live_data\n' + '"unregistered": remove all data from this host that is not registered in the rundb\n' + '"abandoned": remove all the data associated to abandoned runs\n' + '"high_level_data": remove all high level data on this host\n' + '"non_latest": remove data on this host if it was not the last to process a given run \n' + '"database": check if all the entries that the database claims are actually here \n' + '"old_hash": remove high level data if the hash doesnt equal the latest for this datatype \n' + '"all": Clean everything that AJAX can get its hands on: unregistered data, high level ' + 'data') + + args = parser.parse_args() + hostname = socket.getfqdn() + + if not hasattr(logging, args.logging): + raise AttributeError(f'Set --logging to a logging level like DEBUG or INFO') + logging_level = getattr(logging, args.logging) + try: + import daqnt + log_name = 'ajax_' + hostname + ('' if args.execute else '_TESTING') + log = daqnt.get_daq_logger('main', log_name, level=logging_level) + except ModuleNotFoundError: + logging.basicConfig( + level=logging_level, + format='%(asctime)s %(name)s %(levelname)-8s %(message)s', + datefmt='%m-%d %H:%M') + log = logging.getLogger() + + if args.delete_live: + if not args.number: + raise ValueError("Specify which number with --number") + if args.force: + log.warning(f'main::\tDANGER ZONE you are forcefully deleting data that may ' + f'result in an irrecoverable loss of data.') + log.info(f'main::\tPlease note that execute argument is {args.execute} which ' + f'means you are {"" if not args.execute else "!NOT!"} safe') + if not args.ask_confirm: + raise NotImplementedError( + f'main::\tI cannot let your forcefully delete data without asking for ' + f'confirmation. Add --ask_confirm. Bye, bye') + + if not input('Want to proceed willingly? [y]').lower() == 'y': + log.info(f'main::\tAlright no unsafe operations, bye bye') + exit(-1) + if args.clean != 'abandoned' and not args.number: + raise NotImplementedError( + "main::\tI don't want to have this option enabled (yet).") + + if args.clean in ['all', 'old_hash']: + # Need the context for the latest hash + st = straxen.contexts.xenonnt_online() + + # Set the folders + ceph_folder = '/live_data/xenonnt/' + output_folder = '/data/xenonnt_processed/' + non_registered_folder = '/data/xenonnt_unregistered/' + for f in (output_folder, non_registered_folder): + if os.access(f, os.W_OK) is not True: + log_warning(f'main::\tNo writing access to {f}', priority='fatal') + raise IOError(f'main::\tNo writing access to {f}') + + # Runs database + run_collname = 'runs' + run_dbname = straxen.uconfig.get('rundb_admin', 'mongo_rdb_database') + run_uri = straxen.get_mongo_uri(header='rundb_admin', + user_key='mongo_rdb_username', + pwd_key='mongo_rdb_password', + url_key='mongo_rdb_url') + run_client = pymongo.MongoClient(run_uri) + run_db = run_client[run_dbname] + run_coll = run_db[run_collname] + run_db.command('ping') + + # DAQ database + daq_db_name = 'daq' + daq_uri = straxen.get_mongo_uri(header='rundb_admin', + user_key='mongo_daq_username', + pwd_key='mongo_daq_password', + url_key='mongo_daq_url') + daq_client = pymongo.MongoClient(daq_uri) + daq_db = daq_client[daq_db_name] + log_coll = daq_db['log'] + bs_coll = daq_db['eb_monitor'] + daq_db.command('ping') + + try: + set_state('idle') + main_ajax() + except (KeyboardInterrupt, SystemExit) as e: + log.info('\nStopping, wait a second for the delete threads\n') + wait_on_delete_thread() + raise e + + wait_on_delete_thread() + log.info(f'main::\tAjax finished, bye bye') diff --git a/bin/bootstrax b/bin/bootstrax index cae457f36..26f976ff6 100755 --- a/bin/bootstrax +++ b/bin/bootstrax @@ -15,7 +15,7 @@ How to use For more info, see the documentation: https://straxen.readthedocs.io/en/latest/bootstrax.html """ -__version__ = '1.1.8' +__version__ = '1.2.1' import argparse from datetime import datetime, timedelta, timezone @@ -38,6 +38,8 @@ import straxen import threading import pandas as pd import typing as ty +import daqnt +import fnmatch parser = argparse.ArgumentParser( description="XENONnT online processing manager") @@ -81,6 +83,7 @@ parser.add_argument( '--max_messages', type=int, default=10, help="number of max mailbox messages") + actions = parser.add_mutually_exclusive_group() actions.add_argument( '--process', type=int, metavar='NUMBER', @@ -202,32 +205,35 @@ max_queue_new_runs = 2 # Remove any targets or post processing targets after the run failed # this many times. If high level data is hitting some edge case, we # might want to be able to keep the intermediate level data. -# NB: string match so events applies e.g. to event_basics, peak to e.g. peaklets -remove_target_after_fails = dict(events=2, - hitlets=2, - peaks=4, - peak_basics=5, - peaklets=6, - online=5, - ALL=7) +# NB: fnmatch so event* applies e.g. to event_basics, peak* to e.g. peaklets +remove_target_after_fails = { + 'event*': 2, + 'hitlets*': 2, + 'online_monitor_*v': 3, + 'online_peak': 5, + 'peaks*': 4, + 'peak_basics': 5, + 'peaklets': 6, + 'veto_*': 4, + '*': 7, +} ## # Initialize globals (e.g. rundb connection) ## hostname = socket.getfqdn() -try: - import daqnt - log_name = 'bootstrax_' + hostname + ('' if args.production else '_TESTING') - log = daqnt.get_daq_logger(log_name, log_name, level=logging.DEBUG) -except ModuleNotFoundError: - logging.basicConfig( - level=logging.INFO, - format='%(asctime)s %(name)s %(levelname)-8s %(message)s', - datefmt='%m-%d %H:%M') - log = logging.getLogger() +log_name = 'bootstrax_' + hostname + ('' if args.production else '_TESTING') +log = daqnt.get_daq_logger(log_name, log_name, level=logging.DEBUG) log.info(f'---\n bootstrax version {__version__}\n---') +log.info( + straxen.print_versions( + modules='strax straxen utilix daqnt numpy tensorflow numba'.split(), + include_git=True, + return_string=True, + )) + # Set the output folder output_folder = '/data/xenonnt_processed/' if args.production else test_data_folder @@ -255,7 +261,7 @@ else: if os.access(output_folder, os.W_OK) is not True: message = f'No writing access to {output_folder}' - log.warning(message, priority='fatal') + log.warning(message) raise IOError(message) @@ -273,7 +279,6 @@ def new_context(cores=args.cores, allow_multiprocess=cores > 1, allow_shm=cores > 1, allow_lazy=False, - use_rucio=False, max_messages=max_messages, timeout=timeout, _rucio_path=None, @@ -465,12 +470,12 @@ def _remove_veto_from_t(targets: ty.Union[str, list, tuple], remove: ty.Union[str, list, tuple] = ('_mv', '_nv'), _flip: bool = False) -> ty.Union[str, list, tuple, None]: """Remove veto(s) from targets""" - start = targets + start = strax.to_str_tuple(targets) remove = strax.to_str_tuple(remove) if targets is None: return None for r in remove: - targets = [t for t in strax.to_str_tuple(targets) if (r not in t)] + targets = keep_target(targets, {f'*{r}': 0}, 1) if _flip: targets = [t for i, t in enumerate(start) if not np.in1d(start, targets)[i]] return strax.to_str_tuple(targets) @@ -486,6 +491,24 @@ def _keep_veto_from_t(targets: ty.Union[str, list, tuple], return targets +def keep_target(targets, compare_with, n_fails): + kept_targets = [] + delete_after = -1 # just to make logging never fail below + for target_name in strax.to_str_tuple(targets): + for delete_target, delete_after in compare_with.items(): + failed_too_much = n_fails > delete_after + name_matches = fnmatch.fnmatch(target_name, delete_target) + if failed_too_much and name_matches: + log.warning(f'remove {target_name} ({n_fails}>{delete_after})') + break + else: + log.debug(f'keep {target_name} ({n_fails}!>{delete_after})') + kept_targets.append(target_name) + if not len(kept_targets): + kept_targets = ['raw_records'] + return kept_targets + + def infer_target(rd: dict) -> dict: """ Check if the target should be overridden based on the mode of the DAQ for this run @@ -503,28 +526,15 @@ def infer_target(rd: dict) -> dict: if n_fails: log.debug(f'Deleting targets') - for rm_target, rm_after_n in remove_target_after_fails.items(): - for target_i, targ in enumerate(targets): - if rm_target in targ and n_fails > rm_after_n: - log.debug(f'Remove {targ}') - del targets[target_i] - else: - log.debug(f'Keep {targ} for not in {rm_target} ({n_fails}<{rm_after_n})') - for post_i, post_targ in enumerate(post_process): - if post_targ in rm_target and n_fails > rm_after_n: - log.debug(f'Remove {post_targ}') - del post_process[post_i] + targets = keep_target(targets, remove_target_after_fails, n_fails) + post_process = keep_target(post_process, remove_target_after_fails, n_fails) log.debug(f'{targets} and {post_process} remaining') - if n_fails >= remove_target_after_fails['ALL']: - log.debug(f'Overwriting to raw-records only') - return {'targets': strax.to_str_tuple('raw_records'), - 'post_processing': strax.to_str_tuple('raw_records')} - # Special modes override target for these led_modes = ['pmtgain'] - diagnostic_modes = ['exttrig', 'noise', 'pmtap'] + diagnostic_modes = ['exttrig', 'noise'] + ap_modes = ['pmtap'] mode = str(rd.get('mode')) detectors = list(rd.get('detectors')) @@ -533,15 +543,22 @@ def infer_target(rd: dict) -> dict: log.debug('led-mode') targets = 'led_calibration' post_process = 'raw_records' + elif np.any([m in mode for m in ap_modes]): + log.debug('afterpulse mode') + targets = 'afterpulses' + post_process = 'raw_records' elif np.any([m in mode for m in diagnostic_modes]): log.debug('diagnostic-mode') targets = 'raw_records' post_process = 'raw_records' - elif 'kr83m' in mode and len(targets): + elif 'kr83m' in mode and (len(targets) or len(post_process)): # Override the first (highest level) plugin for Kr runs (could # also use source field, outcome is the same) - if targets[0] == 'event_info': - targets[0] = 'event_info_double' + if 'event_info' in targets or 'event_info' in post_process: + targets = list(targets) + ['event_info_double'] + + targets = strax.to_str_tuple(targets) + post_process = strax.to_str_tuple(post_process) if 'tpc' not in detectors: keep = [] @@ -574,7 +591,7 @@ def infer_target(rd: dict) -> dict: for check in (targets, post_process): if not len(set(check)) == len(check): log_warning(f'Duplicates in (post) targets {check}', - priority='Fatal') + priority='fatal') raise ValueError(f'Duplicates in (post) targets {check}') return {'targets': targets, 'post_processing': post_process} @@ -638,7 +655,7 @@ def log_warning(message, priority='warning', run_id=None): """ if not args.production: return - getattr(log, priority)(message) + getattr(log, priority.lower())(message) # Log according to redax rules # https://github.com/coderdj/redax/blob/master/MongoLog.hh#L22 warning_message = { @@ -1560,6 +1577,9 @@ def clean_run(*, mongo_id=None, number=None, force=False): log.info(f'delete data at {loc}') _delete_data(rd, loc, ddoc['type']) + # Also wipe the online_monitor if there is any + run_db['online_monitor'].delete_many({'number': int(rd['number'])}) + def clean_run_test_data(run_id): """ @@ -1642,7 +1662,7 @@ def cleanup_db(): 'bootstrax.state': 'done', 'start': {'$gt': now(-timeouts['abandoning_allowed'])}}, "Run has an 'abandon' tag"), - ({'tags.name': 'abandon', + ({'tags.name': 'abandon', 'bootstrax.state': 'failed'}, "Run has an 'abandon' tag and was failing"), ] diff --git a/docs/make.bat b/docs/make.bat index 8ac3ffd08..9cbabefd5 100644 --- a/docs/make.bat +++ b/docs/make.bat @@ -1,36 +1,36 @@ -@ECHO OFF - -pushd %~dp0 - -REM Command file for Sphinx documentation - -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -) -set SOURCEDIR=source -set BUILDDIR=build -set SPHINXPROJ=straxen - -if "%1" == "" goto help - -%SPHINXBUILD% >NUL 2>NUL -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ - exit /b 1 -) - -%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% -goto end - -:help -%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% - -:end -popd +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build +set SPHINXPROJ=straxen + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% + +:end +popd diff --git a/docs/source/bootstrax.rst b/docs/source/bootstrax.rst index f028e5c78..4bb152bcf 100644 --- a/docs/source/bootstrax.rst +++ b/docs/source/bootstrax.rst @@ -1,71 +1,71 @@ -Bootstrax: XENONnT online processing manager -============================================= -The ``bootstrax`` script watches for new runs to appear from the DAQ, then starts a -strax process to process them. If a run fails, it will retry it with -exponential backoff, each time waiting a little longer before retying. -After 10 failures, ``bootstrax`` stops trying to reprocess a run. -Additionally, every new time it is restarted it tries to process fewer plugins. -After a certain number of tries, it only reprocesses the raw-records. -Therefore a run that may fail at first may successfully be processed later. For example, if - -You can run more than one ``bootstrax`` instance, but only one per machine. -If you start a second one on the same machine, it will try to kill the -first one. - - -Philosophy ----------------- -Bootstrax has a crash-only / recovery first philosophy. Any error in -the core code causes a crash; there is no nice exit or mandatory -cleanup. Bootstrax focuses on recovery after restarts: before starting -work, we look for and fix any mess left by crashes. - -This ensures that hangs and hard crashes do not require expert tinkering -to repair databases. Plus, you can just stop the program with ctrl-c -(or, in principle, pulling the machine's power plug) at any time. - -Errors during run processing are assumed to be retry-able. We track the -number of failures per run to decide how long to wait until we retry; -only if a user marks a run as 'abandoned' (using an external system, -e.g. the website) do we stop retrying. - - -Mongo documents ----------------- -Bootstrax records its status in a document in the '``bootstrax``' collection -in the runs db. These documents contain: - - - **host**: socket.getfqdn() - - **time**: last time this ``bootstrax`` showed life signs - - **state**: one of the following: - - **busy**: doing something - - **idle**: NOT doing something; available for processing new runs - -Additionally, ``bootstrax`` tracks information with each run in the -'``bootstrax``' field of the run doc. We could also put this elsewhere, but -it seemed convenient. This field contains the following subfields: - - - **state**: one of the following: - - **considering**: a ``bootstrax`` is deciding what to do with it - - **busy**: a strax process is working on it - - **failed**: something is wrong, but we will retry after some amount of time. - - **abandoned**: ``bootstrax`` will ignore this run - - **reason**: reason for last failure, if there ever was one (otherwise this field - does not exists). Thus, it's quite possible for this field to exist (and - show an exception) when the state is ``'done'``: that just means it failed - at least once but succeeded later. Tracking failure history is primarily - the DAQ log's responsibility; this message is only provided for convenience. - - **n_failures**: number of failures on this run, if there ever was one - (otherwise this field does not exist). - - **next_retry**: time after which ``bootstrax`` might retry processing this run. - Like 'reason', this will refer to the last failure. - -Finally, ``bootstrax`` outputs the load on the eventbuilder machine(s) -whereon it is running to a collection in the DAQ database into the -capped collection 'eb_monitor'. This collection contains information on -what ``bootstrax`` is thinking of at the moment. - - - **disk_used**: used part of the disk whereto this ``bootstrax`` instance - is writing to (in percent). - +Bootstrax: XENONnT online processing manager +============================================= +The ``bootstrax`` script watches for new runs to appear from the DAQ, then starts a +strax process to process them. If a run fails, it will retry it with +exponential backoff, each time waiting a little longer before retying. +After 10 failures, ``bootstrax`` stops trying to reprocess a run. +Additionally, every new time it is restarted it tries to process fewer plugins. +After a certain number of tries, it only reprocesses the raw-records. +Therefore a run that may fail at first may successfully be processed later. For example, if + +You can run more than one ``bootstrax`` instance, but only one per machine. +If you start a second one on the same machine, it will try to kill the +first one. + + +Philosophy +---------------- +Bootstrax has a crash-only / recovery first philosophy. Any error in +the core code causes a crash; there is no nice exit or mandatory +cleanup. Bootstrax focuses on recovery after restarts: before starting +work, we look for and fix any mess left by crashes. + +This ensures that hangs and hard crashes do not require expert tinkering +to repair databases. Plus, you can just stop the program with ctrl-c +(or, in principle, pulling the machine's power plug) at any time. + +Errors during run processing are assumed to be retry-able. We track the +number of failures per run to decide how long to wait until we retry; +only if a user marks a run as 'abandoned' (using an external system, +e.g. the website) do we stop retrying. + + +Mongo documents +---------------- +Bootstrax records its status in a document in the '``bootstrax``' collection +in the runs db. These documents contain: + + - **host**: socket.getfqdn() + - **time**: last time this ``bootstrax`` showed life signs + - **state**: one of the following: + - **busy**: doing something + - **idle**: NOT doing something; available for processing new runs + +Additionally, ``bootstrax`` tracks information with each run in the +'``bootstrax``' field of the run doc. We could also put this elsewhere, but +it seemed convenient. This field contains the following subfields: + + - **state**: one of the following: + - **considering**: a ``bootstrax`` is deciding what to do with it + - **busy**: a strax process is working on it + - **failed**: something is wrong, but we will retry after some amount of time. + - **abandoned**: ``bootstrax`` will ignore this run + - **reason**: reason for last failure, if there ever was one (otherwise this field + does not exists). Thus, it's quite possible for this field to exist (and + show an exception) when the state is ``'done'``: that just means it failed + at least once but succeeded later. Tracking failure history is primarily + the DAQ log's responsibility; this message is only provided for convenience. + - **n_failures**: number of failures on this run, if there ever was one + (otherwise this field does not exist). + - **next_retry**: time after which ``bootstrax`` might retry processing this run. + Like 'reason', this will refer to the last failure. + +Finally, ``bootstrax`` outputs the load on the eventbuilder machine(s) +whereon it is running to a collection in the DAQ database into the +capped collection 'eb_monitor'. This collection contains information on +what ``bootstrax`` is thinking of at the moment. + + - **disk_used**: used part of the disk whereto this ``bootstrax`` instance + is writing to (in percent). + *Last updated 2021-05-07. Joran Angevaare* \ No newline at end of file diff --git a/docs/source/build_context_doc.py b/docs/source/build_context_doc.py new file mode 100644 index 000000000..943a26da7 --- /dev/null +++ b/docs/source/build_context_doc.py @@ -0,0 +1,35 @@ +import os +this_dir = os.path.dirname(os.path.realpath(__file__)) + + +base_doc = """ +======== +Contexts +======== +The contexts are a class from strax and used everywhere in straxen + +Below, all of the contexts functions are shown including the +`minianalyses `_ + +Contexts documentation +---------------------- +Auto generated documention of all the context functions including minianalyses + + +.. automodule:: strax.context + :members: + :undoc-members: + :show-inheritance: + +""" + + +def main(): + """Maybe we one day want to expend this, but for now, let's start with this""" + out = base_doc + with open(this_dir + f'/reference/context.rst', mode='w') as f: + f.write(out) + + +if __name__ == '__main__': + main() diff --git a/docs/source/cmt.rst b/docs/source/cmt.rst index 032f072c9..22671e3e1 100644 --- a/docs/source/cmt.rst +++ b/docs/source/cmt.rst @@ -1,20 +1,20 @@ -Corrections Management Tool (CMT) -================================= -Corrections Management Tool (CMT) is a centralized tool that allows to store, query and retrieve information about detector effects (corrections) where later the information is used at the event building process to remove (correct) such effects for a given data type. -In specific CMT is a class within `strax `_, the information is stored in MongoDB as document with a ``pandas.DataFrame()`` format and with a ``pandas.DatetimeIndex()`` this allows track time-dependent information as often detector conditions change over time. CMT also adds the functionality to differentiate between ONLINE and OFFLINE versioning, where ONLINE corrections are used during online processing and ,therefore, changes in the past are not allowed and OFFLINE version meant to be used for re-processing where changes in the past are allow. - - -CMT in straxen --------------- -A customized CMT can be implemented given the experiment software, in the case of straxen, experiment specifics can be added to CMT. To set CMT accordingly to straxen a class `CorrectionsManagementService() `_ allows the user to query and retrieve information. This class uses the start time of a given run to find the corresponding information and version for a given correction. For every correction user must set the proper configuration in order to retrieve the information, the syntax is the following ``my_configuration = (“my_correction”, “version”, True)`` the first part correspond to the string of the correction, then the version, it can be either an ONLINE version or OFFLINE version and finally the boolean correspond to the detector configuration (1T or nT). -In the case of straxen there are several plug-ins that call CMT to retrieve information, in that case, the configuration option is set by the ``strax.option()`` and the information is retrieve in `set()` via the function `straxen.get_correction_from_cmt()` and example is shown below where the electron life time is retrieve for a particular run ID, using the ONLINE version for the detector configuration nT=True. - - -.. code-block:: python - - import straxen - elife_conf = ("elife", "ONLINE", True) - elife = straxen.get_correction_from_cmt(run_id, elife_conf) - - -An experiment specific option that provide the ability to do bookkeeping for the different versions is the introduction of the concept of global versions, global version means a unique set of corrections, e.g. ``global_v3={elife[v2], s2_map[v3], s1_map[v3], etc}``. This is specially useful for the creation of different context where the user has can set all the corresponding configuration using a global version via ``apply_cmt_context()``. However the user must be aware that only local version are allow for individual configurations from straxen prior to ``0.19.0`` the user had only the option to use global version. +Corrections Management Tool (CMT) +================================= +Corrections Management Tool (CMT) is a centralized tool that allows to store, query and retrieve information about detector effects (corrections) where later the information is used at the event building process to remove (correct) such effects for a given data type. +In specific CMT is a class within `strax `_, the information is stored in MongoDB as document with a ``pandas.DataFrame()`` format and with a ``pandas.DatetimeIndex()`` this allows track time-dependent information as often detector conditions change over time. CMT also adds the functionality to differentiate between ONLINE and OFFLINE versioning, where ONLINE corrections are used during online processing and ,therefore, changes in the past are not allowed and OFFLINE version meant to be used for re-processing where changes in the past are allow. + + +CMT in straxen +-------------- +A customized CMT can be implemented given the experiment software, in the case of straxen, experiment specifics can be added to CMT. To set CMT accordingly to straxen a class `CorrectionsManagementService() `_ allows the user to query and retrieve information. This class uses the start time of a given run to find the corresponding information and version for a given correction. For every correction user must set the proper configuration in order to retrieve the information, the syntax is the following ``my_configuration = (“my_correction”, “version”, True)`` the first part correspond to the string of the correction, then the version, it can be either an ONLINE version or OFFLINE version and finally the boolean correspond to the detector configuration (1T or nT). +In the case of straxen there are several plug-ins that call CMT to retrieve information, in that case, the configuration option is set by the ``strax.option()`` and the information is retrieve in `set()` via the function `straxen.get_correction_from_cmt()` and example is shown below where the electron life time is retrieve for a particular run ID, using the ONLINE version for the detector configuration nT=True. + + +.. code-block:: python + + import straxen + elife_conf = ("elife", "ONLINE", True) + elife = straxen.get_correction_from_cmt(run_id, elife_conf) + + +An experiment specific option that provide the ability to do bookkeeping for the different versions is the introduction of the concept of global versions, global version means a unique set of corrections, e.g. ``global_v3={elife[v2], s2_map[v3], s1_map[v3], etc}``. This is specially useful for the creation of different context where the user has can set all the corresponding configuration using a global version via ``apply_cmt_context()``. However the user must be aware that only local version are allow for individual configurations from straxen prior to ``0.19.0`` the user had only the option to use global version. diff --git a/docs/source/conf.py b/docs/source/conf.py index 05b4405d5..0ec480732 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -196,3 +196,5 @@ def setup(app): import build_datastructure_doc build_datastructure_doc.build_datastructure_doc(True) build_datastructure_doc.build_datastructure_doc(False) + import build_context_doc + build_context_doc.main() diff --git a/docs/source/config_storage.rst b/docs/source/config_storage.rst index afc5e7cb6..21a59b1b4 100644 --- a/docs/source/config_storage.rst +++ b/docs/source/config_storage.rst @@ -1,119 +1,119 @@ -Storing configuration files in straxen -====================================== - -Most of the configuration files are stored in a mongo database that require a -password to access. Below the different methods of opening files such as -neural-nets in straxen are explained. Further down below a scheme is included -what is done behind the scenes. - - -Downloading XENONnT files from the database -------------------------------------------- -Most generically one downloads files using the :py:class:`straxen.MongoDownloader` -function. For example, one can download a file: - -.. code-block:: python - - import straxen - downloader = straxen.MongoDownloader() - # The downloader allows one to download files from the mongo database by - # looking for the requested name in the files database. The downloader - #returns the path of the downloaded file. - requested_file_name = 'fax_config.json' - config_path = downloader.download_single(requested_file_name) - # We can now open the file using get_resource - simulation_config = straxen.get_resource(config_path, fmt='json') - - -Alternatively, one can rely on the loading of :py:func:`straxen.get_resource` as below: - -.. code-block:: python - - import straxen - simulation_config = straxen.get_resource(requested_file_name, fmt='json') - - -Downloading public placeholder files ------------------------------------- -It is also possible to load any of our `placeholder files -`_. This is -for example used for testing purposes of the software using continuous -integration. This can be done as per the example below. The advantage of -loading files in this manner is that it does not require any password. -However, this kind of access is restricted to very few files and there are -also no checks in place to make sure if the requested file is the latest file. -Therefore, this manner of loading data is intended only for testing purposes. - -.. code-block:: python - - import straxen - requested_url = ( - 'https://github.com/XENONnT/strax_auxiliary_files/blob/' - '3548132b55f81a43654dba5141366041e1daaf01/strax_files/fax_config.json') - simulation_config = straxen.common.resource_from_url(requested_url, fmt='json') - - -How does the downloading work? --------------------------------------- -In :py:mod:`straxen/mongo_storage.py` there are two classes that take care of the -downloading and the uploading of files to the `files` database. In this -database we store configuration files under a :py:obj:`config_identifier` i.e. the -:py:obj:`'file_name'`. This is the label that is used to find the document one is -requesting. - -Scheme -^^^^^^^^^ -The scheme below illustrates how the different components work together to make -files available in the database to be loaded by a user. Starting on the left, -an admin user (with the credentials to upload files to the database) uploads a -file to the `files`- database (not shown) such that it can be downloaded later -by any user. The admin user can upload a file using the command -:py:obj:`MongoUploader.upload_from_dict({'file_name', '/path/to/file'})`. -This command will use the :py:class:`straxen.MongoUploader` class to put the file -:py:obj:`'file_name'` in the `files` database. The :py:class:`straxen.MongoUploader` will -communicate with the database via `GridFs -`_. -The GridFs interface communicates with two mongo-collections; :py:obj:`'fs.files'` and -:py:obj:`'fs.chunks'`, where the former is used for bookkeeping and the latter for -storing pieces of data (not to be confused with :py:class:`strax.Chunks`). - - -Uploading -^^^^^^^^^ -When the admin user issues the command to upload the :py:obj:`'file_name'`-file. The -:py:class:`straxen.MongoUploader` will check that the file is not already stored in the -database. To this end, the :py:class:`straxen.MongoUploader` computes the :py:obj:`md5-hash` of -the file stored under the :py:obj:`'/path/to/file'`. If this is the first time a file -with this :py:obj:`md5-hash` is uploaded, :py:class:`straxen.MongoUploader` will upload it to -:py:obj:`GridFs`. If there is already an existing file with the :py:obj:`md5-hash`, there is no -need to upload. This however does mean that if there is already a file :py:obj:`'file_name'` -stored and you modify the :py:obj:`'file_name'`-file, it will be uploaded again! This is -a feature, not a bug. When a user requests the :py:obj:`'file_name'`-file, the -:py:class:`straxen.MongoDownloader` will fetch the :py:obj:`'file_name'`-file that was uploaded -last. - - -Downloading -^^^^^^^^^^^ -Assuming that an admin user uploaded the :py:obj:`'file_name'`-file, any user (no -required admin rights) can now download the :py:obj:`'file_name'`-file (see above for the -example). When the user executes :py:obj:`MongoUploader.download_single('file_name')`, -the :py:class:`straxen.MongoDownloader` will check if the file is downloaded already. If -this is the case it will simply return the path of the file. Otherwise, it will -start downloading the file. It is important to notice that the files are saved -under their :py:obj:`md5-hash`-name. This means that wherever the files are stored, -it's unreadable what the file (or extension is). The reason to do it in this -way is that it will make sure that the file is never downloaded when it is -already stored but it would be if the file has been changed as explained above. - - -.. image:: figures/mongo_file_storage.svg - - -Straxen Mongo config loader classes -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Both the :py:class:`straxen.MongoUploader` and :py:class:`straxen.MongoDownloader` share a common -parent class, the :py:class:`straxen.GridFsInterface` that provides the appropriate -shared functionality and connection to the database. The important difference -is the :py:obj:`readonly` argument that naturally has to be :py:obj:`False` for the -:py:class:`straxen.MongoUploader` but :py:obj:`True` for the :py:class:`straxen.MongoDownloader`. +Storing configuration files in straxen +====================================== + +Most of the configuration files are stored in a mongo database that require a +password to access. Below the different methods of opening files such as +neural-nets in straxen are explained. Further down below a scheme is included +what is done behind the scenes. + + +Downloading XENONnT files from the database +------------------------------------------- +Most generically one downloads files using the :py:class:`straxen.MongoDownloader` +function. For example, one can download a file: + +.. code-block:: python + + import straxen + downloader = straxen.MongoDownloader() + # The downloader allows one to download files from the mongo database by + # looking for the requested name in the files database. The downloader + #returns the path of the downloaded file. + requested_file_name = 'fax_config.json' + config_path = downloader.download_single(requested_file_name) + # We can now open the file using get_resource + simulation_config = straxen.get_resource(config_path, fmt='json') + + +Alternatively, one can rely on the loading of :py:func:`straxen.get_resource` as below: + +.. code-block:: python + + import straxen + simulation_config = straxen.get_resource(requested_file_name, fmt='json') + + +Downloading public placeholder files +------------------------------------ +It is also possible to load any of our `placeholder files +`_. This is +for example used for testing purposes of the software using continuous +integration. This can be done as per the example below. The advantage of +loading files in this manner is that it does not require any password. +However, this kind of access is restricted to very few files and there are +also no checks in place to make sure if the requested file is the latest file. +Therefore, this manner of loading data is intended only for testing purposes. + +.. code-block:: python + + import straxen + requested_url = ( + 'https://github.com/XENONnT/strax_auxiliary_files/blob/' + '3548132b55f81a43654dba5141366041e1daaf01/strax_files/fax_config.json') + simulation_config = straxen.common.resource_from_url(requested_url, fmt='json') + + +How does the downloading work? +-------------------------------------- +In :py:mod:`straxen/mongo_storage.py` there are two classes that take care of the +downloading and the uploading of files to the `files` database. In this +database we store configuration files under a :py:obj:`config_identifier` i.e. the +:py:obj:`'file_name'`. This is the label that is used to find the document one is +requesting. + +Scheme +^^^^^^^^^ +The scheme below illustrates how the different components work together to make +files available in the database to be loaded by a user. Starting on the left, +an admin user (with the credentials to upload files to the database) uploads a +file to the `files`- database (not shown) such that it can be downloaded later +by any user. The admin user can upload a file using the command +:py:obj:`MongoUploader.upload_from_dict({'file_name', '/path/to/file'})`. +This command will use the :py:class:`straxen.MongoUploader` class to put the file +:py:obj:`'file_name'` in the `files` database. The :py:class:`straxen.MongoUploader` will +communicate with the database via `GridFs +`_. +The GridFs interface communicates with two mongo-collections; :py:obj:`'fs.files'` and +:py:obj:`'fs.chunks'`, where the former is used for bookkeeping and the latter for +storing pieces of data (not to be confused with :py:class:`strax.Chunks`). + + +Uploading +^^^^^^^^^ +When the admin user issues the command to upload the :py:obj:`'file_name'`-file. The +:py:class:`straxen.MongoUploader` will check that the file is not already stored in the +database. To this end, the :py:class:`straxen.MongoUploader` computes the :py:obj:`md5-hash` of +the file stored under the :py:obj:`'/path/to/file'`. If this is the first time a file +with this :py:obj:`md5-hash` is uploaded, :py:class:`straxen.MongoUploader` will upload it to +:py:obj:`GridFs`. If there is already an existing file with the :py:obj:`md5-hash`, there is no +need to upload. This however does mean that if there is already a file :py:obj:`'file_name'` +stored and you modify the :py:obj:`'file_name'`-file, it will be uploaded again! This is +a feature, not a bug. When a user requests the :py:obj:`'file_name'`-file, the +:py:class:`straxen.MongoDownloader` will fetch the :py:obj:`'file_name'`-file that was uploaded +last. + + +Downloading +^^^^^^^^^^^ +Assuming that an admin user uploaded the :py:obj:`'file_name'`-file, any user (no +required admin rights) can now download the :py:obj:`'file_name'`-file (see above for the +example). When the user executes :py:obj:`MongoUploader.download_single('file_name')`, +the :py:class:`straxen.MongoDownloader` will check if the file is downloaded already. If +this is the case it will simply return the path of the file. Otherwise, it will +start downloading the file. It is important to notice that the files are saved +under their :py:obj:`md5-hash`-name. This means that wherever the files are stored, +it's unreadable what the file (or extension is). The reason to do it in this +way is that it will make sure that the file is never downloaded when it is +already stored but it would be if the file has been changed as explained above. + + +.. image:: figures/mongo_file_storage.svg + + +Straxen Mongo config loader classes +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Both the :py:class:`straxen.MongoUploader` and :py:class:`straxen.MongoDownloader` share a common +parent class, the :py:class:`straxen.GridFsInterface` that provides the appropriate +shared functionality and connection to the database. The important difference +is the :py:obj:`readonly` argument that naturally has to be :py:obj:`False` for the +:py:class:`straxen.MongoUploader` but :py:obj:`True` for the :py:class:`straxen.MongoDownloader`. diff --git a/docs/source/index.rst b/docs/source/index.rst index 1e24b0d4b..4f37a6767 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -54,6 +54,7 @@ Straxen is the analysis framework for XENONnT, built on top of the generic `stra reference/datastructure_1T + .. toctree:: :maxdepth: 2 :caption: scripts @@ -63,11 +64,19 @@ Straxen is the analysis framework for XENONnT, built on top of the generic `stra .. toctree:: :maxdepth: 2 - :caption: Auxilliary tools + :caption: Auxiliary tools scada_interface tutorials/ScadaInterfaceExample.ipynb + +.. toctree:: + :maxdepth: 1 + :caption: Context and minianalyses + + reference/context + + .. toctree:: :maxdepth: 1 :caption: Reference diff --git a/docs/source/online_monitor.rst b/docs/source/online_monitor.rst index b0c709e6b..5563e7541 100644 --- a/docs/source/online_monitor.rst +++ b/docs/source/online_monitor.rst @@ -1,123 +1,123 @@ -XENONnT online monitor -====================== -Using strax, it is possible to live-process data while acquiring it. -This allows for fast monitoring. To further allow this, straxen has an -online monitor frontend. This allows a portion of the data to be -shipped of to the Mongo database while collecting the data at the DAQ. -This means that analysers can have fast feedback on what is going on inside the -TPC. - - -Loading data via the online monitor ------------------------------------ -In order to load this data in straxen, one can use the following setup -and start developing live-displays! - - -.. code-block:: python - - import straxen - st = straxen.contexts.xenonnt_online(_add_online_monitor_frontend=True) - - # Allow unfinished runs to be loaded, even before the DAQ has finished processing this run! - st.set_context_config({'allow_incomplete': True}) - st.get_df(latest_run_id, 'event_basics') - -This command adds the online-monitor frontend to the context. If data is -now requested by the user strax will fetch the data via this frontend -if it is not available in any of the other storage frontends. Usually the data -is available within ~30 second after a pulse was detected by a PMT. - - -Machinery ---------- -Using the strax online monitor frontend, each chunk of data being processed -on the DAQ can be shipped out to via the mongo database. Schematically, -this looks as in the following schematic. For data that is stored in the -online-monitor collection of the database, each chunk of data is stored twice. -The data that is written to the DAQ local storage is transferred by -`admix `_ to the shared analysis cluster -(`dali`). This transfer can only start once a run has been finished and also -the transfer takes time. To make data access almost instantaneous, this data is -also stored online. - - -.. image:: figures/online_monitor.svg - -The user will retrieve the data from the mongo database just as if the -data were stored locally. It takes slightly longer to store the data than if -it was stored on disk because each chunk is saved online individually. -However, with a decent internet connection, loading one run of any data -should only take ~10 s. - - -How long and what data is stored online? ----------------------------------------- -The online storage cannot hold data for extended periods of time, and, since -data is shipped to the analysis sites, there is no need to keep it around -forever. -As such, data will be available up to 7 days after writing the data to the -database. After that, the online data will be deleted automatically. - -Depending on the current settings, selected datatypes are stored in the database. -At the time of writing, these were: - - - ``online_peak_monitor`` - - ``event_basics`` - - ``veto_regions`` - -For the most up-to-date information, one can check the registration in the -``straxen.contexts.xenonnt_online`` context: -`here `_. - - -Caching the results of the online monitor ------------------------------------------ -For some applications, it's worth to keep a local copy of the data from the -online monitor. If one is interested in multiple runs, this is usually a good option. - -To this end one can use the context function ``copy_to_frontend``. By setting -``rechunk=True``, we are combining the many small files (one per chunk) into -a few bigger files which makes it much faster to load next time. - - -.. code-block:: python - - import straxen - st = straxen.contexts.xenonnt_online(_add_online_monitor_frontend=True) - st.copy_to_frontend(latest_run_id, 'event_basics', rechunk=True) - -One can look now where this run is stored: - -.. code-block:: python - - for storage_frontend in st.storage: - is_stored = st._is_stored_in_sf(latest_run_id, 'event_basics', storage_frontend) - print(f'{storage_frontend.__class__.__name__} has a copy: {is_stored}') - -which prints - -.. code-block:: rst - - RunDB has a copy: False - DataDirectory has a copy: False - DataDirectory has a copy: False - DataDirectory has a copy: True - OnlineMonitor has a copy: True - -You can also ``print(st.storage)`` to see which directories these refer to. -The ``DataDirectory``-storage frontends that do not have a copy are readonly -folders and not accessible to the user for writing. - -For more information on this, checkout the -`strax documentation on copying data `_. - - -Pre-configured monitoring tools -------------------------------- -For XENONnT we have the private monitor called `olmo `_ -which is only visible for XENONnT members. - - -*Last updated 2021-05-07. Joran Angevaare* - +XENONnT online monitor +====================== +Using strax, it is possible to live-process data while acquiring it. +This allows for fast monitoring. To further allow this, straxen has an +online monitor frontend. This allows a portion of the data to be +shipped of to the Mongo database while collecting the data at the DAQ. +This means that analysers can have fast feedback on what is going on inside the +TPC. + + +Loading data via the online monitor +----------------------------------- +In order to load this data in straxen, one can use the following setup +and start developing live-displays! + + +.. code-block:: python + + import straxen + st = straxen.contexts.xenonnt_online(_add_online_monitor_frontend=True) + + # Allow unfinished runs to be loaded, even before the DAQ has finished processing this run! + st.set_context_config({'allow_incomplete': True}) + st.get_df(latest_run_id, 'event_basics') + +This command adds the online-monitor frontend to the context. If data is +now requested by the user strax will fetch the data via this frontend +if it is not available in any of the other storage frontends. Usually the data +is available within ~30 second after a pulse was detected by a PMT. + + +Machinery +--------- +Using the strax online monitor frontend, each chunk of data being processed +on the DAQ can be shipped out to via the mongo database. Schematically, +this looks as in the following schematic. For data that is stored in the +online-monitor collection of the database, each chunk of data is stored twice. +The data that is written to the DAQ local storage is transferred by +`admix `_ to the shared analysis cluster +(`dali`). This transfer can only start once a run has been finished and also +the transfer takes time. To make data access almost instantaneous, this data is +also stored online. + + +.. image:: figures/online_monitor.svg + +The user will retrieve the data from the mongo database just as if the +data were stored locally. It takes slightly longer to store the data than if +it was stored on disk because each chunk is saved online individually. +However, with a decent internet connection, loading one run of any data +should only take ~10 s. + + +How long and what data is stored online? +---------------------------------------- +The online storage cannot hold data for extended periods of time, and, since +data is shipped to the analysis sites, there is no need to keep it around +forever. +As such, data will be available up to 7 days after writing the data to the +database. After that, the online data will be deleted automatically. + +Depending on the current settings, selected datatypes are stored in the database. +At the time of writing, these were: + + - ``online_peak_monitor`` + - ``event_basics`` + - ``veto_regions`` + +For the most up-to-date information, one can check the registration in the +``straxen.contexts.xenonnt_online`` context: +`here `_. + + +Caching the results of the online monitor +----------------------------------------- +For some applications, it's worth to keep a local copy of the data from the +online monitor. If one is interested in multiple runs, this is usually a good option. + +To this end one can use the context function ``copy_to_frontend``. By setting +``rechunk=True``, we are combining the many small files (one per chunk) into +a few bigger files which makes it much faster to load next time. + + +.. code-block:: python + + import straxen + st = straxen.contexts.xenonnt_online(_add_online_monitor_frontend=True) + st.copy_to_frontend(latest_run_id, 'event_basics', rechunk=True) + +One can look now where this run is stored: + +.. code-block:: python + + for storage_frontend in st.storage: + is_stored = st._is_stored_in_sf(latest_run_id, 'event_basics', storage_frontend) + print(f'{storage_frontend.__class__.__name__} has a copy: {is_stored}') + +which prints + +.. code-block:: rst + + RunDB has a copy: False + DataDirectory has a copy: False + DataDirectory has a copy: False + DataDirectory has a copy: True + OnlineMonitor has a copy: True + +You can also ``print(st.storage)`` to see which directories these refer to. +The ``DataDirectory``-storage frontends that do not have a copy are readonly +folders and not accessible to the user for writing. + +For more information on this, checkout the +`strax documentation on copying data `_. + + +Pre-configured monitoring tools +------------------------------- +For XENONnT we have the private monitor called `olmo `_ +which is only visible for XENONnT members. + + +*Last updated 2021-05-07. Joran Angevaare* + diff --git a/docs/source/reference/context.rst b/docs/source/reference/context.rst new file mode 100644 index 000000000..8ba24c313 --- /dev/null +++ b/docs/source/reference/context.rst @@ -0,0 +1,18 @@ + +======== +Contexts +======== +The contexts are a class from strax and used everywhere in straxen + +Below, all of the contexts functions are shown including the +`minianalyses `_ -that allow common uses of straxen. Some of these scripts are designed -to run on the DAQ whereas others are for common use cases. Each of the -scripts will be briefly discussed below: - -straxer -------- -``straxer`` is the most useful straxen script for regular users. Allows data to be -generated in a script format. Especially useful for reprocessing data -in batch jobs. - -For example a user can reprocess the data of run ``012100`` using the -following command up to ``event_info_double``. - -.. code-block:: bash - - straxer 012100 --target event_info_double - -For more information on the options, please refer to the help: - -.. code-block:: bash - - straxer --help - - -ajax [DAQ-only] ----------------- -The DAQ-cleaning script. Data is stored on the DAQ such that other tools -like `admix `_ may ship the data to -distributed storage. A portion of the high level data is stored on the DAQ -for diagnostic purposes for longer periods of time. ``ajax`` removes this -data if needed. -The ``ajax`` script looks for data on the eventbuilders -that can be deleted because at least one of the following reasons: - - - A run has been "abandoned", this means that there is no further use - for this data, e.g. a board failed during a run, there is no point in - keeping a run where part of the data on the DAQ. - - The live-data (intermediate DAQ format, even more raw than raw-records) has - been successfully processed. Therefore remove this intermediate datakind from - daq. - - A run has been abandoned but there is live-data still on the DAQ-bugger. - - Data is "unregistered" (not in the runsdatabase), - this only occurs if DAQ-experts perform tests on the DAQ. - - Since bootstrax runs on multiple hosts, some of the data may appear to be - stored more than once since a given bootstrax instance could crash during it's processing. - The data of unsucessful processings should be removed by ``ajax``. - - Finally ``ajax`` also checks if all the entries that are in the database are also on the host still - This sanity check catches any potential issues in the data handling by admix. - - -bootstrax [DAQ-only] --------------------- -As the main DAQ processing script. This is discussed separately. It is only used for XENONnT. - - -fake_daq ------------------- -Script that allows mimiming DAQ-processing by opening raw-records data. - - -microstrax ------------------- -Mini strax interface that allows strax-data to be retrieved using HTTP requests -on a given port. This is at the time of writing used on the DAQ as a pulse viewer. - - -refresh_raw_records -------------------- -Updates raw-records from old strax versions. This data is of a different -format and needs to be refreshed before it can be opened with more recent -versions of strax. - -*Last updated 2021-05-07. Joran Angevaare* +Straxen scripts +=================== +Straxen comes with +`several scripts `_ +that allow common uses of straxen. Some of these scripts are designed +to run on the DAQ whereas others are for common use cases. Each of the +scripts will be briefly discussed below: + +straxer +------- +``straxer`` is the most useful straxen script for regular users. Allows data to be +generated in a script format. Especially useful for reprocessing data +in batch jobs. + +For example a user can reprocess the data of run ``012100`` using the +following command up to ``event_info_double``. + +.. code-block:: bash + + straxer 012100 --target event_info_double + +For more information on the options, please refer to the help: + +.. code-block:: bash + + straxer --help + + +ajax [DAQ-only] +---------------- +The DAQ-cleaning script. Data is stored on the DAQ such that other tools +like `admix `_ may ship the data to +distributed storage. A portion of the high level data is stored on the DAQ +for diagnostic purposes for longer periods of time. ``ajax`` removes this +data if needed. +The ``ajax`` script looks for data on the eventbuilders +that can be deleted because at least one of the following reasons: + + - A run has been "abandoned", this means that there is no further use + for this data, e.g. a board failed during a run, there is no point in + keeping a run where part of the data on the DAQ. + - The live-data (intermediate DAQ format, even more raw than raw-records) has + been successfully processed. Therefore remove this intermediate datakind from + daq. + - A run has been abandoned but there is live-data still on the DAQ-bugger. + - Data is "unregistered" (not in the runsdatabase), + this only occurs if DAQ-experts perform tests on the DAQ. + - Since bootstrax runs on multiple hosts, some of the data may appear to be + stored more than once since a given bootstrax instance could crash during it's processing. + The data of unsucessful processings should be removed by ``ajax``. + - Finally ``ajax`` also checks if all the entries that are in the database are also on the host still + This sanity check catches any potential issues in the data handling by admix. + + +bootstrax [DAQ-only] +-------------------- +As the main DAQ processing script. This is discussed separately. It is only used for XENONnT. + + +fake_daq +------------------ +Script that allows mimiming DAQ-processing by opening raw-records data. + + +microstrax +------------------ +Mini strax interface that allows strax-data to be retrieved using HTTP requests +on a given port. This is at the time of writing used on the DAQ as a pulse viewer. + + +refresh_raw_records +------------------- +Updates raw-records from old strax versions. This data is of a different +format and needs to be refreshed before it can be opened with more recent +versions of strax. + +*Last updated 2021-05-07. Joran Angevaare* diff --git a/docs/source/url_configs.rst b/docs/source/url_configs.rst new file mode 100644 index 000000000..36180c4a8 --- /dev/null +++ b/docs/source/url_configs.rst @@ -0,0 +1,156 @@ + +URLConfig options +================= +The URLConfig class was designed to make it easier have complex plugin configuration. +A plugin may require a rich object (such as a TensorFlow model), loading a file, or a run dependent value for its calculation. +While its perfectly reasonable to perform all of these operations in the plugins `setup()` method, +some operations such as loading files and looking up CMT values tend to repeat themselves in many plugins leading to code duplication. +Having the same code duplicated in many plugins can be very difficult to maintain or improve, +with the added annoyance that changing this behavior requires editing the plugin code. +The URLConfig provides a consistent way to define such behaviors at runtime via a URL string. +The URL is like a recipe for how the config value should be loaded when it is needed by the plugin. +Small snippets of code for loading a configuration can be registered as protocols and can be used by all plugins. +This allows you to keep the plugin code clean and focused on the processing itself, +without mixing in details of how to load the configuration data which tends to change more frequently. + + +The main goals of the URLConfig: +- More flexibility in switching between CMT, get_resource, and static configuration values. +- Remove logic of how to fetch and construct configuration objects from the plugin to improve purity (computational logic only) and maintainability of the plugins. +- Make unit testing easier by separating the logic that uses the configuration from the logic that fetches its current value. +- Increase the expressivity of the CMT values (descriptive string instead of opaque tuple) +- Remove need for hardcoding of special treatment for each correction in CMT when reading values. + +A concrete plugin example +------------------------- + +**The old way loading a TF model** + +.. code-block:: python + + @export + @strax.takes_config( + strax.Option('min_reconstruction_area', + help='Skip reconstruction if area (PE) is less than this', + default=10), + strax.Option('n_top_pmts', default=straxen.n_top_pmts, + help="Number of top PMTs") + ) + class PeakPositionsBaseNT(strax.Plugin): + + def setup(self): + self.model_file = self._get_model_file_name() + if self.model_file is None: + warn(f'No file provided for {self.algorithm}. Setting all values ' + f'for {self.provides} to None.') + # No further setup required + return + + # Load the tensorflow model + import tensorflow as tf + if os.path.exists(self.model_file): + print(f"Path is local. Loading {self.algorithm} TF model locally " + f"from disk.") + else: + downloader = straxen.MongoDownloader() + try: + self.model_file = downloader.download_single(self.model_file) + except straxen.mongo_storage.CouldNotLoadError as e: + raise RuntimeError(f'Model files {self.model_file} is not found') from e + with tempfile.TemporaryDirectory() as tmpdirname: + tar = tarfile.open(self.model_file, mode="r:gz") + tar.extractall(path=tmpdirname) + self.model = tf.keras.models.load_model(tmpdirname) + + def _get_model_file_name(self): + config_file = f'{self.algorithm}_model' + model_from_config = self.config.get(config_file, 'No file') + if model_from_config == 'No file': + raise ValueError(f'{__class__.__name__} should have {config_file} ' + f'provided as an option.') + if isinstance(model_from_config, str) and os.path.exists(model_from_config): + # Allow direct path specification + return model_from_config + if model_from_config is None: + # Allow None to be specified (disables processing for given posrec) + return model_from_config + + # Use CMT + model_file = straxen.get_correction_from_cmt(self.run_id, model_from_config) + return model_file + + +Notice how all the details on how to fetch the model file and convert it to a python object that is actually needed, is all hardcoded into the plugin. This is not desirable, the plugin should contain processing logic only. + +**How this could be refactored using `strax.URLConfig`:** + +.. code-block:: python + + class PeakPositionsMLP(PeakPositionsBaseNT): + tf_model_mlp = straxen.URLConfig( + default=f'tf://' + f'resource://' + f'cmt://{algorithm}_model' + f'?version=ONLINE' + f'&run_id=plugin.run_id' + f'&fmt=abs_path', + help='MLP model. Should be opened using the "tf" descriptor. ' + 'Set to "None" to skip computation', + cache=3, + ) + +The details of where the model object is taken from can be determined by setting the model key of the context config +The URL is the object being hashed, so it is important to only use pure URLs i.e the same URL should always refer to the same resource. + +The URL is evaluated recursively in the following order: + 1) **?version=ONLINE&run_id=plugin.run_id&fmt=abs_path** - Query is parsed and substituted (plugin.* are replaced with plugin attributes as evaluated at runtime) the values are then passed as keyword arguments to any protocols that include them in their signature. Everythin after the rightmost `?` character is considered the keyword arguments for the protocols. + 2) **cmt://** - Loads value from CMT, in this case it loads the name of the resource encoding the keras model. + 3) **resource://** - Loads a xenon resource by name (can also load web URLs), in this case returns a path to the file. + 4) **tf://** - Loads a TF model from a path + +**Important** The URL arguments are sorted before they are passed to the plugin so that hashing is not sensitive to the order of the arguments. +This is important to remember when performing tests. +All of the actual code snippets for these protocols are shared among all plugins. + +Adding new protocols +-------------------- + +As an example lets look at some actual protocols in `url_config.py` + + +.. code-block:: python + + @URLConfig.register('format') + def format_arg(arg: str, **kwargs): + """apply pythons builtin format function to a string""" + return arg.format(**kwargs) + + + @URLConfig.register('itp_map') + def load_map(some_map, method='WeightedNearestNeighbors', **kwargs): + """Make an InterpolatingMap""" + return straxen.InterpolatingMap(some_map, method=method, **kwargs) + + + @URLConfig.register('bodega') + def load_value(name: str, bodega_version=None): + """Load a number from BODEGA file""" + if bodega_version is None: + raise ValueError('Provide version see e.g. tests/test_url_config.py') + nt_numbers = straxen.get_resource("XENONnT_numbers.json", fmt="json") + return nt_numbers[name][bodega_version]["value"] + + + @URLConfig.register('tf') + def open_neural_net(model_path: str, **kwargs): + # Nested import to reduce loading time of import straxen and it not + # base requirement + import tensorflow as tf + if not os.path.exists(model_path): + raise FileNotFoundError(f'No file at {model_path}') + with tempfile.TemporaryDirectory() as tmpdirname: + tar = tarfile.open(model_path, mode="r:gz") + tar.extractall(path=tmpdirname) + return tf.keras.models.load_model(tmpdirname) + +As you can see its very easy to define new protocols, once its defined you can use it in any URL! diff --git a/extra_requirements/requirements-docs.txt b/extra_requirements/requirements-docs.txt index 4257d2a78..91a9e8ba1 100644 --- a/extra_requirements/requirements-docs.txt +++ b/extra_requirements/requirements-docs.txt @@ -1,6 +1,7 @@ -# File for the requirements of the documentation -sphinx -sphinx_rtd_theme -nbsphinx -recommonmark -graphviz \ No newline at end of file +# File for the requirements of the documentation +commonmark==0.9.1 +graphviz==0.19.1 +nbsphinx==0.8.8 +recommonmark==0.7.1 +sphinx==4.3.2 +sphinx_rtd_theme==1.0.0 diff --git a/extra_requirements/requirements-tests.txt b/extra_requirements/requirements-tests.txt index 1d3fb8d43..f7fe75038 100644 --- a/extra_requirements/requirements-tests.txt +++ b/extra_requirements/requirements-tests.txt @@ -1,30 +1,43 @@ -# File for the requirements of straxen with the automated tests -blosc==1.10.4 # Strax dependency -boltons==21.0.0 -datashader==0.13.0 -dask==2021.7.2 -dill==0.3.4 # Strax dependency -coveralls==3.2.0 -commentjson==0.9.0 -coverage==5.5 -flake8==3.9.2 -holoviews==1.14.5 -ipywidgets==7.6.3 -hypothesis==6.14.5 -jupyter-client==6.1.12 # for ipywidgets -matplotlib==3.4.2 -multihist==0.6.4 -npshmex==0.2.1 # Strax dependency -numba==0.53.1 # Strax dependency -numpy==1.19.5 -pandas==1.2.5 # Strax dependency -psutil==5.8.0 # Strax dependency -pytest==6.2.4 -pytest-cov==2.12.1 -scikit-learn==0.24.2 -scipy==1.7.0 # Strax dependency -tensorflow==2.5.1 -tqdm==4.62.0 -xarray==0.19.0 -utilix==0.6.1 -zstd==1.5.0.2 # Strax dependency +# File for the requirements of straxen with the automated tests +blosc==1.10.6 # Strax dependency +bokeh==2.4.2 +boltons==21.0.0 +commentjson==0.9.0 +coverage==6.2 +coveralls==3.3.1 +dask==2021.12.0 +datashader==0.13.0 +dill==0.3.4 # Strax dependency +flake8==4.0.1 +gitpython==3.1.26 +holoviews==1.14.7; python_version<="3.9" +holoviews==1.14.7; python_version=="3.10" +hypothesis==6.34.1 +ipywidgets==7.6.5 +jupyter-client==7.1.0 # for ipywidgets +keras==2.7.0; python_version<="3.9" # Tensorflow dependency +keras==2.8.0rc0; python_version=="3.10" # Tensorflow dependency +matplotlib==3.5.1 +multihist==0.6.4 +nestpy==1.5.0; python_version<="3.9" # WFSim dependency, doesn't work in py3.10 +npshmex==0.2.1 # Strax dependency +numba==0.55.0 +numexpr==2.8.1 +numpy==1.21.5 +packaging==21.3 +pandas==1.3.5 +panel==0.12.6 # Bokeh dependency +psutil==5.9.0 # Strax dependency +pymongo==3.12.0 # Strax dependency +pytest==6.2.5 +pytest-cov==3.0.0 +scikit-learn==1.0.2 +scipy==1.7.3 # Strax dependency +tensorflow==2.7.0; python_version<="3.9" +tensorflow==2.8.0rc0; python_version=="3.10" +tqdm==4.62.2 +typing_extensions==4.0.1 # Tensorflow/bokeh dependency +utilix==0.6.5 +wfsim==0.5.12; python_version<="3.9" # nestpy doesn't work in py3.10 +xarray==0.20.2 +zstd==1.5.1.0 # Strax dependency diff --git a/requirements.txt b/requirements.txt index fbc80bcbf..c6f69faf8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,16 @@ -strax>=1.0.0rc -utilix>=0.5.3 -# tensorflow>=2.3.0 # Optional, to (re)do posrec -# holoviews # Optional, to enable wf display # datashader # Optional, to enable wf display +# holoviews # Optional, to enable wf display +# tensorflow>=2.3.0 # Optional, to (re)do posrec bokeh>=2.2.3 -multihist>=0.6.3 -packaging +commentjson +gitpython immutabledict +matplotlib +multihist>=0.6.3 numba>=0.50.0 +numpy +packaging +pymongo<4.0.0 requests -commentjson -matplotlib +strax>=1.1.2 +utilix>=0.5.3 diff --git a/setup.py b/setup.py index 7cc64c3fe..4a09d5408 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ def open_requirements(path): history = file.read() setuptools.setup(name='straxen', - version='1.0.0', + version='1.2.6', description='Streaming analysis for XENON', author='Straxen contributors, the XENON collaboration', url='https://github.com/XENONnT/straxen', @@ -56,6 +56,7 @@ def open_requirements(path): 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', 'Intended Audience :: Science/Research', 'Programming Language :: Python :: Implementation :: CPython', 'Topic :: Scientific/Engineering :: Physics', diff --git a/straxen/__init__.py b/straxen/__init__.py index 207ae60a2..cac276026 100644 --- a/straxen/__init__.py +++ b/straxen/__init__.py @@ -1,4 +1,4 @@ -__version__ = '1.0.0' +__version__ = '1.2.6' from utilix import uconfig from .common import * @@ -16,6 +16,7 @@ from .scada import * from .bokeh_utils import * from .rucio import * +from .url_config import * from . import plugins from .plugins import * @@ -26,6 +27,9 @@ # Otherwise we have straxen.demo() etc. from . import contexts +from . import test_utils +from .test_utils import * + try: from . import holoviews_utils from .holoviews_utils import * diff --git a/straxen/analyses/__init__.py b/straxen/analyses/__init__.py index ba01cad6b..fe8c5f89e 100644 --- a/straxen/analyses/__init__.py +++ b/straxen/analyses/__init__.py @@ -1,9 +1,9 @@ -from . import quick_checks -from . import records_matrix -from . import waveform_plot -from . import holoviews_waveform_display from . import bokeh_waveform_plot +from . import daq_waveforms from . import event_display -from . import pulse_plots +from . import holoviews_waveform_display from . import posrec_comparison -from . import daq_waveforms +from . import pulse_plots +from . import quick_checks +from . import records_matrix +from . import waveform_plot diff --git a/straxen/analyses/bokeh_waveform_plot.py b/straxen/analyses/bokeh_waveform_plot.py index 5c43b950e..ba1c0a141 100644 --- a/straxen/analyses/bokeh_waveform_plot.py +++ b/straxen/analyses/bokeh_waveform_plot.py @@ -1,15 +1,12 @@ +import warnings + import bokeh import bokeh.plotting as bklt - -from straxen.analyses.holoviews_waveform_display import _hvdisp_plot_records_2d, hook, \ - plot_record_polygons, get_records_matrix_in_window - -import numpy as np import numba +import numpy as np import strax import straxen - -import warnings +from straxen.analyses.holoviews_waveform_display import _hvdisp_plot_records_2d, hook, plot_record_polygons, get_records_matrix_in_window # noqa # Default legend, unknow, S1 and S2 LEGENDS = ('Unknown', 'S1', 'S2') @@ -186,8 +183,9 @@ def event_display_interactive(events, r = st.get_array(run_id, 'raw_records', time_range=(events[0]['time'], events[0]['endtime'])) r = p.compute(r, events[0]['time'], events[0]['endtime'])['records'] else: - warnings.warn(f'Can neither find records nor raw_records for run {run_id}, proceed without record ' - f'matrix.') + warnings.warn( + f'Can neither find records nor raw_records for run {run_id}, proceed without record ' + f'matrix.') plot_record_matrix = False if plot_record_matrix: @@ -718,7 +716,7 @@ def _make_event_title(event, run_id, width=1600): sizing_mode='scale_both', width=width, default_size=width, - orientation='vertical', + # orientation='vertical', width_policy='fit', margin=(0, 0, -30, 50) ) @@ -752,6 +750,7 @@ class DataSelectionHist: """ Class for an interactive data selection plot. """ + def __init__(self, name, size=600): """ Class for an interactive data selection plot. diff --git a/straxen/analyses/daq_waveforms.py b/straxen/analyses/daq_waveforms.py index fae7656ca..8a20781d9 100644 --- a/straxen/analyses/daq_waveforms.py +++ b/straxen/analyses/daq_waveforms.py @@ -1,137 +1,135 @@ -import numba -import pandas -import straxen -import numpy as np -import strax -import pymongo -import typing -import matplotlib.pyplot as plt - - -@straxen.mini_analysis() -def daq_plot(context, - figsize=(14, 15), - lower_panel_height=6, - group_by='link', - vmin=None, - vmax=None, - **kwargs): - """ - Plot with peak, records and records sorted by "link" or "ADC ID" - (other items are also possible as long as it is in the channel map). - """ - - f, axes = plt.subplots(3, 1, - figsize=figsize, - gridspec_kw={'height_ratios': [1, 1, lower_panel_height]}) - - # Panel 1, the peaks - plt.sca(axes[0]) - plt.title('Peaks') - context.plot_peaks(**kwargs, - single_figure=False) - xlim = plt.xlim() - plt.xticks(rotation=0) - plt.grid('y') - - # Panel 2, the records where we keep the order of the records/channel number - plt.sca(axes[1]) - plt.title('Records (by channel number)') - context.plot_records_matrix(**kwargs, - vmin=vmin, - vmax=vmax, - single_figure=False) - plt.xticks(rotation=0) - plt.grid('x') - plt.xlim(*xlim) - - # Use a grouping argument to group the channels by. - plt.sca(axes[2]) - plt.title(f'Records (by {group_by})') - context.plot_records_matrix(**kwargs, - vmin=vmin, - vmax=vmax, - group_by=group_by, - single_figure=False) - plt.xlim(*xlim) - plt.grid() - - -def _get_daq_config( - context: strax.Context, - run_id: str, - config_name: str = 'daq_config', - run_collection: typing.Optional[pymongo.collection.Collection] = None) -> dict: - """ - Query the runs database for the config of the daq during this run. - Either use the context of the runs collection. - """ - if not context.storage[0].__class__.__name__ == 'RunDB' and run_collection is None: - raise NotImplementedError('Only works with the runs-database') - if run_collection is None: - run_collection = context.storage[0].collection - daq_doc = run_collection.find_one({"number": int(run_id)}, - projection={config_name: 1}) - if daq_doc is None or config_name not in daq_doc: - raise ValueError(f'Requested {config_name} does not exist') - return daq_doc[config_name] - - -def _board_to_host_link(daq_config: dict, board: int, add_crate=True) -> str: - """Parse the daq-config to get the host, link and crate""" - for bdoc in daq_config['boards']: - try: - if int(bdoc['board']) == board: - res = f"{bdoc['host']}_link{bdoc['link']}" - if add_crate: - res += f"_crate{bdoc['crate']}" - return res - except KeyError: - raise ValueError(f'Invalid DAQ config {daq_config} or board {board}') - # This happens if the board is not in the channel map which might - # happen for very old runs. - return 'unknown' - - -def _get_cable_map(name: str = 'xenonnt_cable_map.csv') -> pandas.DataFrame: - """Download the cable map and return as a pandas dataframe""" - down = straxen.MongoDownloader() - cable_map = down.download_single(name) - cable_map = pandas.read_csv(cable_map) - return cable_map - - -def _group_channels_by_index(cable_map: pandas.DataFrame, - group_by: str = 'ADC ID', - ) -> typing.Tuple[np.ndarray, np.ndarray]: - """ - Parse the cable map, return the labels where each of the channels is - mapped to as well as an array that can be used to map each of the - channels maps to the labels. - """ - idx = np.arange(straxen.n_tpc_pmts) - idx_seen = 0 - labels = [] - for selection in np.unique(cable_map[group_by].values): - selected_channels = cable_map[cable_map[group_by] == selection]['PMT Location'] - selected_channels = np.array(selected_channels) - n_sel = len(selected_channels) - - idx[idx_seen:idx_seen + n_sel] = selected_channels - labels += [selection] * n_sel - idx_seen += n_sel - return np.array(labels), idx - - -def group_by_daq(context, run_id, group_by: str): - """From the channel map, get the mapping of channel number -> group by""" - cable_map = _get_cable_map() - if group_by == 'link': - labels, idx = _group_channels_by_index(cable_map, group_by='ADC ID') - daq_config = _get_daq_config(context, run_id) - labels = [_board_to_host_link(daq_config, l) for l in labels] - labels = np.array(labels) - order = np.argsort(labels) - return labels[order], idx[order] - else: - return _group_channels_by_index(cable_map, group_by=group_by) +import typing + +import matplotlib.pyplot as plt +import numpy as np +import pandas +import pymongo +import straxen +import utilix + + +@straxen.mini_analysis() +def daq_plot(context, + figsize=(14, 15), + lower_panel_height=6, + group_by='link', + vmin=None, + vmax=None, + **kwargs): + """ + Plot with peak, records and records sorted by "link" or "ADC ID" + (other items are also possible as long as it is in the channel map). + """ + f, axes = plt.subplots(3, 1, + figsize=figsize, + gridspec_kw={'height_ratios': [1, 1, lower_panel_height]}) + + # Panel 1, the peaks + plt.sca(axes[0]) + plt.title('Peaks') + context.plot_peaks(**kwargs, + single_figure=False) + xlim = plt.xlim() + plt.xticks(rotation=0) + plt.grid('y') + + # Panel 2, the records where we keep the order of the records/channel number + plt.sca(axes[1]) + plt.title('Records (by channel number)') + context.plot_records_matrix(**kwargs, + vmin=vmin, + vmax=vmax, + single_figure=False) + plt.xticks(rotation=0) + plt.grid('x') + plt.xlim(*xlim) + + # Use a grouping argument to group the channels by. + plt.sca(axes[2]) + plt.title(f'Records (by {group_by})') + context.plot_records_matrix(**kwargs, + vmin=vmin, + vmax=vmax, + group_by=group_by, + single_figure=False) + plt.xlim(*xlim) + plt.grid() + + +def _get_daq_config( + run_id: str, + config_name: str = 'daq_config', + run_collection: typing.Optional[pymongo.collection.Collection] = None) -> dict: + """ + Query the runs database for the config of the daq during this run. + Either use the context of the runs collection. + """ + if run_collection is None: + if not straxen.utilix_is_configured(): + raise NotImplementedError('Only works with the runs-database') + run_collection = utilix.rundb.xent_collection() + daq_doc = run_collection.find_one({"number": int(run_id)}, + projection={config_name: 1}) + if daq_doc is None or config_name not in daq_doc: + raise ValueError(f'Requested {config_name} does not exist') + return daq_doc[config_name] + + +def _board_to_host_link(daq_config: dict, board: int, add_crate=True) -> str: + """Parse the daq-config to get the host, link and crate""" + for bdoc in daq_config['boards']: + try: + if int(bdoc['board']) == board: + res = f"{bdoc['host']}_link{bdoc['link']}" + if add_crate: + res += f"_crate{bdoc['crate']}" + return res + except KeyError: + raise ValueError(f'Invalid DAQ config {daq_config} or board {board}') + # This happens if the board is not in the channel map which might + # happen for very old runs. + return 'unknown' + + +def _get_cable_map(name: str = 'xenonnt_cable_map.csv') -> pandas.DataFrame: + """Download the cable map and return as a pandas dataframe""" + down = straxen.MongoDownloader() + cable_map = down.download_single(name) + cable_map = pandas.read_csv(cable_map) + return cable_map + + +def _group_channels_by_index(cable_map: pandas.DataFrame, + group_by: str = 'ADC ID', + ) -> typing.Tuple[np.ndarray, np.ndarray]: + """ + Parse the cable map, return the labels where each of the channels is + mapped to as well as an array that can be used to map each of the + channels maps to the labels. + """ + idx = np.arange(straxen.n_tpc_pmts) + idx_seen = 0 + labels = [] + for selection in np.unique(cable_map[group_by].values): + selected_channels = cable_map[cable_map[group_by] == selection]['PMT Location'] + selected_channels = np.array(selected_channels) + n_sel = len(selected_channels) + + idx[idx_seen:idx_seen + n_sel] = selected_channels + labels += [selection] * n_sel + idx_seen += n_sel + return np.array(labels), idx + + +def group_by_daq(run_id, group_by: str): + """From the channel map, get the mapping of channel number -> group by""" + cable_map = _get_cable_map() + if group_by == 'link': + labels, idx = _group_channels_by_index(cable_map, group_by='ADC ID') + daq_config = _get_daq_config(run_id) + labels = [_board_to_host_link(daq_config, l) for l in labels] + labels = np.array(labels) + order = np.argsort(labels) + return labels[order], idx[order] + else: + return _group_channels_by_index(cable_map, group_by=group_by) diff --git a/straxen/analyses/event_display.py b/straxen/analyses/event_display.py index 22101ff06..388b5713e 100644 --- a/straxen/analyses/event_display.py +++ b/straxen/analyses/event_display.py @@ -1,480 +1,460 @@ -import numpy as np -import matplotlib.pyplot as plt -import matplotlib.gridspec as gridspec -import strax -import straxen -from datetime import datetime -import pytz - -export, __all__ = strax.exporter() - -# Default attributes to display in the event_display (looks little -# complicated but just repeats same fields for S1 S1) -# Should be of form as below where {v} wil be filled with the value of -# event['key']: -# (('key', '{v} UNIT'), ..) -PEAK_DISPLAY_DEFAULT_INFO = sum([[(k.format(i=s_i), u) for k, u in - (('cs{i}', '{v:.2f} PE'), - ('s{i}_area', '{v:.2f} PE'), - ('alt_cs{i}', '{v:.2f} PE'), - ('s{i}_n_channels', '{v}'), - ('s{i}_area_fraction_top', '{v:.2f}'), - ('s{i}_range_50p_area', '{v:.1f}'), - )] for s_i in (1, 2)], []) -EVENT_DISPLAY_DEFAULT_INFO = (('time', '{v} ns'), - ('endtime', '{v} ns'), - ('event_number', '{v}'), - ('x', '{v:.2f} cm'), - ('y', '{v:.2f} cm'), - ('z', '{v:.2f} cm'), - ('r', '{v:.2f} cm'), - ('theta', '{v:.2f} rad'), - ('drift_time', '{v} ns'), - ('alt_s1_interaction_drift_time', '{v} ns'), - ('alt_s2_interaction_drift_time', '{v} ns') - ) - - -# Don't be smart with the arguments, since it is a minianalyses we -# need to have all the arguments -@straxen.mini_analysis(requires=('event_info',)) -def event_display_simple(context, - run_id, - events, - to_pe, - records_matrix=True, - s2_fuzz=50, - s1_fuzz=0, - max_peaks=500, - xenon1t=False, - display_peak_info=PEAK_DISPLAY_DEFAULT_INFO, - display_event_info=EVENT_DISPLAY_DEFAULT_INFO, - s1_hp_kwargs=None, - s2_hp_kwargs=None, - event_time_limit=None, - plot_all_positions=True, - ): - """ - {event_docs} - {event_returns} - """ - fig = plt.figure(figsize=(12, 8), facecolor="white") - grid = plt.GridSpec(2, 3, hspace=0.5) - axes = dict() - axes["ax_s1"] = fig.add_subplot(grid[0, 0]) - axes["ax_s2"] = fig.add_subplot(grid[0, 1]) - axes["ax_s2_hp_t"] = fig.add_subplot(grid[0, 2]) - axes["ax_ev"] = fig.add_subplot(grid[1, :]) - - return _event_display(context, - run_id, - events, - to_pe, - axes=axes, - records_matrix=records_matrix, - s2_fuzz=s2_fuzz, - s1_fuzz=s1_fuzz, - max_peaks=max_peaks, - xenon1t=xenon1t, - display_peak_info=display_peak_info, - display_event_info=display_event_info, - s1_hp_kwargs=s1_hp_kwargs, - s2_hp_kwargs=s2_hp_kwargs, - event_time_limit=event_time_limit, - plot_all_positions=plot_all_positions, - ) - - -# Don't be smart with the arguments, since it is a minianalyses we -# need to have all the arguments -@straxen.mini_analysis(requires=('event_info',)) -def event_display(context, - run_id, - events, - to_pe, - records_matrix=True, - s2_fuzz=50, - s1_fuzz=0, - max_peaks=500, - xenon1t=False, - display_peak_info=PEAK_DISPLAY_DEFAULT_INFO, - display_event_info=EVENT_DISPLAY_DEFAULT_INFO, - s1_hp_kwargs=None, - s2_hp_kwargs=None, - event_time_limit=None, - plot_all_positions=True, - ): - """ - {event_docs} - {event_returns} - """ - if records_matrix not in ('raw', True, False): - raise ValueError('Choose either "raw", True or False for records_matrix') - if ((records_matrix == 'raw' and not context.is_stored(run_id, 'raw_records')) or - (isinstance(records_matrix, bool) and not context.is_stored(run_id, - 'records'))): # noqa - print("(raw)records not stored! Not showing records_matrix") - records_matrix = False - # Convert string to int to allow plots to be enlarged for extra panel - _rr_resize_int = int(bool(records_matrix)) - - fig = plt.figure(figsize=(25, 21 if _rr_resize_int else 16), - facecolor='white') - grid = plt.GridSpec((2 + _rr_resize_int), 1, hspace=0.1 + 0.1 * _rr_resize_int, - height_ratios=[1.5, 0.5, 0.5][:2 + _rr_resize_int] - ) - - # S1, S2, hitpatterns - gss_0 = gridspec.GridSpecFromSubplotSpec(2, 4, subplot_spec=grid[0], wspace=0.25, hspace=0.4) - ax_s1 = fig.add_subplot(gss_0[0]) - ax_s2 = fig.add_subplot(gss_0[1]) - ax_s1_hp_t = fig.add_subplot(gss_0[2]) - ax_s1_hp_b = fig.add_subplot(gss_0[3]) - ax_s2_hp_t = fig.add_subplot(gss_0[6]) - ax_s2_hp_b = fig.add_subplot(gss_0[7]) - - # Peak & event info - ax_event_info = fig.add_subplot(gss_0[4]) - ax_peak_info = fig.add_subplot(gss_0[5]) - - # All peaks in event - gss_1 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=grid[1]) - ax_ev = fig.add_subplot(gss_1[0]) - ax_rec = None - - # (raw)records matrix (optional) - if records_matrix and ax_rec is not None: - gss_2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=grid[2]) - ax_rec = fig.add_subplot(gss_2[0]) - axes = dict( - ax_s1=ax_s1, - ax_s2=ax_s2, - ax_s1_hp_t=ax_s1_hp_t, - ax_s1_hp_b=ax_s1_hp_b, - ax_event_info=ax_event_info, - ax_peak_info=ax_peak_info, - ax_s2_hp_t=ax_s2_hp_t, - ax_s2_hp_b=ax_s2_hp_b, - ax_ev=ax_ev, - ax_rec=ax_rec) - - return _event_display(context, - run_id, - events, - to_pe, - axes=axes, - records_matrix=records_matrix, - s2_fuzz=s2_fuzz, - s1_fuzz=s1_fuzz, - max_peaks=max_peaks, - xenon1t=xenon1t, - display_peak_info=display_peak_info, - display_event_info=display_event_info, - s1_hp_kwargs=s1_hp_kwargs, - s2_hp_kwargs=s2_hp_kwargs, - event_time_limit=event_time_limit, - plot_all_positions=plot_all_positions, - ) - - -def _event_display(context, - run_id, - events, - to_pe, - axes=None, - records_matrix=True, - s2_fuzz=50, - s1_fuzz=0, - max_peaks=500, - xenon1t=False, - display_peak_info=PEAK_DISPLAY_DEFAULT_INFO, - display_event_info=EVENT_DISPLAY_DEFAULT_INFO, - s1_hp_kwargs=None, - s2_hp_kwargs=None, - event_time_limit=None, - plot_all_positions=True, - ): - """{event_docs} - :param axes: if a dict of matplotlib axes (w/ same keys as below, - and empty/None for panels not filled) - {event_returns} - """ - if len(events) != 1: - raise ValueError(f'Found {len(events)} only request one') - event = events[0] - - if not context.is_stored(run_id, 'peaklets'): - raise strax.DataNotAvailable(f'peaklets not available for {run_id}') - - if axes is None: - raise ValueError(f'No axes provided') - ax_s1 = axes.get("ax_s1", None) - ax_s2 = axes.get("ax_s2", None) - ax_s1_hp_t = axes.get("ax_s1_hp_t", None) - ax_s1_hp_b = axes.get("ax_s1_hp_b", None) - ax_s2_hp_t = axes.get("ax_s2_hp_t", None) - ax_s2_hp_b = axes.get("ax_s2_hp_b", None) - ax_event_info = axes.get("ax_event_info", None) - ax_peak_info = axes.get("ax_peak_info", None) - ax_ev = axes.get("ax_ev", None) - ax_rec = axes.get("ax_rec", None) - - # titles - for ax, title in zip([ax_s1, ax_s1_hp_t, ax_s1_hp_b, - ax_s2, ax_s2_hp_t, ax_s2_hp_b, - ax_event_info, ax_peak_info], - ["Main S1", "S1 top", "S1 bottom", - "Main S2", "S2 top", "S2 bottom", - "Event info", "Peak info"]): - if ax is not None: - ax.set_title(title) - - # Parse the hit pattern options - # Convert to dict (not at function definition because of mutable defaults) - if s1_hp_kwargs is None: - s1_hp_kwargs = {} - if s2_hp_kwargs is None: - s2_hp_kwargs = {} - - # Hit patterns options: - for hp_opt, color_map in ((s1_hp_kwargs, "Blues"), (s2_hp_kwargs, "Greens")): - _common_opt = dict(xenon1t=xenon1t, - pmt_label_color='lightgrey', - log_scale=True, - vmin=0.1, - s=(250 if records_matrix else 220), - pmt_label_size=7, - edgecolor='grey', - dead_pmts=np.argwhere(to_pe == 0), - cmap=color_map) - # update s1 & S2 hit pattern kwargs with _common_opt if not - # specified by the user - for k, v in _common_opt.items(): - if k not in hp_opt: - hp_opt[k] = v - - # S1 - if events['s1_area'] != 0: - if ax_s1 is not None: - plt.sca(ax_s1) - context.plot_peaks(run_id, - time_range=(events['s1_time'] - s1_fuzz, - events['s1_endtime'] + s1_fuzz), - single_figure=False) - - # Hit pattern plots - area = context.get_array(run_id, 'peaklets', - time_range=(events['s1_time'], - events['s1_endtime']), - keep_columns=('area_per_channel', 'time', 'dt', 'length'), - progress_bar=False, - ) - for ax, array in ((ax_s1_hp_t, 'top'), (ax_s1_hp_b, 'bottom')): - if ax is not None: - plt.sca(ax) - straxen.plot_on_single_pmt_array(c=np.sum(area['area_per_channel'], axis=0), - array_name=array, - **s1_hp_kwargs) - # Mark reconstructed position - plt.scatter(event['x'], event['y'], marker='X', s=100, c='k') - - # S2 - if event['s2_area'] != 0: - if ax_s2 is not None: - plt.sca(ax_s2) - context.plot_peaks(run_id, - time_range=(events['s2_time'] - s2_fuzz, - events['s2_endtime'] + s2_fuzz), - single_figure=False) - - # Hit pattern plots - area = context.get_array(run_id, 'peaklets', - time_range=(events['s2_time'], - events['s2_endtime']), - keep_columns=('area_per_channel', 'time', 'dt', 'length'), - progress_bar=False, - ) - for axi, (ax, array) in enumerate([(ax_s2_hp_t, 'top'), (ax_s2_hp_b, 'bottom')]): - if ax is not None: - plt.sca(ax) - straxen.plot_on_single_pmt_array(c=np.sum(area['area_per_channel'], axis=0), - array_name=array, - **s2_hp_kwargs) - # Mark reconstructed position (corrected) - plt.scatter(event['x'], event['y'], marker='X', s=100, c='k') - if not xenon1t and axi == 0 and plot_all_positions: - _scatter_rec(event) - - # Fill panels with peak/event info - for it, (ax, labels_and_unit) in enumerate([(ax_event_info, display_event_info), - (ax_peak_info, display_peak_info)]): - if ax is not None: - for i, (_lab, _unit) in enumerate(labels_and_unit): - coord = 0.01, 0.9 - 0.9 * i / len(labels_and_unit) - ax.text(*coord, _lab[:24], va='top', zorder=-10) - ax.text(coord[0] + 0.5, coord[1], - _unit.format(v=event[_lab]), va='top', zorder=-10) - # Remove axes and labels from panel - ax.set_xticks([]) - ax.set_yticks([]) - _ = [s.set_visible(False) for s in ax.spines.values()] - - # Plot peaks in event - ev_range = None - if ax_ev is not None: - plt.sca(ax_ev) - if event_time_limit is None: - time_range = (events['time'], events['endtime']) - else: - time_range = event_time_limit - - context.plot_peaks(run_id, - time_range=time_range, - show_largest=max_peaks, - single_figure=False) - ev_range = plt.xlim() - - if records_matrix and ax_rec is not None: - plt.sca(ax_rec) - context.plot_records_matrix(run_id, - raw=records_matrix == 'raw', - time_range=(events['time'], - events['endtime']), - single_figure=False) - ax_rec.tick_params(axis='x', rotation=0) - if not xenon1t: - # Top vs bottom division - ax_rec.axhline(straxen.n_top_pmts, c='k') - if ev_range is not None: - plt.xlim(*ev_range) - - # Final tweaks - if ax_s2 is not None: - ax_s1.tick_params(axis='x', rotation=45) - if ax_s2 is not None: - ax_s1.tick_params(axis='x', rotation=45) - if ax_ev is not None: - ax_ev.tick_params(axis='x', rotation=0) - title = (f'Run {run_id}. Time ' - f'{str(events["time"])[:-9]}.{str(events["time"])[-9:]}\n' - f'{datetime.fromtimestamp(event["time"] / 1e9, tz=pytz.utc)}') - plt.suptitle(title, y=0.95) - # NB: reflects panels order - return (ax_s1, ax_s2, ax_s1_hp_t, ax_s1_hp_b, - ax_event_info, ax_peak_info, ax_s2_hp_t, ax_s2_hp_b, - ax_ev, - ax_rec) - - -@export -def plot_single_event(context: strax.Context, - run_id, - events, - event_number=None, - **kwargs): - """ - Wrapper for event_display - - :param context: strax.context - :param run_id: run id - :param events: dataframe / numpy array of events. Should either be - length 1 or the event_number argument should be provided - :param event_number: (optional) int, if provided, only show this - event number - :param kwargs: kwargs for events_display - :return: see events_display - """ - if event_number is not None: - events = events[events['event_number'] == event_number] - if len(events) > 1 or len(events) == 0: - raise ValueError(f'Make sure to provide an event number or a single ' - f'event. Got {len(events)} events') - - return context.event_display(run_id, - time_range=(events[0]['time'], - events[0]['endtime']), - **kwargs) - - -def _scatter_rec(_event, - recs=None, - scatter_kwargs=None, - ): - """Convenient wrapper to show posrec of three algorithms for xenonnt""" - if recs is None: - recs = ('mlp', 'cnn', 'gcn') - elif len(recs) > 5: - raise ValueError("I only got five markers/colors") - if scatter_kwargs is None: - scatter_kwargs = {} - scatter_kwargs.setdefault('s', 100) - scatter_kwargs.setdefault('alpha', 0.8) - shapes = ('v', '^', '>', '<', '*', 'D', "P") - colors = ('brown', 'orange', 'lightcoral', 'gold', 'lime', 'crimson') - for _i, _r in enumerate(recs): - x, y = _event[f's2_x_{_r}'], _event[f's2_y_{_r}'] - if np.isnan(x) or np.isnan(y): - continue - plt.scatter(x, y, - marker=shapes[_i], - c=colors[_i], - label=_r.upper(), - **scatter_kwargs, - ) - plt.legend(loc='best', fontsize="x-small", markerscale=0.5) - - -# Event display docstrings. -# Let's add them to the corresponding functions - -event_docs = """ -Make a waveform-display of a given event. Requires events, peaks and - peaklets (optionally: records). NB: time selection should return - only one event! - -:param context: strax.Context provided by the minianalysis wrapper -:param run_id: run-id of the event -:param events: events, provided by the minianalysis wrapper -:param to_pe: gains, provided by the minianalysis wrapper -:param records_matrix: False (no record matrix), True, or "raw" - (show raw-record matrix) -:param s2_fuzz: extra time around main S2 [ns] -:param s1_fuzz: extra time around main S1 [ns] -:param max_peaks: max peaks for plotting in the wf plot -:param xenon1t: True: is 1T, False: is nT -:param display_peak_info: tuple, items that will be extracted from - event and displayed in the event info panel see above for format -:param display_event_info: tuple, items that will be extracted from - event and displayed in the peak info panel see above for format -:param s1_hp_kwargs: dict, optional kwargs for S1 hitpatterns -:param s2_hp_kwargs: dict, optional kwargs for S2 hitpatterns -:param event_time_limit = overrides x-axis limits of event - plot -:param plot_all_positions if True, plot best-fit positions - from all posrec algorithms -""" -event_returns = """ -:return: axes used for plotting: - ax_s1, ax_s2, ax_s1_hp_t, ax_s1_hp_b, - ax_event_info, ax_peak_info, ax_s2_hp_t, ax_s2_hp_b, - ax_ev, - ax_rec - Where those panels (axes) are: - - ax_s1, main S1 peak - - ax_s2, main S2 peak - - ax_s1_hp_t, S1 top hit pattern - - ax_s1_hp_b, S1 bottom hit pattern - - ax_s2_hp_t, S2 top hit pattern - - ax_s2_hp_b, S2 bottom hit pattern - - ax_event_info, text info on the event - - ax_peak_info, text info on the main S1 and S2 - - ax_ev, waveform of the entire event - - ax_rec, (raw)record matrix (if any otherwise None) -""" - -# Add the same docstring to each of these functions -for event_function in (event_display, event_display_simple, _event_display): - doc = event_function.__doc__ - if doc is not None: - event_function.__doc__ = doc.format(event_docs=event_docs, - event_returns=event_returns) +from datetime import datetime + +import matplotlib.gridspec as gridspec +import matplotlib.pyplot as plt +import numpy as np +import pytz +import strax +import straxen + +export, __all__ = strax.exporter() + +# Default attributes to display in the event_display (looks little +# complicated but just repeats same fields for S1 S1) +# Should be of form as below where {v} wil be filled with the value of +# event['key']: +# (('key', '{v} UNIT'), ..) +PEAK_DISPLAY_DEFAULT_INFO = sum([[(k.format(i=s_i), u) for k, u in + (('cs{i}', '{v:.2f} PE'), + ('s{i}_area', '{v:.2f} PE'), + ('alt_cs{i}', '{v:.2f} PE'), + ('s{i}_n_channels', '{v}'), + ('s{i}_area_fraction_top', '{v:.2f}'), + ('s{i}_range_50p_area', '{v:.1f}'), + )] for s_i in (1, 2)], []) +EVENT_DISPLAY_DEFAULT_INFO = (('time', '{v} ns'), + ('endtime', '{v} ns'), + ('event_number', '{v}'), + ('x', '{v:.2f} cm'), + ('y', '{v:.2f} cm'), + ('z', '{v:.2f} cm'), + ('r', '{v:.2f} cm'), + ('theta', '{v:.2f} rad'), + ('drift_time', '{v} ns'), + ('alt_s1_interaction_drift_time', '{v} ns'), + ('alt_s2_interaction_drift_time', '{v} ns') + ) + + +# Don't be smart with the arguments, since it is a minianalyses we +# need to have all the arguments +@straxen.mini_analysis(requires=('event_info',)) +def event_display(context, + run_id, + events, + to_pe, + records_matrix=True, + s2_fuzz=50, + s1_fuzz=0, + max_peaks=500, + xenon1t=False, + s1_hp_kwargs=None, + s2_hp_kwargs=None, + event_time_limit=None, + plot_all_positions=True, + display_peak_info=PEAK_DISPLAY_DEFAULT_INFO, + display_event_info=EVENT_DISPLAY_DEFAULT_INFO, + simple_layout=False, + ): + """ + {event_docs} + {event_returns} + """ + if records_matrix not in ('raw', True, False): + raise ValueError('Choose either "raw", True or False for records_matrix') + if ((records_matrix == 'raw' and not context.is_stored(run_id, 'raw_records')) or + (isinstance(records_matrix, bool) and not context.is_stored(run_id, + 'records'))): # noqa + print("(raw)records not stored! Not showing records_matrix") + records_matrix = False + # Convert string to int to allow plots to be enlarged for extra panel + _rr_resize_int = int(bool(records_matrix)) + + if simple_layout: + axes = _event_display_simple_layout() + else: + axes = _event_display_full_layout(_rr_resize_int, records_matrix) + + return _event_display(context, + run_id, + events, + to_pe, + axes=axes, + records_matrix=records_matrix, + s2_fuzz=s2_fuzz, + s1_fuzz=s1_fuzz, + max_peaks=max_peaks, + xenon1t=xenon1t, + display_peak_info=display_peak_info, + display_event_info=display_event_info, + s1_hp_kwargs=s1_hp_kwargs, + s2_hp_kwargs=s2_hp_kwargs, + event_time_limit=event_time_limit, + plot_all_positions=plot_all_positions, + ) + + +def _event_display(context, + run_id, + events, + to_pe, + axes=None, + records_matrix=True, + s2_fuzz=50, + s1_fuzz=0, + max_peaks=500, + xenon1t=False, + display_peak_info=PEAK_DISPLAY_DEFAULT_INFO, + display_event_info=EVENT_DISPLAY_DEFAULT_INFO, + s1_hp_kwargs=None, + s2_hp_kwargs=None, + event_time_limit=None, + plot_all_positions=True, + ): + """{event_docs} + :param axes: if a dict of matplotlib axes (w/ same keys as below, + and empty/None for panels not filled) + {event_returns} + """ + if len(events) != 1: + raise ValueError(f'Found {len(events)} only request one') + event = events[0] + + if axes is None: + raise ValueError(f'No axes provided') + ax_s1 = axes.get("ax_s1", None) + ax_s2 = axes.get("ax_s2", None) + ax_s1_hp_t = axes.get("ax_s1_hp_t", None) + ax_s1_hp_b = axes.get("ax_s1_hp_b", None) + ax_s2_hp_t = axes.get("ax_s2_hp_t", None) + ax_s2_hp_b = axes.get("ax_s2_hp_b", None) + ax_event_info = axes.get("ax_event_info", None) + ax_peak_info = axes.get("ax_peak_info", None) + ax_ev = axes.get("ax_ev", None) + ax_rec = axes.get("ax_rec", None) + + # titles + for ax, title in zip([ax_s1, ax_s1_hp_t, ax_s1_hp_b, + ax_s2, ax_s2_hp_t, ax_s2_hp_b, + ax_event_info, ax_peak_info], + ["Main S1", "S1 top", "S1 bottom", + "Main S2", "S2 top", "S2 bottom", + "Event info", "Peak info"]): + if ax is not None: + ax.set_title(title) + + # Parse the hit pattern options + # Convert to dict (not at function definition because of mutable defaults) + if s1_hp_kwargs is None: + s1_hp_kwargs = {} + if s2_hp_kwargs is None: + s2_hp_kwargs = {} + + # Hit patterns options: + for hp_opt, color_map in ((s1_hp_kwargs, "Blues"), (s2_hp_kwargs, "Greens")): + _common_opt = dict(xenon1t=xenon1t, + pmt_label_color='lightgrey', + log_scale=True, + vmin=0.1, + s=(250 if records_matrix else 220), + pmt_label_size=7, + edgecolor='grey', + dead_pmts=np.argwhere(to_pe == 0), + cmap=color_map) + # update s1 & S2 hit pattern kwargs with _common_opt if not + # specified by the user + for k, v in _common_opt.items(): + if k not in hp_opt: + hp_opt[k] = v + + # S1 + if events['s1_area'] != 0: + if ax_s1 is not None: + plt.sca(ax_s1) + context.plot_peaks(run_id, + time_range=(events['s1_time'] - s1_fuzz, + events['s1_endtime'] + s1_fuzz), + single_figure=False) + + # Hit pattern plots + area = context.get_array(run_id, 'peaklets', + time_range=(events['s1_time'], + events['s1_endtime']), + keep_columns=('area_per_channel', 'time', 'dt', 'length'), + progress_bar=False, + ) + for ax, array in ((ax_s1_hp_t, 'top'), (ax_s1_hp_b, 'bottom')): + if ax is not None: + plt.sca(ax) + straxen.plot_on_single_pmt_array(c=np.sum(area['area_per_channel'], axis=0), + array_name=array, + **s1_hp_kwargs) + # Mark reconstructed position + plt.scatter(event['x'], event['y'], marker='X', s=100, c='k') + + # S2 + if event['s2_area'] != 0: + if ax_s2 is not None: + plt.sca(ax_s2) + context.plot_peaks(run_id, + time_range=(events['s2_time'] - s2_fuzz, + events['s2_endtime'] + s2_fuzz), + single_figure=False) + + # Hit pattern plots + area = context.get_array(run_id, 'peaklets', + time_range=(events['s2_time'], + events['s2_endtime']), + keep_columns=('area_per_channel', 'time', 'dt', 'length'), + progress_bar=False, + ) + for axi, (ax, array) in enumerate([(ax_s2_hp_t, 'top'), (ax_s2_hp_b, 'bottom')]): + if ax is not None: + plt.sca(ax) + straxen.plot_on_single_pmt_array(c=np.sum(area['area_per_channel'], axis=0), + array_name=array, + **s2_hp_kwargs) + # Mark reconstructed position (corrected) + plt.scatter(event['x'], event['y'], marker='X', s=100, c='k') + if not xenon1t and axi == 0 and plot_all_positions: + _scatter_rec(event) + + # Fill panels with peak/event info + for it, (ax, labels_and_unit) in enumerate([(ax_event_info, display_event_info), + (ax_peak_info, display_peak_info)]): + if ax is not None: + for i, (_lab, _unit) in enumerate(labels_and_unit): + coord = 0.01, 0.9 - 0.9 * i / len(labels_and_unit) + ax.text(*coord, _lab[:24], va='top', zorder=-10) + ax.text(coord[0] + 0.5, coord[1], + _unit.format(v=event[_lab]), va='top', zorder=-10) + # Remove axes and labels from panel + ax.set_xticks([]) + ax.set_yticks([]) + _ = [s.set_visible(False) for s in ax.spines.values()] + + # Plot peaks in event + ev_range = None + if ax_ev is not None: + plt.sca(ax_ev) + if event_time_limit is None: + time_range = (events['time'], events['endtime']) + else: + time_range = event_time_limit + + context.plot_peaks(run_id, + time_range=time_range, + show_largest=max_peaks, + single_figure=False) + ev_range = plt.xlim() + + if records_matrix and ax_rec is not None: + plt.sca(ax_rec) + context.plot_records_matrix(run_id, + raw=records_matrix == 'raw', + time_range=(events['time'], + events['endtime']), + single_figure=False) + ax_rec.tick_params(axis='x', rotation=0) + if not xenon1t: + # Top vs bottom division + ax_rec.axhline(straxen.n_top_pmts, c='k') + if ev_range is not None: + plt.xlim(*ev_range) + + # Final tweaks + if ax_s2 is not None: + ax_s1.tick_params(axis='x', rotation=45) + if ax_s2 is not None: + ax_s1.tick_params(axis='x', rotation=45) + if ax_ev is not None: + ax_ev.tick_params(axis='x', rotation=0) + title = (f'Run {run_id}. Time ' + f'{str(events["time"])[:-9]}.{str(events["time"])[-9:]}\n' + f'{datetime.fromtimestamp(event["time"] / 1e9, tz=pytz.utc)}') + plt.suptitle(title, y=0.95) + # NB: reflects panels order + return (ax_s1, ax_s2, ax_s1_hp_t, ax_s1_hp_b, + ax_event_info, ax_peak_info, ax_s2_hp_t, ax_s2_hp_b, + ax_ev, + ax_rec) + + +@straxen.mini_analysis(requires=('event_info',)) +def event_display_simple(context, run_id, events, **kwargs): + raise NotImplementedError('Pass st.event_display(.., simple_layout=True)') + + +def _event_display_simple_layout() -> dict: + """Setup a simple gidspec for the event display""" + fig = plt.figure(figsize=(12, 8), facecolor="white") + grid = plt.GridSpec(2, 3, hspace=0.5) + axes = dict() + axes["ax_s1"] = fig.add_subplot(grid[0, 0]) + axes["ax_s2"] = fig.add_subplot(grid[0, 1]) + axes["ax_s2_hp_t"] = fig.add_subplot(grid[0, 2]) + axes["ax_ev"] = fig.add_subplot(grid[1, :]) + return axes + + +def _event_display_full_layout(_rr_resize_int, records_matrix) -> dict: + """Setup the full gidspec for the event display""" + fig = plt.figure(figsize=(25, 21 if _rr_resize_int else 16), + facecolor='white') + grid = plt.GridSpec((2 + _rr_resize_int), + 1, + hspace=0.1 + 0.1 * _rr_resize_int, + height_ratios=[1.5, 0.5, 0.5][:2 + _rr_resize_int] + ) + + # S1, S2, hitpatterns + gss_0 = gridspec.GridSpecFromSubplotSpec(2, 4, + subplot_spec=grid[0], + wspace=0.25, + hspace=0.4) + ax_s1 = fig.add_subplot(gss_0[0]) + ax_s2 = fig.add_subplot(gss_0[1]) + ax_s1_hp_t = fig.add_subplot(gss_0[2]) + ax_s1_hp_b = fig.add_subplot(gss_0[3]) + ax_s2_hp_t = fig.add_subplot(gss_0[6]) + ax_s2_hp_b = fig.add_subplot(gss_0[7]) + + # Peak & event info + ax_event_info = fig.add_subplot(gss_0[4]) + ax_peak_info = fig.add_subplot(gss_0[5]) + + # All peaks in event + gss_1 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=grid[1]) + ax_ev = fig.add_subplot(gss_1[0]) + ax_rec = None + + # (raw)records matrix (optional) + if records_matrix: + gss_2 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=grid[2]) + ax_rec = fig.add_subplot(gss_2[0]) + axes = dict( + ax_s1=ax_s1, + ax_s2=ax_s2, + ax_s1_hp_t=ax_s1_hp_t, + ax_s1_hp_b=ax_s1_hp_b, + ax_event_info=ax_event_info, + ax_peak_info=ax_peak_info, + ax_s2_hp_t=ax_s2_hp_t, + ax_s2_hp_b=ax_s2_hp_b, + ax_ev=ax_ev, + ax_rec=ax_rec) + return axes + + +@export +def plot_single_event(context: strax.Context, + run_id, + events, + event_number=None, + **kwargs): + """ + Wrapper for event_display + + :param context: strax.context + :param run_id: run id + :param events: dataframe / numpy array of events. Should either be + length 1 or the event_number argument should be provided + :param event_number: (optional) int, if provided, only show this + event number + :param kwargs: kwargs for events_display + :return: see events_display + """ + if event_number is not None: + events = events[events['event_number'] == event_number] + if len(events) > 1 or len(events) == 0: + raise ValueError(f'Make sure to provide an event number or a single ' + f'event. Got {len(events)} events') + + return context.event_display(run_id, + time_range=(events[0]['time'], + events[0]['endtime']), + **kwargs) + + +def _scatter_rec(_event, + recs=None, + scatter_kwargs=None, + ): + """Convenient wrapper to show posrec of three algorithms for xenonnt""" + if recs is None: + recs = ('mlp', 'cnn', 'gcn') + elif len(recs) > 5: + raise ValueError("I only got five markers/colors") + if scatter_kwargs is None: + scatter_kwargs = {} + scatter_kwargs.setdefault('s', 100) + scatter_kwargs.setdefault('alpha', 0.8) + shapes = ('v', '^', '>', '<', '*', 'D', "P") + colors = ('brown', 'orange', 'lightcoral', 'gold', 'lime', 'crimson') + for _i, _r in enumerate(recs): + x, y = _event[f's2_x_{_r}'], _event[f's2_y_{_r}'] + if np.isnan(x) or np.isnan(y): + continue + plt.scatter(x, y, + marker=shapes[_i], + c=colors[_i], + label=_r.upper(), + **scatter_kwargs, + ) + plt.legend(loc='best', fontsize="x-small", markerscale=0.5) + + +# Event display docstrings. +# Let's add them to the corresponding functions + +event_docs = """ +Make a waveform-display of a given event. Requires events, peaks and + peaklets (optionally: records). NB: time selection should return + only one event! + +:param context: strax.Context provided by the minianalysis wrapper +:param run_id: run-id of the event +:param events: events, provided by the minianalysis wrapper +:param to_pe: gains, provided by the minianalysis wrapper +:param records_matrix: False (no record matrix), True, or "raw" + (show raw-record matrix) +:param s2_fuzz: extra time around main S2 [ns] +:param s1_fuzz: extra time around main S1 [ns] +:param max_peaks: max peaks for plotting in the wf plot +:param xenon1t: True: is 1T, False: is nT +:param display_peak_info: tuple, items that will be extracted from + event and displayed in the event info panel see above for format +:param display_event_info: tuple, items that will be extracted from + event and displayed in the peak info panel see above for format +:param s1_hp_kwargs: dict, optional kwargs for S1 hitpatterns +:param s2_hp_kwargs: dict, optional kwargs for S2 hitpatterns +:param event_time_limit = overrides x-axis limits of event + plot +:param plot_all_positions if True, plot best-fit positions + from all posrec algorithms +""" +event_returns = """ +:return: axes used for plotting: + ax_s1, ax_s2, ax_s1_hp_t, ax_s1_hp_b, + ax_event_info, ax_peak_info, ax_s2_hp_t, ax_s2_hp_b, + ax_ev, + ax_rec + Where those panels (axes) are: + - ax_s1, main S1 peak + - ax_s2, main S2 peak + - ax_s1_hp_t, S1 top hit pattern + - ax_s1_hp_b, S1 bottom hit pattern + - ax_s2_hp_t, S2 top hit pattern + - ax_s2_hp_b, S2 bottom hit pattern + - ax_event_info, text info on the event + - ax_peak_info, text info on the main S1 and S2 + - ax_ev, waveform of the entire event + - ax_rec, (raw)record matrix (if any otherwise None) +""" + +# Add the same docstring to each of these functions +for event_function in (event_display, event_display_simple, _event_display): + doc = event_function.__doc__ + if doc is not None: + event_function.__doc__ = doc.format(event_docs=event_docs, + event_returns=event_returns) diff --git a/straxen/analyses/holoviews_waveform_display.py b/straxen/analyses/holoviews_waveform_display.py index d5e4162af..7206273a0 100644 --- a/straxen/analyses/holoviews_waveform_display.py +++ b/straxen/analyses/holoviews_waveform_display.py @@ -197,7 +197,8 @@ def _hvdisp_plot_records_2d(records, def plot_record_polygons(record_points, center_time=True, - scaling=10**-3,): + scaling=10**-3, + ): """ Plots record hv.Points as polygons for record matrix. @@ -389,6 +390,7 @@ def wrapped(x_range, **kwargzz): t_reference + int(x_range[1] * 1e9)), t_reference=t_reference, **kwargs) + return wrapped diff --git a/straxen/analyses/posrec_comparison.py b/straxen/analyses/posrec_comparison.py index 8cd729c98..7bcc6060c 100644 --- a/straxen/analyses/posrec_comparison.py +++ b/straxen/analyses/posrec_comparison.py @@ -9,7 +9,6 @@ def load_corrected_positions(context, run_id, events, cmt_version=None, posrec_algos=('mlp', 'gcn', 'cnn')): - """ Returns the corrected position for each position algorithm available, without the need to reprocess event_basics, as the needed @@ -21,9 +20,9 @@ def load_corrected_positions(context, run_id, events, :param posrec_algos: list of position reconstruction algorithms to use (default ['mlp', 'gcn', 'cnn']) """ - + posrec_algos = strax.to_str_tuple(posrec_algos) - + if cmt_version is None: fdc_config = None try: @@ -31,24 +30,30 @@ def load_corrected_positions(context, run_id, events, cmt_version = fdc_config[1] except IndexError as e: raise ValueError(f'CMT is not set? Your fdc config is {fdc_config}') from e - - if hasattr(cmt_version, '__len__') and not isinstance(cmt_version, str) and len(cmt_version) != len(posrec_algos): - raise TypeError(f"cmt_version is a list but does not match the posrec_algos ({posrec_algos}) length.") - cmt_version = (cmt_version, ) * len(posrec_algos) if isinstance(cmt_version, str) else cmt_version + if ( + isinstance(cmt_version, (tuple, list)) + and len(cmt_version) != len(posrec_algos) + ): + raise TypeError(f"cmt_version is a list but does not match the " + f"posrec_algos ({posrec_algos}) length.") + + cmt_version = ((cmt_version, ) * len(posrec_algos) + if isinstance(cmt_version, str) else cmt_version) # Get drift from CMT - drift_conf = context.get_single_plugin(run_id, 'event_positions').config.get('electron_drift_velocity') + ep = context.get_single_plugin(run_id, 'event_positions') + drift_conf = ep.config.get('electron_drift_velocity') drift_speed = straxen.get_correction_from_cmt(run_id, drift_conf) dtype = [] - + for algo in posrec_algos: for xyzr in 'x y z r'.split(): dtype += [ ((f'Interaction {xyzr}-position, field-distortion corrected (cm) - ' f'{algo.upper()} posrec algorithm', f'{xyzr}_{algo}'), np.float32), - ] + ] dtype += [ ((f'Interaction r-position using observed S2 positions directly (cm) -' f' {algo.upper()} posrec algorithm', f'r_naive_{algo}'), @@ -59,12 +64,12 @@ def load_corrected_positions(context, run_id, events, ((f'Interaction angular position (radians) - {algo.upper()} ' f'posrec algorithm', f'theta_{algo}'), np.float32)] - + dtype += [(('Interaction z-position using mean drift velocity only (cm)', 'z_naive'), np.float32)] result = np.zeros(len(events), dtype=dtype) z_obs = - drift_speed * events['drift_time'] - + for algo, v_cmt in zip(posrec_algos, cmt_version): fdc_tmp = (f'fdc_map_{algo}', v_cmt, True) map_tmp = straxen.get_correction_from_cmt(run_id, fdc_tmp) diff --git a/straxen/analyses/pulse_plots.py b/straxen/analyses/pulse_plots.py index 46a1ea0fb..f78efae61 100644 --- a/straxen/analyses/pulse_plots.py +++ b/straxen/analyses/pulse_plots.py @@ -1,39 +1,41 @@ -import straxen -import strax -import numpy as np -import matplotlib.pyplot as plt import os +import matplotlib.pyplot as plt +import numpy as np +import strax +import straxen + @straxen.mini_analysis(requires=('raw_records',), warn_beyond_sec=5) -def plot_pulses_tpc(context, raw_records, run_id, time_range, +def plot_pulses_tpc(context, raw_records, run_id, time_range=None, plot_hits=False, plot_median=False, max_plots=20, store_pdf=False, path=''): plot_pulses(context, raw_records, run_id, time_range, plot_hits, plot_median, max_plots, store_pdf, path) - + + @straxen.mini_analysis(requires=('raw_records_mv',), warn_beyond_sec=5) -def plot_pulses_mv(context, raw_records_mv, run_id, time_range, - plot_hits=False, plot_median=False, - max_plots=20, store_pdf=False, path=''): +def plot_pulses_mv(context, raw_records_mv, run_id, time_range=None, + plot_hits=False, plot_median=False, + max_plots=20, store_pdf=False, path=''): plot_pulses(context, raw_records_mv, run_id, time_range, plot_hits, plot_median, max_plots, store_pdf, path, detector_ending='_mv') @straxen.mini_analysis(requires=('raw_records_nv',), warn_beyond_sec=5) -def plot_pulses_nv(context, raw_records_nv, run_id, time_range, +def plot_pulses_nv(context, raw_records_nv, run_id, time_range=None, plot_hits=False, plot_median=False, max_plots=20, store_pdf=False, path=''): plot_pulses(context, raw_records_nv, run_id, time_range, plot_hits, plot_median, max_plots, store_pdf, path, detector_ending='_nv') - + def plot_pulses(context, raw_records, run_id, time_range, plot_hits=False, plot_median=False, - max_plots=20, store_pdf=False, path='', + max_plots=20, store_pdf=False, path='', detector_ending=''): """ Plots nveto pulses for a list of records. @@ -53,14 +55,14 @@ def plot_pulses(context, raw_records, run_id, time_range, string for TPC '_nv' for neutron-veto and '_mv' muon-veto. """ # Register records plugin to get settings - p = context.get_single_plugin(run_id, 'records'+detector_ending) + p = context.get_single_plugin(run_id, 'records' + detector_ending) # Compute strax baseline and baseline_rms: records = strax.raw_to_records(raw_records) records = strax.sort_by_time(records) strax.zero_out_of_bounds(records) - baseline_key = [key for key in p.config.keys() if 'baseline_samples' in key][0] + baseline_key = [key for key in p.config.keys() if 'baseline_samples' in key][0] if isinstance(p.config[baseline_key], int): baseline_samples = p.config[baseline_key] @@ -73,12 +75,16 @@ def plot_pulses(context, raw_records, run_id, time_range, flip=True) nfigs = 1 + if store_pdf and time_range is None: + raise ValueError(f'Specify time range!') if store_pdf: from matplotlib.backends.backend_pdf import PdfPages fname = f'pulses_{run_id}_{time_range[0]}_{time_range[1]}.pdf' fname = os.path.join(path, fname) pdf = PdfPages(fname) + hits = None # needed for delete if false + for inds in _yield_pulse_indices(raw_records): # Grouped our pulse so now plot: rr_pulse = raw_records[inds] @@ -93,18 +99,19 @@ def plot_pulses(context, raw_records, run_id, time_range, baseline = r_pulse[0]['baseline'] baseline_rms = r_pulse[0]['baseline_rms'] axes.axhline(baseline, ls='solid', - color='k', label=f'Strax Bas. +/-RMS:\n ({baseline:.2f}+/-{baseline_rms:.2f}) ADC') + color='k', + label=f'Strax Bas. +/-RMS:\n ({baseline:.2f}+/-{baseline_rms:.2f}) ADC') xlim = axes.get_xlim() axes.fill_between(xlim, [baseline + baseline_rms] * 2, [baseline - baseline_rms] * 2, color='gray', alpha=0.4 ) - + # check type of p.hit_thresholds - if isinstance(p.hit_thresholds,int): + if isinstance(p.hit_thresholds, int): thr = p.hit_thresholds - elif isinstance(p.hit_thresholds,np.ndarray): + elif isinstance(p.hit_thresholds, np.ndarray): thr = p.hit_thresholds[rr_pulse['channel']][0] if plot_median: @@ -121,10 +128,9 @@ def plot_pulses(context, raw_records, run_id, time_range, ls='dotted', color='orange' ) - hits = None # needed for delet if false if plot_hits: min_amplitude = thr - + axes.axhline(baseline - min_amplitude, color='orange', label='Hitfinder threshold') @@ -133,7 +139,7 @@ def plot_pulses(context, raw_records, run_id, time_range, ) if detector_ending != '_he': # We don't have 'save_outside_hits_he' at all! - le, re = p.config['save_outside_hits'+detector_ending] + le, re = p.config['save_outside_hits' + detector_ending] else: le, re = p.config['save_outside_hits'] start = (hits['time'] - r_pulse[0]['time']) / r_pulse[0]['dt'] - le diff --git a/straxen/analyses/quick_checks.py b/straxen/analyses/quick_checks.py index 98ecee2c1..fa22a56ea 100644 --- a/straxen/analyses/quick_checks.py +++ b/straxen/analyses/quick_checks.py @@ -1,8 +1,8 @@ -import numpy as np -from multihist import Hist1d, Histdd import matplotlib.pyplot as plt - +import numpy as np import straxen +from multihist import Hist1d, Histdd +from matplotlib.colors import LogNorm @straxen.mini_analysis(requires=('peak_basics',)) @@ -43,18 +43,18 @@ def std_axes(): plt.xlabel("Area [PE]") plt.ylabel("Range 50% area [ns]") labels = [ - (12, 8, "AP?", 'white'), - (3, 150, "1PE\npileup", 'gray'), + (12, 8, "AP?", 'white'), + (3, 150, "1PE\npileup", 'gray'), - (30, 200, "1e", 'gray'), - (100, 1000, "n-e", 'w'), - (2000, 2e4, "Train", 'gray'), + (30, 200, "1e", 'gray'), + (100, 1000, "n-e", 'w'), + (2000, 2e4, "Train", 'gray'), - (1200, 50, "S1", 'w'), - (45e3, 60, "αS1", 'w'), + (1200, 50, "S1", 'w'), + (45e3, 60, "αS1", 'w'), - (2e5, 800, "S2", 'w'), - ] + list(extra_labels) + (2e5, 800, "S2", 'w'), + ] + list(extra_labels) for x, w, text, color in labels: plt.text(x, w, text, color=color, @@ -63,8 +63,7 @@ def std_axes(): plt.sca(axes[0]) (mh / livetime_sec).sum(axis=2).plot( - log_scale=True, - vmin=rate_range[0], vmax=rate_range[1], + norm=LogNorm(vmin=rate_range[0], vmax=rate_range[1]), colorbar_kwargs=dict(extend='both'), cblabel='Peaks / (bin * s)') std_axes() @@ -73,7 +72,8 @@ def std_axes(): mh.average(axis=2).plot( vmin=aft_range[0], vmax=aft_range[1], colorbar_kwargs=dict(extend='max'), - cmap=plt.cm.jet, cblabel='Mean area fraction top') + cmap=plt.cm.jet, + cblabel='Mean area fraction top') std_axes() plt.tight_layout() @@ -138,8 +138,8 @@ def event_scatter(context, run_id, events, x = np.geomspace(*el_lim, num=1000) e_label = 1.2e-3 for e_const, label in [ - (0.1, ''), (1, '1\nkeV'), (10, '10\nkeV'), - (100, '100\nkeV'), (1e3, '1\nMeV'), (1e4, '')]: + (0.1, ''), (1, '1\nkeV'), (10, '10\nkeV'), + (100, '100\nkeV'), (1e3, '1\nMeV'), (1e4, '')]: plt.plot(x, e_const - x, c='k', alpha=0.2) plt.text(e_const - e_label, e_label, label, bbox=dict(facecolor='white', alpha=0.5, edgecolor='none'), @@ -196,8 +196,8 @@ def plot_energy_spectrum( if exposure_kg_sec is not None: unit = 'kg_day_kev' else: - unit = 'events' - + unit = 'events' + h = Hist1d(events['e_ces'], bins=(np.geomspace if geomspace else np.linspace)( min_energy, max_energy, n_bins)) @@ -205,6 +205,8 @@ def plot_energy_spectrum( if unit == 'events': scale, ylabel = 1, 'Events per bin' else: + if exposure_kg_sec is None: + raise ValueError('you did not specify exposure_kg_sec') exposure_kg_day = exposure_kg_sec / (3600 * 24) if unit == 'kg_day_kev': scale = exposure_kg_day @@ -218,13 +220,13 @@ def plot_energy_spectrum( else: raise ValueError(f"Invalid unit {unit}") scale *= h.bin_volumes() - + h.plot(errors=errors, error_style='band', color=color, label=label, linewidth=1, - scale_histogram_by=1/scale, + scale_histogram_by=1 / scale, error_alpha=error_alpha) plt.yscale('log') if geomspace: diff --git a/straxen/analyses/records_matrix.py b/straxen/analyses/records_matrix.py index efcc4d18b..d37918d26 100644 --- a/straxen/analyses/records_matrix.py +++ b/straxen/analyses/records_matrix.py @@ -2,7 +2,6 @@ import numba import numpy as np - import strax import straxen @@ -95,12 +94,17 @@ def raw_records_matrix(context, run_id, raw_records, time_range, **kwargs) -@numba.njit def _records_to_matrix(records, t0, window, n_channels, dt=10): - n_samples = window // dt + if np.any(records['amplitude_bit_shift'] > 0): + warnings.warn('Ignoring amplitude bitshift!') + return _records_to_matrix_inner(records, t0, window, n_channels, dt) + + +@numba.njit +def _records_to_matrix_inner(records, t0, window, n_channels, dt=10): + n_samples = (window // dt) + 1 # Use 32-bit integers, so downsampling saturated samples doesn't # cause wraparounds - # TODO: amplitude bit shift! y = np.zeros((n_samples, n_channels), dtype=np.int32) @@ -114,7 +118,12 @@ def _records_to_matrix(records, t0, window, n_channels, dt=10): if dt >= samples_per_record * r['dt']: # Downsample to single sample -> store area - y[(r['time'] - t0) // dt, r['channel']] += r['area'] + idx = (r['time'] - t0) // dt + if idx >= len(y): + print(len(y), idx) + raise IndexError('Despite n_samples = window // dt + 1, our ' + 'idx is too high?!') + y[idx, r['channel']] += r['area'] continue # Assume out-of-bounds data has been zeroed, so we do not @@ -125,7 +134,8 @@ def _records_to_matrix(records, t0, window, n_channels, dt=10): if dt > r['dt']: # Downsample duration = samples_per_record * r['dt'] - assert duration % dt == 0, "Cannot downsample fractionally" + if duration % dt != 0: + raise ValueError("Cannot downsample fractionally") # .astype here keeps numba happy ... ?? w = w.reshape(duration // dt, -1).sum(axis=1).astype(np.int32) diff --git a/straxen/analyses/waveform_plot.py b/straxen/analyses/waveform_plot.py index cd7a2058b..4f726e604 100644 --- a/straxen/analyses/waveform_plot.py +++ b/straxen/analyses/waveform_plot.py @@ -1,13 +1,12 @@ -import numpy as np import matplotlib import matplotlib.pyplot as plt -import warnings +import numpy as np import strax import straxen from mpl_toolkits.axes_grid1 import inset_locator -from datetime import datetime -from .records_matrix import DEFAULT_MAX_SAMPLES + from .daq_waveforms import group_by_daq +from .records_matrix import DEFAULT_MAX_SAMPLES export, __all__ = strax.exporter() __all__ += ['plot_wf'] @@ -44,6 +43,7 @@ def plot_waveform(context, else: f, axes = plt.subplots(2, 1, + constrained_layout=True, figsize=figsize, gridspec_kw={'height_ratios': [1, lower_panel_height]}) @@ -60,7 +60,6 @@ def plot_waveform(context, raw=deep == 'raw', single_figure=False) - straxen.quiet_tight_layout() plt.subplots_adjust(hspace=0) @@ -132,8 +131,8 @@ def plot_records_matrix(context, run_id, group_by=None, max_samples=DEFAULT_MAX_SAMPLES, ignore_max_sample_warning=False, - vmin = None, - vmax = None, + vmin=None, + vmax=None, **kwargs): if seconds_range is None: raise ValueError( @@ -141,7 +140,7 @@ def plot_records_matrix(context, run_id, "to plot_records_matrix.") if single_figure: - plt.figure(figsize=figsize) + plt.figure(figsize=figsize, constrained_layout=True) f = context.raw_records_matrix if raw else context.records_matrix @@ -150,7 +149,7 @@ def plot_records_matrix(context, run_id, ignore_max_sample_warning=ignore_max_sample_warning, **kwargs) if group_by is not None: - ylabs, wvm_mask = group_by_daq(context, run_id, group_by) + ylabs, wvm_mask = group_by_daq(run_id, group_by) wvm = wvm[:, wvm_mask] plt.ylabel(group_by) else: @@ -210,9 +209,6 @@ def plot_records_matrix(context, run_id, plt.sca(ax) - if single_figure: - straxen.quiet_tight_layout() - def seconds_range_xaxis(seconds_range, t0=None): """Make a pretty time axis given seconds_range""" diff --git a/straxen/common.py b/straxen/common.py index 0fa9fddf0..7405adbca 100644 --- a/straxen/common.py +++ b/straxen/common.py @@ -2,16 +2,13 @@ import configparser import gzip import inspect -import io +import typing as ty import commentjson import json import os import os.path as osp import pickle import dill -import socket -import sys -import tarfile import urllib.request import tqdm import numpy as np @@ -26,7 +23,7 @@ __all__ += ['straxen_dir', 'first_sr1_run', 'tpc_r', 'tpc_z', 'aux_repo', 'n_tpc_pmts', 'n_top_pmts', 'n_hard_aqmon_start', 'ADC_TO_E', 'n_nveto_pmts', 'n_mveto_pmts', 'tpc_pmt_radius', 'cryostat_outer_radius', - 'INFINITY_64BIT_SIGNED'] + 'perp_wire_angle', 'perp_wire_x_rot_pos', 'INFINITY_64BIT_SIGNED'] straxen_dir = os.path.dirname(os.path.abspath( inspect.getfile(inspect.currentframe()))) @@ -45,6 +42,9 @@ tpc_pmt_radius = 7.62 / 2 # cm +perp_wire_angle = np.deg2rad(30) +perp_wire_x_rot_pos = 13.06 #[cm] + # Convert from ADC * samples to electrons emitted by PMT # see pax.dsputils.adc_to_pe for calculation. Saving this number in straxen as # it's needed in analyses @@ -56,6 +56,31 @@ INFINITY_64BIT_SIGNED = 9223372036854775807 + +@export +def rotate_perp_wires(x_obs: np.ndarray, + y_obs: np.ndarray, + angle_extra: ty.Union[float, int] = 0): + """ + Returns x and y in the rotated plane where the perpendicular wires + area vertically aligned (parallel to the y-axis). Accepts addition to the + rotation angle with `angle_extra` [deg] + + :param x_obs: array of x coordinates + :param y_obs: array of y coordinates + :param angle_extra: extra rotation in [deg] + :return: x_rotated, y_rotated + """ + if len(x_obs) != len(y_obs): + raise ValueError('x and y are not of the same length') + angle_extra_rad = np.deg2rad(angle_extra) + x_rot = (np.cos(perp_wire_angle + angle_extra_rad) * x_obs + - np.sin(perp_wire_angle + angle_extra_rad) * y_obs) + y_rot = (np.sin(perp_wire_angle + angle_extra_rad) * x_obs + + np.cos(perp_wire_angle + angle_extra_rad) * y_obs) + return x_rot, y_rot + + @export def pmt_positions(xenon1t=False): """Return pandas dataframe with PMT positions @@ -94,9 +119,10 @@ def open_resource(file_name: str, fmt='text'): :param fmt: format of the file :return: opened file """ - if file_name in _resource_cache: + cached_name = _cache_name(file_name, fmt) + if cached_name in _resource_cache: # Retrieve from in-memory cache - return _resource_cache[file_name] + return _resource_cache[cached_name] # File resource if fmt in ['npy', 'npy_pickle']: result = np.load(file_name, allow_pickle=fmt == 'npy_pickle') @@ -136,7 +162,7 @@ def open_resource(file_name: str, fmt='text'): raise ValueError(f"Unsupported format {fmt}!") # Store in in-memory cache - _resource_cache[file_name] = result + _resource_cache[cached_name] = result return result @@ -161,8 +187,9 @@ def get_resource(x: str, fmt='text'): specified format """ # 1. load from memory - if x in _resource_cache: - return _resource_cache[x] + cached_name = _cache_name(x, fmt) + if cached_name in _resource_cache: + return _resource_cache[cached_name] # 2. load from file elif os.path.exists(x): return open_resource(x, fmt=fmt) @@ -180,6 +207,11 @@ def get_resource(x: str, fmt='text'): f'cannot download it from anywhere.') +def _cache_name(name: str, fmt: str)->str: + """Return a name under which to store the requested name with the given format in the _cache""" + return f'{fmt}::{name}' + + # Legacy loader for public URL files def resource_from_url(html: str, fmt='text'): """ @@ -243,70 +275,6 @@ def resource_from_url(html: str, fmt='text'): return result -@export -def get_secret(x): - """Return secret key x. In order of priority, we search: - * Environment variable: uppercase version of x - * xenon_secrets.py (if included with your straxen installation) - * A standard xenon_secrets.py located on the midway analysis hub - (if you are running on midway) - """ - warn("xenon_secrets is deprecated, and will be replaced with utilix" - "configuration file instead. See https://github.com/XENONnT/utilix") - env_name = x.upper() - if env_name in os.environ: - return os.environ[env_name] - - message = (f"Secret {x} requested, but there is no environment " - f"variable {env_name}, ") - - # now try using utilix. We need to check that it is not None first! - # this will be main method in a future release - if straxen.uconfig is not None and straxen.uconfig.has_option('straxen', x): - try: - return straxen.uconfig.get('straxen', x) - except configparser.NoOptionError: - warn(f'straxen.uconfig does not have {x}') - - # if that doesn't work, revert to xenon_secrets - try: - from . import xenon_secrets - except ImportError: - message += ("nor was there a valid xenon_secrets.py " - "included with your straxen installation, ") - - # If on midway, try loading a standard secrets file instead - if 'rcc' in socket.getfqdn(): - path_to_secrets = '/project2/lgrandi/xenonnt/xenon_secrets.py' - if os.path.exists(path_to_secrets): - sys.path.append(osp.dirname(path_to_secrets)) - import xenon_secrets - sys.path.pop() - else: - raise ValueError( - message + ' nor could we load the secrets module from ' - f'{path_to_secrets}, even though you seem ' - 'to be on the midway analysis hub.') - - else: - raise ValueError( - message + 'nor are you on the midway analysis hub.') - - if hasattr(xenon_secrets, x): - return getattr(xenon_secrets, x) - raise ValueError(message + " and the secret is not in xenon_secrets.py") - - -@export -def download_test_data(): - """Downloads strax test data to strax_test_data in the current directory""" - blob = get_resource('https://raw.githubusercontent.com/XENONnT/strax_auxiliary_files/609b492e1389369734c7d2cbabb38059f14fc05e/strax_files/strax_test_data_straxv0.9.tar', # noqa - fmt='binary') - f = io.BytesIO(blob) - tf = tarfile.open(fileobj=f) - tf.extractall() - - @export def get_livetime_sec(context, run_id, things=None): """Get the livetime of a run in seconds. If it is not in the run metadata, @@ -342,7 +310,7 @@ def pre_apply_function(data, run_id, target, function_name='pre_apply_function') if function_name not in _resource_cache: # only load the function once and put it in the resource cache function_file = f'{function_name}.py' - function_file = _overwrite_testing_function_file(function_file) + function_file = straxen.test_utils._overwrite_testing_function_file(function_file) function = get_resource(function_file, fmt='txt') # pylint: disable=exec-used exec(function) @@ -352,32 +320,9 @@ def pre_apply_function(data, run_id, target, function_name='pre_apply_function') return data -def _overwrite_testing_function_file(function_file): - """For testing purposes allow this function file to be loaded from HOME/testing_folder""" - if not straxen._is_on_pytest(): - # If we are not on a pytest, never try using a local file. - return function_file - - home = os.environ.get('HOME') - if home is None: - # Impossible to load from non-existent folder - return function_file - - testing_file = os.path.join(home, function_file) - - if os.path.exists(testing_file): - # For testing purposes allow loading from 'home/testing_folder' - warn(f'Using local function: {function_file} from {testing_file}! ' - f'If you are not integrated testing on github you should ' - f'absolutely remove this file. (See #559)') - function_file = testing_file - - return function_file - - @export def check_loading_allowed(data, run_id, target, - max_in_disallowed = 1, + max_in_disallowed=1, disallowed=('event_positions', 'corrected_areas', 'energy_estimates') @@ -431,16 +376,6 @@ def remap_channels(data, verbose=True, safe_copy=False, _tqdm=False, ): aux_repo + '/ecb6da7bd4deb98cd0a4e83b3da81c1e67505b16/remapped_channels_since_20200729_17.20UTC.csv', fmt='csv') - def wr_tqdm(x): - """Wrap input x with tqdm""" - if _tqdm: - try: - return tqdm.tqdm_notebook(x) - except (AttributeError, ModuleNotFoundError, ImportError): - # ok, sorry lets not wrap but return x - pass - return x - def convert_channel(_data, replace=('channel', 'max_pmt')): """ Given an array, replace the 'channel' entry if we had to remap it according to the @@ -520,7 +455,7 @@ def convert_channel_like(channel_data, n_chs=n_tpc_pmts): return channel_data # Create a buffer to overright buffer = channel_data.copy() - for k in wr_tqdm(get_dtypes(channel_data)): + for k in strax.utils.tqdm(get_dtypes(channel_data), disable=not _tqdm): if np.iterable(channel_data[k][0]) and len(channel_data[k][0]) == n_chs: if verbose: print(f'convert_channel_like::\tupdate {k}') @@ -615,6 +550,7 @@ def _swap_values_in_array(data_arr, buffer, items, replacements): break return buffer + ## # Old XENON1T Stuff ## diff --git a/straxen/contexts.py b/straxen/contexts.py index 5f22c99a0..9ce60f63b 100644 --- a/straxen/contexts.py +++ b/straxen/contexts.py @@ -2,9 +2,10 @@ import strax import straxen from copy import deepcopy -import socket -from warnings import warn +from .rucio import HAVE_ADMIX +import os +from straxen.common import pax_file common_opts = dict( register_all=[ @@ -29,7 +30,7 @@ check_available=('raw_records', 'peak_basics'), store_run_fields=( 'name', 'number', - 'start', 'end', 'livetime', 'mode')) + 'start', 'end', 'livetime', 'mode', 'source')) xnt_common_config = dict( n_tpc_pmts=straxen.n_tpc_pmts, @@ -52,11 +53,7 @@ nveto_blank=(2999, 2999)), # Clustering/classification parameters # Event level parameters - s2_xy_correction_map=('s2_xy_map', "ONLINE", True), fdc_map=('fdc_map', "ONLINE", True), - s1_xyz_correction_map=("s1_xyz_map", "ONLINE", True), - g1=0.1426, - g2=11.55, ) # these are placeholders to avoid calling cmt with non integer run_ids. Better solution pending. # s1,s2 and fd corrections are still problematic @@ -64,7 +61,7 @@ xnt_simulation_config.update(gain_model=("to_pe_placeholder", True), gain_model_nv=("adc_nv", True), gain_model_mv=("adc_mv", True), - elife_conf=('elife_constant', 1e6), + elife=1e6, ) # Plugins in these files have nT plugins, E.g. in pulse&peak(let) @@ -76,6 +73,7 @@ straxen.PeakPositionsMLP, straxen.PeakPositionsGCN, straxen.PeakPositionsNT, + straxen.S2ReconPosDiff, straxen.PeakBasicsHighEnergy, straxen.PeaksHighEnergy, straxen.PeakletsHighEnergy, @@ -83,7 +81,9 @@ straxen.MergedS2sHighEnergy, straxen.PeakVetoTagging, straxen.EventInfo, - ], + straxen.PeakShadow, + straxen.EventShadow, + ], 'register_all': common_opts['register_all'] + [straxen.veto_veto_regions, straxen.nveto_recorder, straxen.veto_pulse_processing, @@ -99,7 +99,6 @@ ], 'use_per_run_defaults': False, }) - ## # XENONnT ## @@ -113,15 +112,14 @@ def xenonnt(cmt_version='global_ONLINE', **kwargs): def xenonnt_online(output_folder='./strax_data', - use_rucio=None, - use_rucio_remote=False, we_are_the_daq=False, + download_heavy=False, _minimum_run_number=7157, _maximum_run_number=None, _database_init=True, _forbid_creation_of=None, - _rucio_path='/dali/lgrandi/rucio/', _include_rucio_remote=False, + _rucio_path='/dali/lgrandi/rucio/', _raw_path='/dali/lgrandi/xenonnt/raw', _processed_path='/dali/lgrandi/xenonnt/processed', _add_online_monitor_frontend=False, @@ -132,10 +130,8 @@ def xenonnt_online(output_folder='./strax_data', :param output_folder: str, Path of the strax.DataDirectory where new data can be stored - :param use_rucio: bool, whether or not to use the rucio frontend (by - default, we add the frontend when running on an rcc machine) - :param use_rucio_remote: bool, if download data from rucio directly :param we_are_the_daq: bool, if we have admin access to upload data + :param download_heavy: bool, whether or not to allow downloads of heavy data (raw_records*, less the aqmon) :param _minimum_run_number: int, lowest number to consider :param _maximum_run_number: Highest number to consider. When None (the default) consider all runs that are higher than the @@ -160,7 +156,7 @@ def xenonnt_online(output_folder='./strax_data', st = strax.Context( config=straxen.contexts.xnt_common_config, **context_options) - st.register([straxen.DAQReader, straxen.LEDCalibration]) + st.register([straxen.DAQReader, straxen.LEDCalibration, straxen.LEDAfterpulseProcessing]) st.storage = [ straxen.RunDB( @@ -182,22 +178,21 @@ def xenonnt_online(output_folder='./strax_data', readonly=True, )] if output_folder: - st.storage.append( - strax.DataDirectory(output_folder, - provide_run_metadata=True, - )) - + st.storage += [strax.DataDirectory(output_folder, + provide_run_metadata=True, + )] st.context_config['forbid_creation_of'] = straxen.daqreader.DAQReader.provides if _forbid_creation_of is not None: st.context_config['forbid_creation_of'] += strax.to_str_tuple(_forbid_creation_of) - # Add the rucio frontend to storage when asked to or if we did not - # specify anything and are on rcc - if use_rucio or (use_rucio is None and 'rcc' in socket.getfqdn()): - st.storage.append(straxen.rucio.RucioFrontend( - include_remote=use_rucio_remote, - staging_dir=output_folder, - )) + # Add the rucio frontend if we are able to + if HAVE_ADMIX: + rucio_frontend = straxen.rucio.RucioFrontend( + include_remote=_include_rucio_remote, + staging_dir=os.path.join(output_folder, 'rucio'), + download_heavy=download_heavy, + ) + st.storage += [rucio_frontend] # Only the online monitor backend for the DAQ if _database_init and (_add_online_monitor_frontend or we_are_the_daq): @@ -206,7 +201,9 @@ def xenonnt_online(output_folder='./strax_data', take_only=('veto_intervals', 'online_peak_monitor', 'event_basics', - 'online_monitor_nv'))] + 'online_monitor_nv', + 'online_monitor_mv', + ))] # Remap the data if it is before channel swap (because of wrongly cabled # signal cable connectors) These are runs older than run 8797. Runs @@ -224,7 +221,10 @@ def xenonnt_online(output_folder='./strax_data', def xenonnt_led(**kwargs): st = xenonnt_online(**kwargs) - st.context_config['check_available'] = ('raw_records', 'led_calibration') + st.set_context_config( + {'check_available': ('raw_records', 'led_calibration'), + 'free_options': list(xnt_common_config.keys()) + }) # Return a new context with only raw_records and led_calibration registered st = st.new_context( replace=True, @@ -243,9 +243,10 @@ def xenonnt_led(**kwargs): def xenonnt_simulation( output_folder='./strax_data', + wfsim_registry='RawRecordsFromFaxNT', cmt_run_id_sim=None, cmt_run_id_proc=None, - cmt_version='v3', + cmt_version='global_ONLINE', fax_config='fax_config_nt_design.json', overwrite_from_fax_file_sim=False, overwrite_from_fax_file_proc=False, @@ -255,7 +256,8 @@ def xenonnt_simulation( _config_overlap=immutabledict( drift_time_gate='electron_drift_time_gate', drift_velocity_liquid='electron_drift_velocity', - electron_lifetime_liquid='elife_conf'), + electron_lifetime_liquid='elife', + ), **kwargs): """ The most generic context that allows for setting full divergent @@ -266,7 +268,7 @@ def xenonnt_simulation( refer to detector simulation parameters. Arguments having _proc in their name refer to detector parameters that - are used for processing of simulations as done to the real datector + are used for processing of simulations as done to the real detector data. This means starting from already existing raw_records and finishing with higher level data, such as peaks, events etc. @@ -277,6 +279,7 @@ def xenonnt_simulation( CMT options can also be overwritten via fax config file. :param output_folder: Output folder for strax data. + :param wfsim_registry: Name of WFSim plugin used to generate data. :param cmt_run_id_sim: Run id for detector parameters from CMT to be used for creation of raw_records. :param cmt_run_id_proc: Run id for detector parameters from CMT to be used @@ -287,7 +290,7 @@ def xenonnt_simulation( parameters for truth/raw_records from from fax_config file istead of CMT :param overwrite_from_fax_file_proc: If true sets detector processing parameters after raw_records(peaklets/events/etc) from from fax_config - file istead of CMT + file instead of CMT :param cmt_option_overwrite_sim: Dictionary to overwrite CMT settings for the detector simulation part. :param cmt_option_overwrite_proc: Dictionary to overwrite CMT settings for @@ -307,11 +310,15 @@ def xenonnt_simulation( check_raw_record_overlaps=True, **straxen.contexts.xnt_common_config,), **straxen.contexts.xnt_common_opts, **kwargs) - st.register(wfsim.RawRecordsFromFaxNT) - if straxen.utilix_is_configured(): - st.apply_cmt_version(f'global_{cmt_version}') - else: - warn(f'Bad context as we cannot set CMT since we have no database access') + st.register(getattr(wfsim, wfsim_registry)) + + # Make sure that the non-simulated raw-record types are not requested + st.deregister_plugins_with_missing_dependencies() + + if straxen.utilix_is_configured( + warning_message='Bad context as we cannot set CMT since we ' + 'have no database access'''): + st.apply_cmt_version(cmt_version) if _forbid_creation_of is not None: st.context_config['forbid_creation_of'] += strax.to_str_tuple(_forbid_creation_of) @@ -330,20 +337,41 @@ def xenonnt_simulation( cmt_run_id_proc = cmt_id # Replace default cmt options with cmt_run_id tag + cmt run id - cmt_options = straxen.get_corrections.get_cmt_options(st) + cmt_options_full = straxen.get_corrections.get_cmt_options(st) + + # prune to just get the strax options + cmt_options = {key: val['strax_option'] + for key, val in cmt_options_full.items()} # First, fix gain model for simulation st.set_config({'gain_model_mc': ('cmt_run_id', cmt_run_id_sim, *cmt_options['gain_model'])}) fax_config_override_from_cmt = dict() for fax_field, cmt_field in _config_overlap.items(): + value = cmt_options[cmt_field] + + # URL configs need to be converted to the expected format + if isinstance(value, str): + opt_cfg = cmt_options_full[cmt_field] + version = straxen.URLConfig.kwarg_from_url(value, 'version') + # We now allow the cmt name to be different from the config name + # WFSim expects the cmt name + value = (opt_cfg['correction'], version, True) + fax_config_override_from_cmt[fax_field] = ('cmt_run_id', cmt_run_id_sim, - *cmt_options[cmt_field]) + *value) st.set_config({'fax_config_override_from_cmt': fax_config_override_from_cmt}) # and all other parameters for processing for option in cmt_options: - st.config[option] = ('cmt_run_id', cmt_run_id_proc, *cmt_options[option]) + value = cmt_options[option] + if isinstance(value, str): + # for URL configs we can just replace the run_id keyword argument + # This will become the proper way to override the run_id for cmt configs + st.config[option] = straxen.URLConfig.format_url_kwargs(value, run_id=cmt_run_id_proc) + else: + # FIXME: Remove once all cmt configs are URLConfigs + st.config[option] = ('cmt_run_id', cmt_run_id_proc, *value) # Done with "default" usage, now to overwrites from file # @@ -352,11 +380,20 @@ def xenonnt_simulation( fax_config = straxen.get_resource(fax_config, fmt='json') for fax_field, cmt_field in _config_overlap.items(): if overwrite_from_fax_file_proc: - st.config[cmt_field] = ( cmt_options[cmt_field][0] + '_constant', - fax_config[fax_field]) + if isinstance(cmt_options[cmt_field], str): + # URLConfigs can just be set to a constant + st.config[cmt_field] = fax_config[fax_field] + else: + # FIXME: Remove once all cmt configs are URLConfigs + st.config[cmt_field] = (cmt_options[cmt_field][0] + '_constant', + fax_config[fax_field]) if overwrite_from_fax_file_sim: + # CMT name allowed to be different from the config name + # WFSim needs the cmt name + cmt_name = cmt_options_full[cmt_field]['correction'] + st.config['fax_config_override_from_cmt'][fax_field] = ( - cmt_options[cmt_field][0] + '_constant',fax_config[fax_field]) + cmt_name + '_constant', fax_config[fax_field]) # And as the last step - manual overrrides, since they have the highest priority # User customized for simulation @@ -364,26 +401,32 @@ def xenonnt_simulation( if option not in cmt_options: raise ValueError(f'Overwrite option {option} is not using CMT by default ' 'you should just use set config') - if not option in _config_overlap.values(): + if option not in _config_overlap.values(): raise ValueError(f'Overwrite option {option} does not have mapping from ' - 'CMT to fax config! ') - for fax_key,cmt_key in _config_overlap.items(): - if cmt_key==option: - _name_index = 2 if 'cmt_run_id' in cmt_options[option] else 0 + f'CMT to fax config!') + for fax_key, cmt_key in _config_overlap.items(): + if cmt_key == option: + cmt_name = cmt_options_full[option]['correction'] st.config['fax_config_override_from_cmt'][fax_key] = ( - cmt_options[option][_name_index] + '_constant', + cmt_name + '_constant', cmt_option_overwrite_sim[option]) - del(_name_index) del(fax_key, cmt_key) # User customized for simulation for option in cmt_option_overwrite_proc: if option not in cmt_options: raise ValueError(f'Overwrite option {option} is not using CMT by default ' 'you should just use set config') - _name_index = 2 if 'cmt_run_id' in cmt_options[option] else 0 - st.config[option] = (cmt_options[option][_name_index] + '_constant', - cmt_option_overwrite_proc[option]) - del(_name_index) + + if isinstance(cmt_options[option], str): + # URLConfig options can just be set to constants, no hacks needed + # But for now lets keep things consistent for people + st.config[option] = cmt_option_overwrite_proc[option] + else: + # CMT name allowed to be different from the config name + # WFSim needs the cmt name + cmt_name = cmt_options_full[option]['correction'] + st.config[option] = (cmt_name + '_constant', + cmt_option_overwrite_proc[option]) # Only for simulations st.set_config({"event_info_function": "disabled"}) @@ -393,7 +436,6 @@ def xenonnt_simulation( # XENON1T ## - x1t_context_config = { **common_opts, **dict( @@ -445,16 +487,22 @@ def xenonnt_simulation( # Peaks # Smaller right extension since we applied the filter peak_right_extension=30, - s1_max_rise_time=60, s1_max_rise_time_post100=150, s1_min_coincidence=3, # Events* left_event_extension=int(0.3e6), right_event_extension=int(1e6), - elife_conf=('elife_xenon1t', 'v1', False), + elife=1e6, electron_drift_velocity=("electron_drift_velocity_constant", 1.3325e-4), max_drift_length=96.9, electron_drift_time_gate=("electron_drift_time_gate_constant", 1700), + se_gain=28.2, + avg_se_gain=28.2, + rel_extraction_eff=1.0, + s1_xyz_map=f'itp_map://resource://{pax_file("XENON1T_s1_xyz_lce_true_kr83m_SR1_pax-680_fdc-3d_v0.json")}?fmt=json', # noqa + s2_xy_map=f'itp_map://resource://{pax_file("XENON1T_s2_xy_ly_SR1_v2.2.json")}?fmt=json', + g1=0.1426, + g2=11.55/(1 - 0.63), ) @@ -476,8 +524,13 @@ def demo(): st.set_config(dict( hev_gain_model=('1T_to_pe_placeholder', False), gain_model=('1T_to_pe_placeholder', False), - elife_conf=('elife_constant', 1e6), + elife=1e6, electron_drift_velocity=("electron_drift_velocity_constant", 1.3325e-4), + se_gain=28.2, + avg_se_gain=28.2, + rel_extraction_eff=1.0, + s1_xyz_map=f'itp_map://resource://{pax_file("XENON1T_s1_xyz_lce_true_kr83m_SR1_pax-680_fdc-3d_v0.json")}?fmt=json', + s2_xy_map=f'itp_map://resource://{pax_file("XENON1T_s2_xy_ly_SR1_v2.2.json")}?fmt=json', )) return st @@ -529,7 +582,10 @@ def xenon1t_dali(output_folder='./strax_data', build_lowlevel=False, **kwargs): def xenon1t_led(**kwargs): st = xenon1t_dali(**kwargs) - st.context_config['check_available'] = ('raw_records', 'led_calibration') + st.set_context_config( + {'check_available': ('raw_records', 'led_calibration'), + 'free_options': list(x1t_context_config.keys()) + }) # Return a new context with only raw_records and led_calibration registered st = st.new_context( replace=True, @@ -550,4 +606,5 @@ def xenon1t_simulation(output_folder='./strax_data'): **x1t_common_config), **x1t_context_config) st.register(wfsim.RawRecordsFromFax1T) + st.deregister_plugins_with_missing_dependencies() return st diff --git a/straxen/corrections_services.py b/straxen/corrections_services.py index 3c844c4ad..b83cbcea9 100644 --- a/straxen/corrections_services.py +++ b/straxen/corrections_services.py @@ -1,7 +1,6 @@ """Return corrections from corrections DB """ import warnings - import pytz import numpy as np from functools import lru_cache @@ -9,31 +8,36 @@ import utilix import straxen import os -from immutabledict import immutabledict - +from urllib.parse import urlparse, parse_qs export, __all__ = strax.exporter() -corrections_w_file = ['mlp_model', 'gcn_model', 'cnn_model', - 's2_xy_map', 's1_xyz_map_mlp', 's1_xyz_map_cnn', - 's1_xyz_map_gcn', 'fdc_map_mlp', 'fdc_map_gcn', - 'fdc_map_cnn'] +corrections_w_file = ['mlp_model', 'cnn_model', 'gcn_model', + 's2_xy_map_mlp', 's2_xy_map_cnn', 's2_xy_map_gcn', 's2_xy_map', + 's1_xyz_map_mlp', 's1_xyz_map_cnn', 's1_xyz_map_gcn', + 'fdc_map_mlp', 'fdc_map_cnn', 'fdc_map_gcn'] single_value_corrections = ['elife_xenon1t', 'elife', 'baseline_samples_nv', - 'electron_drift_velocity', 'electron_drift_time_gate'] + 'electron_drift_velocity', 'electron_drift_time_gate', + 'se_gain', 'rel_extraction_eff'] arrays_corrections = ['hit_thresholds_tpc', 'hit_thresholds_he', 'hit_thresholds_nv', 'hit_thresholds_mv'] # needed because we pass these names as strax options which then get paired with the default reconstruction algorithm # important for apply_cmt_version -posrec_corrections_basenames = ['s1_xyz_map', 'fdc_map'] +posrec_corrections_basenames = ['s1_xyz_map', 'fdc_map', 's2_xy_map'] +@export class CMTVersionError(Exception): pass +class CMTnanValueError(Exception): + pass + + @export class CorrectionsManagementServices(): """ @@ -104,7 +108,6 @@ def get_corrections_config(self, run_id, config_model=None): f"available {single_value_corrections}, {arrays_corrections} and " f"{corrections_w_file} ") - # TODO add option to extract 'when'. Also, the start time might not be the best # entry for e.g. for super runs # cache results, this would help when looking at the same gains @lru_cache(maxsize=None) @@ -130,14 +133,22 @@ def _get_correction(self, run_id, correction, version): pmts = list(gains.keys()) for it_correction in pmts: # loop over all PMTs if correction in it_correction: - df = self.interface.read(it_correction) + df = self.interface.read_at(it_correction, when) + if df[version].isnull().values.any(): + raise CMTnanValueError(f"For {it_correction} there are NaN values, this means no correction available " + f"for {run_id} in version {version}, please check e-logbook for more info ") + if version in 'ONLINE': df = self.interface.interpolate(df, when, how='fill') else: df = self.interface.interpolate(df, when) values.append(df.loc[df.index == when, version].values[0]) else: - df = self.interface.read(correction) + df = self.interface.read_at(correction, when) + if df[version].isnull().values.any(): + raise CMTnanValueError(f"For {correction} there are NaN values, this means no correction available " + f"for {run_id} in version {version}, please check e-logbook for more info ") + if correction in corrections_w_file or correction in arrays_corrections or version in 'ONLINE': df = self.interface.interpolate(df, when, how='fill') else: @@ -189,12 +200,10 @@ def get_pmt_gains(self, run_id, model_type, version, to_pe = self._get_correction(run_id, target_detector, version) # be cautious with very early runs, check that not all are None - if np.isnan(to_pe).all(): + if np.isnan(to_pe).any(): raise ValueError( f"to_pe(PMT gains) values are NaN, no data available " - f"for {run_id} in the gain model with version " - f"{version}, please set constant values for " - f"{run_id}") + f"for {run_id} in the gain model with version") else: raise ValueError(f"{model_type} not implemented for to_pe values") @@ -241,7 +250,6 @@ def get_config_from_cmt(self, run_id, model_type, version='ONLINE'): f"Please contact CMT manager and yell at him") return file_name - # TODO change to st.estimate_start_time def get_start_time(self, run_id): """ Smart logic to return start time from runsDB @@ -319,8 +327,13 @@ def get_cmt_local_versions(global_version): return cmt.get_local_versions(global_version) +def args_idx(x): + """Get the idx of "?" in the string""" + return x.rfind('?') if '?' in x else None + + @strax.Context.add_method -def apply_cmt_version(context: strax.Context, cmt_global_version: str): +def apply_cmt_version(context: strax.Context, cmt_global_version: str) -> None: """Sets all the relevant correction variables :param cmt_global_version: A specific CMT global version, or 'latest' to get the newest one :returns None @@ -342,20 +355,53 @@ def apply_cmt_version(context: strax.Context, cmt_global_version: str): # we want this error to occur in order to keep fixed global versions cmt_config = dict() failed_keys = [] - for option, tup in cmt_options.items(): - try: - # might need to modify correction name to include position reconstruction algo - correction_name = tup[0] - if correction_name in posrec_corrections_basenames: - correction_name += f"_{posrec_algo}" - new_tup = (tup[0], local_versions[correction_name], tup[2]) - except KeyError: - failed_keys.append(option) + + for option, option_info in cmt_options.items(): + # name of the CMT correction, this is not always equal to the strax option + correction_name = option_info['correction'] + # actual config option + # this could be either a CMT tuple or a URLConfig + value = option_info['strax_option'] + + # might need to modify correction name to include position reconstruction algo + # this is a bit of a mess, but posrec configs are treated differently in the tuples + # URL configs should already include the posrec suffix + # (it's real mess -- we should drop tuple configs) + if correction_name in posrec_corrections_basenames: + correction_name += f"_{posrec_algo}" + + # now see if our correction is in our local_versions dict + if correction_name in local_versions: + if isinstance(value, str) and 'cmt://' in value: + new_value = replace_url_version(value, local_versions[correction_name]) + # if it is a tuple, make a new tuple + else: + new_value = (value[0], local_versions[correction_name], value[2]) + else: + if correction_name not in failed_keys: + failed_keys.append(correction_name) continue - cmt_config[option] = new_tup + + cmt_config[option] = new_value + if len(failed_keys): failed_keys = ', '.join(failed_keys) - raise CMTVersionError(f"CMT version {cmt_global_version} is not compatible with this straxen version! " - f"CMT {cmt_global_version} is missing these corrections: {failed_keys}") + msg = f"CMT version {cmt_global_version} is not compatible with this straxen version! " \ + f"CMT {cmt_global_version} is missing these corrections: {failed_keys}" + + # only raise a warning if we are working with the online context + if cmt_global_version == "global_ONLINE": + warnings.warn(msg, UserWarning) + else: + raise CMTVersionError(msg) context.set_config(cmt_config) + + +def replace_url_version(url, version): + """Replace the local version of a correction in a CMT config""" + kwargs = {k: v[0] for k, v in parse_qs(urlparse(url).query).items()} + kwargs['version'] = version + args = [f"{k}={v}" for k, v in kwargs.items()] + args_str = "&".join(args) + return f'{url[:args_idx(url)]}?{args_str}' diff --git a/straxen/get_corrections.py b/straxen/get_corrections.py index b12c5aa97..e54e1874b 100644 --- a/straxen/get_corrections.py +++ b/straxen/get_corrections.py @@ -1,7 +1,7 @@ import numpy as np import strax import straxen -from warnings import warn +import typing as ty from functools import wraps from straxen.corrections_services import corrections_w_file from straxen.corrections_services import single_value_corrections @@ -63,7 +63,7 @@ def get_correction_from_cmt(run_id, conf): where True means looking at nT runs, e.g. get_correction_from_cmt(run_id, conf[:2]) special cases: - version can be replaced by consant int, float or array + version can be replaced by constant int, float or array when user specify value(s) :param run_id: run id from runDB :param conf: configuration @@ -131,6 +131,7 @@ def get_cmt_resource(run_id, conf, fmt=''): return straxen.get_resource(get_correction_from_cmt(run_id, conf), fmt=fmt) +@export def is_cmt_option(config): """ Check if the input configuration is cmt style. @@ -140,6 +141,9 @@ def is_cmt_option(config): @correction_options def _is_cmt_option(run_id, config): + # Compatibilty with URLConfig + if isinstance(config, str) and "cmt://" in config: + return True is_cmt = (isinstance(config, tuple) and len(config)==3 and isinstance(config[0], str) @@ -149,23 +153,54 @@ def _is_cmt_option(run_id, config): return is_cmt -def get_cmt_options(context): +def get_cmt_options(context: strax.Context) -> ty.Dict[str, ty.Dict[str, tuple]]: """ Function which loops over all plugin configs and returns dictionary - with option name as key and current settings as values. + with option name as key and a nested dict of CMT correction name and strax option as values. :param context: Context with registered plugins. """ + cmt_options = {} + runid_test_str = 'norunids!' + for data_type, plugin in context._plugin_class_registry.items(): for option_key, option in plugin.takes_config.items(): + if option_key in cmt_options: + # let's not do work twice if needed by > 1 plugin + continue + if (option_key in context.config and - straxen.get_corrections.is_cmt_option(context.config[option_key]) - ): - cmt_options[option_key] = context.config[option_key] - elif straxen.get_corrections.is_cmt_option(option.default): - cmt_options[option_key] = option.default + is_cmt_option(context.config[option_key])): + opt = context.config[option_key] + elif is_cmt_option(option.default): + opt = option.default + else: + continue + + # check if it's a URLConfig + if isinstance(opt, str) and 'cmt://' in opt: + before_cmt, cmt, after_cmt = opt.partition('cmt://') + p = context.get_single_plugin(runid_test_str, data_type) + p.config[option_key] = after_cmt + correction_name = getattr(p, option_key) + # make sure the correction name does not depend on runid + if runid_test_str in correction_name: + raise RuntimeError("Correction names should not depend on runids! " + f"Please check your option for {option_key}") + + # if there is no other protocol being called before cmt, + # we will get a string back including the query part + if option.QUERY_SEP in correction_name: + correction_name, _ = option.split_url_kwargs(correction_name) + cmt_options[option_key] = {'correction': correction_name, + 'strax_option': opt, + } + else: + cmt_options[option_key] = {'correction': opt[0], + 'strax_option': opt, + } return cmt_options diff --git a/straxen/itp_map.py b/straxen/itp_map.py index 487c1fe3d..0876993ef 100644 --- a/straxen/itp_map.py +++ b/straxen/itp_map.py @@ -6,7 +6,7 @@ import numpy as np from scipy.spatial import cKDTree from scipy.interpolate import RectBivariateSpline, RegularGridInterpolator - +import straxen import strax export, __all__ = strax.exporter() @@ -147,7 +147,10 @@ def __init__(self, data, method='WeightedNearestNeighbors', **kwargs): for map_name in self.map_names: # Specify dtype float to set Nones to nan map_data = np.array(self.data[map_name], dtype=np.float) - array_valued = len(map_data.shape) == self.dimensions + 1 + if len(self.coordinate_system) == len(map_data): + array_valued = len(map_data.shape) == 2 + else: + array_valued = len(map_data.shape) == self.dimensions + 1 if self.dimensions == 0: # 0 D -- placeholder maps which take no arguments @@ -185,6 +188,7 @@ def _rect_bivariate_spline(csys, map_data, array_valued, **kwargs): assert dimensions == 2, 'RectBivariateSpline interpolate maps of dimension 2' assert not array_valued, 'RectBivariateSpline does not support interpolating array values' map_data = map_data.reshape(*grid_shape) + kwargs = straxen.filter_kwargs(RectBivariateSpline, kwargs) rbs = RectBivariateSpline(grid[0], grid[1], map_data, **kwargs) def arg_formated_rbs(positions): @@ -206,14 +210,18 @@ def _regular_grid_interpolator(csys, map_data, array_valued, **kwargs): map_data = map_data.reshape(*grid_shape) config = dict(bounds_error=False, fill_value=None) + kwargs = straxen.filter_kwargs(RegularGridInterpolator, kwargs) config.update(kwargs) + return RegularGridInterpolator(tuple(grid), map_data, **config) @staticmethod def _weighted_nearest_neighbors(csys, map_data, array_valued, **kwargs): if array_valued: map_data = map_data.reshape((-1, map_data.shape[-1])) - + else: + map_data = map_data.flatten() + kwargs = straxen.filter_kwargs(InterpolateAndExtrapolate, kwargs) return InterpolateAndExtrapolate(csys, map_data, array_valued=array_valued, **kwargs) def scale_coordinates(self, scaling_factor, map_name='map'): @@ -235,7 +243,10 @@ def scale_coordinates(self, scaling_factor, map_name='map'): alt_csys[i] = [gc * k for (gc, k) in zip(gp, self._sf)] map_data = np.array(self.data[map_name]) - array_valued = len(map_data.shape) == self.dimensions + 1 + if len(self.coordinate_system) == len(map_data): + array_valued = len(map_data.shape) == 2 + else: + array_valued = len(map_data.shape) == self.dimensions + 1 if array_valued: map_data = map_data.reshape((-1, map_data.shape[-1])) itp_fun = InterpolateAndExtrapolate(points=np.array(alt_csys), diff --git a/straxen/matplotlib_utils.py b/straxen/matplotlib_utils.py index 390ce9a2c..122fbc767 100644 --- a/straxen/matplotlib_utils.py +++ b/straxen/matplotlib_utils.py @@ -192,6 +192,7 @@ def logticks(tmin, tmax=None, tick_at=None): @export def quiet_tight_layout(): + warnings.warn('Don\'t use quiet_tight_layout it will be removed in a future release') with warnings.catch_warnings(): warnings.simplefilter("ignore") plt.tight_layout() @@ -203,6 +204,7 @@ def draw_box(x, y, **kwargs): plt.gca().add_patch(matplotlib.patches.Rectangle( (x[0], y[0]), x[1] - x[0], y[1] - y[0], facecolor='none', **kwargs)) + @export def plot_single_pulse(records, run_id, pulse_i=''): """ diff --git a/straxen/mini_analysis.py b/straxen/mini_analysis.py index 21a5e4ddb..7e9fef1be 100644 --- a/straxen/mini_analysis.py +++ b/straxen/mini_analysis.py @@ -50,6 +50,9 @@ def wrapped_f(context: strax.Context, run_id: str, **kwargs): # Say magic words to enable holoviews if hv_bokeh: + # Generally using globals is not great, but it would be + # the same as doing a slow import on the top of this file + # pylint: disable=global-statement global _hv_bokeh_initialized if not _hv_bokeh_initialized: import holoviews @@ -94,7 +97,7 @@ def wrapped_f(context: strax.Context, run_id: str, **kwargs): for dkind, dtypes in deps_by_kind.items(): if dkind in kwargs: # Already have data, just apply cuts - kwargs[dkind] = context.apply_selection( + kwargs[dkind] = strax.apply_selection( kwargs[dkind], selection_str=kwargs['selection_str'], time_range=kwargs['time_range'], diff --git a/straxen/misc.py b/straxen/misc.py index 6e41ef0be..590a0a3bc 100644 --- a/straxen/misc.py +++ b/straxen/misc.py @@ -2,14 +2,24 @@ import pandas as pd import socket import strax +import inspect import straxen import sys import warnings import datetime import pytz -from os import environ as os_environ +from sys import getsizeof, stderr +from itertools import chain +from collections import OrderedDict, deque from importlib import import_module +from git import Repo, InvalidGitRepositoryError from configparser import NoSectionError +import typing as ty +try: + # pylint: disable=redefined-builtin + from reprlib import repr +except ImportError: + pass export, __all__ = strax.exporter() @@ -41,7 +51,9 @@ def do_round(x): @export -def print_versions(modules=('strax', 'straxen', 'cutax'), return_string=False): +def print_versions(modules=('strax', 'straxen', 'cutax'), + return_string=False, + include_git=True): """ Print versions of modules installed. @@ -50,6 +62,8 @@ def print_versions(modules=('strax', 'straxen', 'cutax'), return_string=False): 'cutax', 'pema')) :param return_string: optional. Instead of printing the message, return a string + :param include_git: Include the current branch and latest + commit hash :return: optional, the message that would have been printed """ message = (f'Working on {socket.getfqdn()} with the following ' @@ -59,29 +73,65 @@ def print_versions(modules=('strax', 'straxen', 'cutax'), return_string=False): for m in strax.to_str_tuple(modules): try: mod = import_module(m) - message += f'\n{m}' - if hasattr(mod, '__version__'): - message += f'\tv{mod.__version__}' - if hasattr(mod, '__path__'): - message += f'\t{mod.__path__[0]}' except (ModuleNotFoundError, ImportError): print(f'{m} is not installed') + continue + + message += f'\n{m}' + if hasattr(mod, '__version__'): + message += f'\tv{mod.__version__}' + if hasattr(mod, '__path__'): + module_path = mod.__path__[0] + message += f'\t{module_path}' + if include_git: + try: + repo = Repo(module_path, search_parent_directories=True) + except InvalidGitRepositoryError: + # not a git repo + pass + else: + try: + branch = repo.active_branch + except TypeError: + branch = 'unknown' + try: + commit_hash = repo.head.object.hexsha + except TypeError: + commit_hash = 'unknown' + message += f'\tgit branch:{branch} | {commit_hash[:7]}' if return_string: return message print(message) @export -def utilix_is_configured(header='RunDB', section='xent_database') -> bool: +def utilix_is_configured(header: str = 'RunDB', + section: str = 'xent_database', + warning_message: ty.Union[None, bool, str] = None, + ) -> bool: """ Check if we have the right connection to :return: bool, can we connect to the Mongo database? + + :param header: Which header to check in the utilix config file + :param section: Which entry in the header to check to exist + :param warning_message: If utilix is not configured, warn the user. + if None -> generic warning + if str -> use the string to warn + if False -> don't warn """ try: - return (hasattr(straxen.uconfig, 'get') and - straxen.uconfig.get(header, section) is not None) + is_configured = (hasattr(straxen.uconfig, 'get') and + straxen.uconfig.get(header, section) is not None) except NoSectionError: - return False + is_configured = False + + should_report = bool(warning_message) or warning_message is None + if not is_configured and should_report: + if warning_message is None: + warning_message = 'Utilix is not configured, cannot proceed' + warnings.warn(warning_message) + return is_configured @export @@ -203,6 +253,27 @@ def _convert_to_datetime(time_widget, time_zone): return time, time_ns +@strax.Context.add_method +def extract_latest_comment(self): + """ + Extract the latest comment in the runs-database. This just adds info to st.runs + + Example: + st.extract_latest_comment() + st.select_runs(available=('raw_records')) + """ + if self.runs is None or 'comments' not in self.runs.keys(): + self.scan_runs(store_fields=('comments',)) + latest_comments = _parse_to_last_comment(self.runs['comments']) + self.runs['comments'] = latest_comments + return self.runs + + +def _parse_to_last_comment(comments): + """Unpack to get the last comment (hence the -1) or give '' when there is none""" + return [(c[-1]['comment'] if hasattr(c, '__len__') else '') for c in comments] + + @export def convert_array_to_df(array: np.ndarray) -> pd.DataFrame: """ @@ -217,6 +288,84 @@ def convert_array_to_df(array: np.ndarray) -> pd.DataFrame: @export -def _is_on_pytest(): - """Check if we are on a pytest""" - return 'PYTEST_CURRENT_TEST' in os_environ +def filter_kwargs(func, kwargs): + """Filter out keyword arguments that + are not in the call signature of func + and return filtered kwargs dictionary + """ + params = inspect.signature(func).parameters + if any([str(p).startswith('**') for p in params.values()]): + # if func accepts wildcard kwargs, return all + return kwargs + return {k: v for k, v in kwargs.items() if k in params} + + +@export +class CacheDict(OrderedDict): + """Dict with a limited length, ejecting LRUs as needed. + copied from + https://gist.github.com/davesteele/44793cd0348f59f8fadd49d7799bd306 + """ + + def __init__(self, *args, cache_len: int = 10, **kwargs): + assert cache_len > 0 + self.cache_len = cache_len + + super().__init__(*args, **kwargs) + + def __setitem__(self, key, value): + super().__setitem__(key, value) + super().move_to_end(key) + + while len(self) > self.cache_len: + oldkey = next(iter(self)) + super().__delitem__(oldkey) + + def __getitem__(self, key): + val = super().__getitem__(key) + super().move_to_end(key) + + return val + +@export +def total_size(o, handlers=None, verbose=False): + """ Returns the approximate memory footprint an object and all of its contents. + + Automatically finds the contents of the following builtin containers and + their subclasses: tuple, list, deque, dict, set and frozenset. + To search other containers, add handlers to iterate over their contents: + + handlers = {SomeContainerClass: iter, + OtherContainerClass: OtherContainerClass.get_elements} + + from: https://code.activestate.com/recipes/577504/ + """ + dict_handler = lambda d: chain.from_iterable(d.items()) + all_handlers = {tuple: iter, + list: iter, + deque: iter, + dict: dict_handler, + set: iter, + frozenset: iter, + } + if handlers is not None: + all_handlers.update(handlers) # user handlers take precedence + seen = set() # track which object id's have already been seen + default_size = getsizeof(0) # estimate sizeof object without __sizeof__ + + def sizeof(o): + if id(o) in seen: # do not double count the same object + return 0 + seen.add(id(o)) + s = getsizeof(o, default_size) + + if verbose: + print(s, type(o), repr(o), file=stderr) + + for typ, handler in all_handlers.items(): + if isinstance(o, typ): + s += sum(map(sizeof, handler(o))) + break + return s + + return sizeof(o) diff --git a/straxen/mongo_storage.py b/straxen/mongo_storage.py index f0e3683f6..7891277c0 100644 --- a/straxen/mongo_storage.py +++ b/straxen/mongo_storage.py @@ -1,393 +1,397 @@ -import os -import tempfile -from datetime import datetime -from warnings import warn -import pytz -from strax import exporter, to_str_tuple -import gridfs -from tqdm import tqdm -from shutil import move -import hashlib -from pymongo.collection import Collection as pymongo_collection -import utilix -from straxen import uconfig - -export, __all__ = exporter() - - -@export -class GridFsInterface: - """ - Base class to upload/download the files to a database using GridFS - for PyMongo: - https://pymongo.readthedocs.io/en/stable/api/gridfs/index.html#module-gridfs - - This class does the basic shared initiation of the downloader and - uploader classes. - - """ - - def __init__(self, - readonly=True, - file_database='files', - config_identifier='config_name', - collection=None, - ): - """ - GridFsInterface - - :param readonly: bool, can one read or also write to the - database. - :param file_database: str, name of the database. Default should - not be changed. - :param config_identifier: str, header of the files that are - saved in Gridfs - :param collection: pymongo.collection.Collection, (Optional) - PyMongo DataName Collection to bypass normal initiation - using utilix. Should be an object of the form: - pymongo.MongoClient(..).DATABASE_NAME.COLLECTION_NAME - """ - if collection is None: - if not readonly: - # We want admin access to start writing data! - mongo_url = uconfig.get('rundb_admin', 'mongo_rdb_url') - mongo_user = uconfig.get('rundb_admin', 'mongo_rdb_username') - mongo_password = uconfig.get('rundb_admin', 'mongo_rdb_password') - else: - # We can safely use the Utilix defaults - mongo_url = mongo_user = mongo_password = None - - # If no collection arg is passed, it defaults to the 'files' - # collection, see for more details: - # https://github.com/XENONnT/utilix/blob/master/utilix/rundb.py - mongo_kwargs = { - 'url': mongo_url, - 'user': mongo_user, - 'password': mongo_password, - 'database': file_database, - } - # We can safely hard-code the collection as that is always - # the same with GridFS. - collection = utilix.rundb.xent_collection( - **mongo_kwargs, - collection='fs.files') - else: - # Check the user input is fine for what we want to do. - if not isinstance(collection, pymongo_collection): - raise ValueError('Provide PyMongo collection (see docstring)!') - assert file_database is None, "Already provided a collection!" - - # Set collection and make sure it can at least do a 'find' operation - self.collection = collection - self.test_find() - - # This is the identifier under which we store the files. - self.config_identifier = config_identifier - - # The GridFS used in this database - self.grid_fs = gridfs.GridFS(collection.database) - - def get_query_config(self, config): - """ - Generate identifier to query against. This is just the configs - name. - - :param config: str, name of the file of interest - :return: dict, that can be used in queries - """ - return {self.config_identifier: config} - - def document_format(self, config): - """ - Format of the document to upload - - :param config: str, name of the file of interest - :return: dict, that will be used to add the document - """ - doc = self.get_query_config(config) - doc.update({ - 'added': datetime.now(tz=pytz.utc), - }) - return doc - - def config_exists(self, config): - """ - Quick check if this config is already saved in the collection - - :param config: str, name of the file of interest - :return: bool, is this config name stored in the database - """ - query = self.get_query_config(config) - return self.collection.count_documents(query) > 0 - - def md5_stored(self, abs_path): - """ - NB: RAM intensive operation! - Carefully compare if the MD5 identifier is the same as the file - as stored under abs_path. - - :param abs_path: str, absolute path to the file name - :return: bool, returns if the exact same file is already stored - in the database - - """ - if not os.path.exists(abs_path): - # A file that does not exist does not have the same MD5 - return False - query = {'md5': self.compute_md5(abs_path)} - return self.collection.count_documents(query) > 0 - - def test_find(self): - """ - Test the connection to the self.collection to see if we can - perform a collection.find operation. - """ - if self.collection.find_one(projection="_id") is None: - raise ConnectionError('Could not find any data in this collection') - - def list_files(self): - """ - Get a complete list of files that are stored in the database - - :return: list, list of the names of the items stored in this - database - - """ - return [doc[self.config_identifier] - for doc in - self.collection.find( - projection= - {self.config_identifier: 1}) - if self.config_identifier in doc - ] - - @staticmethod - def compute_md5(abs_path): - """ - NB: RAM intensive operation! - Get the md5 hash of a file stored under abs_path - - :param abs_path: str, absolute path to a file - :return: str, the md5-hash of the requested file - """ - # This function is copied from: - # stackoverflow.com/questions/3431825/generating-an-md5-checksum-of-a-file - - if not os.path.exists(abs_path): - # if there is no file, there is nothing to compute - return "" - # Also, disable all the Use of insecure MD2, MD4, MD5, or SHA1 - # hash function violations in this function. - # disable bandit - hash_md5 = hashlib.md5() - with open(abs_path, "rb") as f: - for chunk in iter(lambda: f.read(4096), b""): - hash_md5.update(chunk) - return hash_md5.hexdigest() - - -@export -class MongoUploader(GridFsInterface): - """ - Class to upload files to GridFs - """ - - def __init__(self, readonly=False, *args, **kwargs): - # Same as parent. Just check the readonly_argument - if readonly: - raise PermissionError( - "How can you upload if you want to operate in readonly?") - super().__init__(*args, readonly=readonly, **kwargs) - - def upload_from_dict(self, file_path_dict): - """ - Upload all files in the dictionary to the database. - - :param file_path_dict: dict, dictionary of paths to upload. The - dict should be of the format: - file_path_dict = {'config_name': '/the_config_path', ...} - - :return: None - """ - if not isinstance(file_path_dict, dict): - raise ValueError(f'file_path_dict must be dict of form ' - f'"dict(NAME=ABSOLUTE_PATH,...)". Got ' - f'{type(file_path_dict)} instead') - - for config, abs_path in tqdm(file_path_dict.items()): - # We need to do this expensive check here. It is not enough - # to just check that the file is stored under the - # 'config_identifier'. What if the file changed? Then we - # want to upload a new file! Otherwise we could have done - # the self.config_exists-query. If it turns out we have the - # exact same file, forget about uploading it. - if self.config_exists(config) and self.md5_stored(abs_path): - continue - else: - # This means we are going to upload the file because its - # not stored yet. - try: - self.upload_single(config, abs_path) - except (CouldNotLoadError, ConfigTooLargeError): - # Perhaps we should fail then? - warn(f'Cannot upload {config}') - - def upload_single(self, config, abs_path): - """ - Upload a single file to gridfs - - :param config: str, the name under which this file should be - stored - - :param abs_path: str, the absolute path of the file - """ - doc = self.document_format(config) - if not os.path.exists(abs_path): - raise CouldNotLoadError(f'{abs_path} does not exits') - - print(f'uploading {config}') - with open(abs_path, 'rb') as file: - self.grid_fs.put(file, **doc) - - -@export -class MongoDownloader(GridFsInterface): - """ - Class to download files from GridFs - """ - - def __init__(self, - store_files_at=None, - *args, **kwargs): - super().__init__(*args, **kwargs) - - # We are going to set a place where to store the files. It's - # either specified by the user or we use these defaults: - if store_files_at is None: - store_files_at = ('/tmp/straxen_resource_cache/', - '/dali/lgrandi/strax/resource_cache', - './resource_cache', - ) - elif not isinstance(store_files_at, (tuple, str, list)): - raise ValueError(f'{store_files_at} should be tuple of paths!') - elif isinstance(store_files_at, str): - store_files_at = to_str_tuple(store_files_at) - - self.storage_options = store_files_at - - def download_single(self, - config_name: str, - human_readable_file_name=False): - """ - Download the config_name if it exists - - :param config_name: str, the name under which the file is stored - - :param human_readable_file_name: bool, store the file also under - it's human readable name. It is better not to use this as - the user might not know if the version of the file is the - latest. - - :return: str, the absolute path of the file requested - """ - if self.config_exists(config_name): - # Query by name - query = self.get_query_config(config_name) - try: - # This could return multiple since we upload files if - # they have changed again! Therefore just take the last. - fs_object = self.grid_fs.get_last_version(**query) - except gridfs.NoFile as e: - raise CouldNotLoadError( - f'{config_name} cannot be downloaded from GridFs') from e - - # Ok, so we can open it. We will store the file under it's - # md5-hash as that allows to easily compare if we already - # have the correct file. - if human_readable_file_name: - target_file_name = config_name - else: - target_file_name = fs_object.md5 - - for cache_folder in self.storage_options: - possible_path = os.path.join(cache_folder, target_file_name) - if os.path.exists(possible_path): - # Great! This already exists. Let's just return - # where it is stored. - return possible_path - - # Apparently the file does not exist, let's find a place to - # store the file and download it. - store_files_at = self._check_store_files_at(self.storage_options) - destination_path = os.path.join(store_files_at, target_file_name) - - # Let's open a temporary directory, download the file, and - # try moving it to the destination_path. This prevents - # simultaneous writes of the same file. - with tempfile.TemporaryDirectory() as temp_directory_name: - temp_path = os.path.join(temp_directory_name, target_file_name) - - with open(temp_path, 'wb') as stored_file: - # This is were we do the actual downloading! - warn(f'Downloading {config_name} to {destination_path}') - stored_file.write(fs_object.read()) - - if not os.path.exists(destination_path): - # Move the file to the place we want to store it. - move(temp_path, destination_path) - return destination_path - - else: - raise ValueError(f'Config {config_name} cannot be downloaded ' - f'since it is not stored') - - def get_abs_path(self, config_name): - return self.download_single(config_name) - - def download_all(self): - """Download all the files that are stored in the mongo collection""" - raise NotImplementedError('This feature is disabled for now') - # Disable the inspection of `Unreachable code` - # pylint: disable=unreachable - for config in self.list_files(): - self.download_single(config) - - @staticmethod - def _check_store_files_at(cache_folder_alternatives): - """ - Iterate over the options in cache_options until we find a folder - where we can store data. Order does matter as we iterate - until we find one folder that is willing. - - :param cache_folder_alternatives: tuple, this tuple must be a - list of paths one can try to store the downloaded data - - :return: str, the folder that we can write to. - """ - if not isinstance(cache_folder_alternatives, (tuple, list)): - raise ValueError('cache_folder_alternatives must be tuple') - for folder in cache_folder_alternatives: - if not os.path.exists(folder): - try: - os.makedirs(folder) - except (PermissionError, OSError): - continue - if os.access(folder, os.W_OK): - return folder - raise PermissionError( - f'Cannot write to any of the cache_folder_alternatives: ' - f'{cache_folder_alternatives}') - - -class CouldNotLoadError(Exception): - """Raise if we cannot load this kind of data""" - # Disable the inspection of 'Unnecessary pass statement' - # pylint: disable=unnecessary-pass - pass - - -class ConfigTooLargeError(Exception): - """Raise if the data is to large to be uploaded into mongo""" - # Disable the inspection of 'Unnecessary pass statement' - # pylint: disable=unnecessary-pass - pass +import os +import tempfile +from datetime import datetime +from warnings import warn +import pytz +from strax import exporter, to_str_tuple +import gridfs +from tqdm import tqdm +from shutil import move +import hashlib +from pymongo.collection import Collection as pymongo_collection +import utilix +from straxen import uconfig + +export, __all__ = exporter() + + +@export +class GridFsInterface: + """ + Base class to upload/download the files to a database using GridFS + for PyMongo: + https://pymongo.readthedocs.io/en/stable/api/gridfs/index.html#module-gridfs + + This class does the basic shared initiation of the downloader and + uploader classes. + + """ + + def __init__(self, + readonly=True, + file_database='files', + config_identifier='config_name', + collection=None, + _test_on_init=False, + ): + """ + GridFsInterface + + :param readonly: bool, can one read or also write to the + database. + :param file_database: str, name of the database. Default should + not be changed. + :param config_identifier: str, header of the files that are + saved in Gridfs + :param collection: pymongo.collection.Collection, (Optional) + PyMongo DataName Collection to bypass normal initiation + using utilix. Should be an object of the form: + pymongo.MongoClient(..).DATABASE_NAME.COLLECTION_NAME + :param _test_on_init: Test if the collection is empty on init + (only deactivate if you are using a brand new database)! + """ + + if collection is None: + if not readonly: + # We want admin access to start writing data! + mongo_url = uconfig.get('rundb_admin', 'mongo_rdb_url') + mongo_user = uconfig.get('rundb_admin', 'mongo_rdb_username') + mongo_password = uconfig.get('rundb_admin', 'mongo_rdb_password') + else: + # We can safely use the Utilix defaults + mongo_url = mongo_user = mongo_password = None + + # If no collection arg is passed, it defaults to the 'files' + # collection, see for more details: + # https://github.com/XENONnT/utilix/blob/master/utilix/rundb.py + mongo_kwargs = { + 'url': mongo_url, + 'user': mongo_user, + 'password': mongo_password, + 'database': file_database, + } + # We can safely hard-code the collection as that is always + # the same with GridFS. + collection = utilix.rundb.xent_collection( + **mongo_kwargs, + collection='fs.files') + else: + # Check the user input is fine for what we want to do. + if not isinstance(collection, pymongo_collection): + raise ValueError('Provide PyMongo collection (see docstring)!') + if file_database is not None: + raise ValueError("Already provided a collection!") + + # Set collection and make sure it can at least do a 'find' operation + self.collection = collection + if _test_on_init: + self.test_find() + + # This is the identifier under which we store the files. + self.config_identifier = config_identifier + + # The GridFS used in this database + self.grid_fs = gridfs.GridFS(collection.database) + + def get_query_config(self, config): + """ + Generate identifier to query against. This is just the configs + name. + + :param config: str, name of the file of interest + :return: dict, that can be used in queries + """ + return {self.config_identifier: config} + + def document_format(self, config): + """ + Format of the document to upload + + :param config: str, name of the file of interest + :return: dict, that will be used to add the document + """ + doc = self.get_query_config(config) + doc.update({ + 'added': datetime.now(tz=pytz.utc), + }) + return doc + + def config_exists(self, config): + """ + Quick check if this config is already saved in the collection + + :param config: str, name of the file of interest + :return: bool, is this config name stored in the database + """ + query = self.get_query_config(config) + return self.collection.count_documents(query) > 0 + + def md5_stored(self, abs_path): + """ + NB: RAM intensive operation! + Carefully compare if the MD5 identifier is the same as the file + as stored under abs_path. + + :param abs_path: str, absolute path to the file name + :return: bool, returns if the exact same file is already stored + in the database + + """ + if not os.path.exists(abs_path): + # A file that does not exist does not have the same MD5 + return False + query = {'md5': self.compute_md5(abs_path)} + return self.collection.count_documents(query) > 0 + + def test_find(self): + """ + Test the connection to the self.collection to see if we can + perform a collection.find operation. + """ + if self.collection.find_one(projection="_id") is None: + raise ConnectionError('Could not find any data in this collection') + + def list_files(self): + """ + Get a complete list of files that are stored in the database + + :return: list, list of the names of the items stored in this + database + + """ + return [doc[self.config_identifier] + for doc in + self.collection.find( + projection= + {self.config_identifier: 1}) + if self.config_identifier in doc + ] + + @staticmethod + def compute_md5(abs_path): + """ + NB: RAM intensive operation! + Get the md5 hash of a file stored under abs_path + + :param abs_path: str, absolute path to a file + :return: str, the md5-hash of the requested file + """ + # This function is copied from: + # stackoverflow.com/questions/3431825/generating-an-md5-checksum-of-a-file + + if not os.path.exists(abs_path): + # if there is no file, there is nothing to compute + return "" + # Also, disable all the Use of insecure MD2, MD4, MD5, or SHA1 + # hash function violations in this function. + # disable bandit + hash_md5 = hashlib.md5() + with open(abs_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + + +@export +class MongoUploader(GridFsInterface): + """ + Class to upload files to GridFs + """ + + def __init__(self, readonly=False, *args, **kwargs): + # Same as parent. Just check the readonly_argument + if readonly: + raise PermissionError( + "How can you upload if you want to operate in readonly?") + super().__init__(*args, readonly=readonly, **kwargs) + + def upload_from_dict(self, file_path_dict): + """ + Upload all files in the dictionary to the database. + + :param file_path_dict: dict, dictionary of paths to upload. The + dict should be of the format: + file_path_dict = {'config_name': '/the_config_path', ...} + + :return: None + """ + if not isinstance(file_path_dict, dict): + raise ValueError(f'file_path_dict must be dict of form ' + f'"dict(NAME=ABSOLUTE_PATH,...)". Got ' + f'{type(file_path_dict)} instead') + + for config, abs_path in tqdm(file_path_dict.items()): + # We need to do this expensive check here. It is not enough + # to just check that the file is stored under the + # 'config_identifier'. What if the file changed? Then we + # want to upload a new file! Otherwise we could have done + # the self.config_exists-query. If it turns out we have the + # exact same file, forget about uploading it. + if self.config_exists(config) and self.md5_stored(abs_path): + continue + else: + # This means we are going to upload the file because its + # not stored yet. + try: + self.upload_single(config, abs_path) + except (CouldNotLoadError, ConfigTooLargeError): + # Perhaps we should fail then? + warn(f'Cannot upload {config}') + + def upload_single(self, config, abs_path): + """ + Upload a single file to gridfs + + :param config: str, the name under which this file should be + stored + + :param abs_path: str, the absolute path of the file + """ + doc = self.document_format(config) + doc['md5'] = self.compute_md5(abs_path) + if not os.path.exists(abs_path): + raise CouldNotLoadError(f'{abs_path} does not exits') + + print(f'uploading {config}') + with open(abs_path, 'rb') as file: + self.grid_fs.put(file, **doc) + + +@export +class MongoDownloader(GridFsInterface): + """ + Class to download files from GridFs + """ + + def __init__(self, + store_files_at=None, + *args, **kwargs): + super().__init__(*args, **kwargs) + + # We are going to set a place where to store the files. It's + # either specified by the user or we use these defaults: + if store_files_at is None: + store_files_at = ('/tmp/straxen_resource_cache/', + '/dali/lgrandi/strax/resource_cache', + './resource_cache', + ) + elif not isinstance(store_files_at, (tuple, str, list)): + raise ValueError(f'{store_files_at} should be tuple of paths!') + elif isinstance(store_files_at, str): + store_files_at = to_str_tuple(store_files_at) + + self.storage_options = store_files_at + + def download_single(self, + config_name: str, + human_readable_file_name=False): + """ + Download the config_name if it exists + + :param config_name: str, the name under which the file is stored + + :param human_readable_file_name: bool, store the file also under + it's human readable name. It is better not to use this as + the user might not know if the version of the file is the + latest. + + :return: str, the absolute path of the file requested + """ + if self.config_exists(config_name): + # Query by name + query = self.get_query_config(config_name) + try: + # This could return multiple since we upload files if + # they have changed again! Therefore just take the last. + fs_object = self.grid_fs.get_last_version(**query) + except gridfs.NoFile as e: + raise CouldNotLoadError( + f'{config_name} cannot be downloaded from GridFs') from e + + # Ok, so we can open it. We will store the file under it's + # md5-hash as that allows to easily compare if we already + # have the correct file. + if human_readable_file_name: + target_file_name = config_name + else: + target_file_name = fs_object.md5 + + for cache_folder in self.storage_options: + possible_path = os.path.join(cache_folder, target_file_name) + if os.path.exists(possible_path): + # Great! This already exists. Let's just return + # where it is stored. + return possible_path + + # Apparently the file does not exist, let's find a place to + # store the file and download it. + store_files_at = self._check_store_files_at(self.storage_options) + destination_path = os.path.join(store_files_at, target_file_name) + + # Let's open a temporary directory, download the file, and + # try moving it to the destination_path. This prevents + # simultaneous writes of the same file. + with tempfile.TemporaryDirectory() as temp_directory_name: + temp_path = os.path.join(temp_directory_name, target_file_name) + + with open(temp_path, 'wb') as stored_file: + # This is were we do the actual downloading! + warn(f'Downloading {config_name} to {destination_path}') + stored_file.write(fs_object.read()) + + if not os.path.exists(destination_path): + # Move the file to the place we want to store it. + move(temp_path, destination_path) + return destination_path + + else: + raise ValueError(f'Config {config_name} cannot be downloaded ' + f'since it is not stored') + + def get_abs_path(self, config_name): + return self.download_single(config_name) + + def download_all(self): + """Download all the files that are stored in the mongo collection""" + for config in self.list_files(): + print(config, self.download_single(config)) + + @staticmethod + def _check_store_files_at(cache_folder_alternatives): + """ + Iterate over the options in cache_options until we find a folder + where we can store data. Order does matter as we iterate + until we find one folder that is willing. + + :param cache_folder_alternatives: tuple, this tuple must be a + list of paths one can try to store the downloaded data + + :return: str, the folder that we can write to. + """ + if not isinstance(cache_folder_alternatives, (tuple, list)): + raise ValueError('cache_folder_alternatives must be tuple') + for folder in cache_folder_alternatives: + if not os.path.exists(folder): + try: + os.makedirs(folder) + except (PermissionError, OSError): + continue + if os.access(folder, os.W_OK): + return folder + raise PermissionError( + f'Cannot write to any of the cache_folder_alternatives: ' + f'{cache_folder_alternatives}') + + +class CouldNotLoadError(Exception): + """Raise if we cannot load this kind of data""" + # Disable the inspection of 'Unnecessary pass statement' + # pylint: disable=unnecessary-pass + pass + + +class ConfigTooLargeError(Exception): + """Raise if the data is to large to be uploaded into mongo""" + # Disable the inspection of 'Unnecessary pass statement' + # pylint: disable=unnecessary-pass + pass diff --git a/straxen/online_monitor.py b/straxen/online_monitor.py index e8e904995..9fd45303f 100644 --- a/straxen/online_monitor.py +++ b/straxen/online_monitor.py @@ -1,55 +1,55 @@ -from strax import MongoFrontend, exporter -from straxen import uconfig - -export, __all__ = exporter() - -default_online_collection = 'online_monitor' - - -@export -def get_mongo_uri(user_key='pymongo_user', - pwd_key='pymongo_password', - url_key='pymongo_url', - header='RunDB'): - user = uconfig.get(header, user_key) - pwd = uconfig.get(header, pwd_key) - url = uconfig.get(header, url_key) - return f"mongodb://{user}:{pwd}@{url}" - - -@export -class OnlineMonitor(MongoFrontend): - """ - Online monitor Frontend for Saving data temporarily to the - database - """ - - def __init__(self, - uri=None, - take_only=None, - database=None, - col_name=default_online_collection, - readonly=True, - *args, **kwargs): - if take_only is None: - raise ValueError(f'Specify which data_types to accept! Otherwise ' - f'the DataBase will be overloaded') - if uri is None and readonly: - uri = get_mongo_uri() - elif uri is None and not readonly: - # 'not readonly' means that you want to write. Let's get - # your admin credentials: - uri = get_mongo_uri(header='rundb_admin', - user_key='mongo_rdb_username', - pwd_key='mongo_rdb_password', - url_key='mongo_rdb_url') - - if database is None: - database = uconfig.get('RunDB', 'pymongo_database') - - super().__init__(uri=uri, - database=database, - take_only=take_only, - col_name=col_name, - *args, **kwargs) - self.readonly = readonly +from strax import MongoFrontend, exporter +from straxen import uconfig + +export, __all__ = exporter() + +default_online_collection = 'online_monitor' + + +@export +def get_mongo_uri(user_key='pymongo_user', + pwd_key='pymongo_password', + url_key='pymongo_url', + header='RunDB'): + user = uconfig.get(header, user_key) + pwd = uconfig.get(header, pwd_key) + url = uconfig.get(header, url_key) + return f"mongodb://{user}:{pwd}@{url}" + + +@export +class OnlineMonitor(MongoFrontend): + """ + Online monitor Frontend for Saving data temporarily to the + database + """ + + def __init__(self, + uri=None, + take_only=None, + database=None, + col_name=default_online_collection, + readonly=True, + *args, **kwargs): + if take_only is None: + raise ValueError(f'Specify which data_types to accept! Otherwise ' + f'the DataBase will be overloaded') + if uri is None and readonly: + uri = get_mongo_uri() + elif uri is None and not readonly: + # 'not readonly' means that you want to write. Let's get + # your admin credentials: + uri = get_mongo_uri(header='rundb_admin', + user_key='mongo_rdb_username', + pwd_key='mongo_rdb_password', + url_key='mongo_rdb_url') + + if database is None: + database = uconfig.get('RunDB', 'pymongo_database') + + super().__init__(uri=uri, + database=database, + take_only=take_only, + col_name=col_name, + *args, **kwargs) + self.readonly = readonly diff --git a/straxen/plugins/__init__.py b/straxen/plugins/__init__.py index 934a4907f..dbf37b3e9 100644 --- a/straxen/plugins/__init__.py +++ b/straxen/plugins/__init__.py @@ -28,6 +28,9 @@ from . import event_processing from .event_processing import * +from . import afterpulse_processing +from .afterpulse_processing import * + from . import double_scatter from .double_scatter import * diff --git a/straxen/plugins/acqmon_processing.py b/straxen/plugins/acqmon_processing.py index 8489d51b0..879fddeff 100644 --- a/straxen/plugins/acqmon_processing.py +++ b/straxen/plugins/acqmon_processing.py @@ -15,9 +15,9 @@ @export -@strax.takes_config(strax.Option('hit_min_amplitude_aqmon', default=50, track=True, +@strax.takes_config(strax.Option('hit_min_amplitude_aqmon', default=50, track=True, infer_type=False, help='Minimum hit threshold in ADC*counts above baseline'), - strax.Option('baseline_samples_aqmon', default=10, track=True, + strax.Option('baseline_samples_aqmon', default=10, track=True, infer_type=False, help='Number of samples to use at the start of the pulse to determine the baseline')) class AqmonHits(strax.Plugin): """ Find hits in acquisition monitor data. These hits could be diff --git a/straxen/plugins/afterpulse_processing.py b/straxen/plugins/afterpulse_processing.py new file mode 100644 index 000000000..71ade3a41 --- /dev/null +++ b/straxen/plugins/afterpulse_processing.py @@ -0,0 +1,298 @@ +import numba +import numpy as np +import strax +import straxen +from straxen.get_corrections import is_cmt_option + +export, __all__ = strax.exporter() + + +@export +@strax.takes_config( + strax.Option('gain_model', infer_type=False, + help='PMT gain model. Specify as (model_type, model_config)', + ), + strax.Option('n_tpc_pmts', + type=int, + help="Number of PMTs in TPC", + ), + strax.Option('LED_window_left', + default=50, infer_type=False, + help='Left boundary of sample range for LED pulse integration', + ), + strax.Option('LED_window_right', + default=100, infer_type=False, + help='Right boundary of sample range for LED pulse integration', + ), + strax.Option('baseline_samples', + default=40, infer_type=False, + help='Number of samples to use at start of WF to determine the baseline', + ), + strax.Option('hit_min_amplitude', + track=True, infer_type=False, + default=('hit_thresholds_tpc', 'ONLINE', True), + help='Minimum hit amplitude in ADC counts above baseline. ' + 'Specify as a tuple of length n_tpc_pmts, or a number,' + 'or a string like "pmt_commissioning_initial" which means calling' + 'hitfinder_thresholds.py' + 'or a tuple like (correction=str, version=str, nT=boolean),' + 'which means we are using cmt.', + ), + strax.Option('hit_min_height_over_noise', + default=4, infer_type=False, + help='Minimum hit amplitude in numbers of baseline_rms above baseline.' + 'Actual threshold used is max(hit_min_amplitude, hit_min_' + 'height_over_noise * baseline_rms).', + ), + strax.Option('save_outside_hits', + default=(3, 20), infer_type=False, + help='Save (left, right) samples besides hits; cut the rest', + ), +) +class LEDAfterpulseProcessing(strax.Plugin): + __version__ = '0.5.1' + depends_on = 'raw_records' + data_kind = 'afterpulses' + provides = 'afterpulses' + compressor = 'zstd' + parallel = 'process' + rechunk_on_save = True + + def infer_dtype(self): + dtype = dtype_afterpulses() + return dtype + + def setup(self): + self.to_pe = straxen.get_correction_from_cmt(self.run_id, self.config['gain_model']) + self.hit_left_extension, self.hit_right_extension = self.config['save_outside_hits'] + + # Check config of `hit_min_amplitude` and define hit thresholds + # if cmt config + if is_cmt_option(self.config['hit_min_amplitude']): + self.hit_thresholds = straxen.get_correction_from_cmt( + self.run_id, + self.config['hit_min_amplitude']) + # if hitfinder_thresholds config + elif isinstance(self.config['hit_min_amplitude'], str): + self.hit_thresholds = straxen.hit_min_amplitude( + self.config['hit_min_amplitude']) + else: # int or array + self.hit_thresholds = self.config['hit_min_amplitude'] + + def compute(self, raw_records): + + # Convert everything to the records data type -- adds extra fields. + records = strax.raw_to_records(raw_records) + del raw_records + + # calculate baseline and baseline rms + strax.baseline(records, + baseline_samples=self.config['baseline_samples'], + flip=True) + + # find all hits + hits = strax.find_hits(records, + min_amplitude=self.hit_thresholds, + min_height_over_noise=self.config['hit_min_height_over_noise'], + ) + + # sort hits by record_i and time, then find LED hit and afterpulse + # hits within the same record + hits_ap = find_ap(hits, + records, + LED_window_left=self.config['LED_window_left'], + LED_window_right=self.config['LED_window_right'], + hit_left_extension=self.hit_left_extension, + hit_right_extension=self.hit_right_extension, + ) + + hits_ap['area_pe'] = hits_ap['area'] * self.to_pe[hits_ap['channel']] + hits_ap['height_pe'] = hits_ap['height'] * self.to_pe[hits_ap['channel']] + + return hits_ap + + +@export +def find_ap(hits, records, LED_window_left, LED_window_right, hit_left_extension, + hit_right_extension): + buffer = np.zeros(len(hits), dtype=dtype_afterpulses()) + + if not len(hits): + return buffer + + # sort hits first by record_i, then by time + hits_sorted = np.sort(hits, order=('record_i', 'time')) + res = _find_ap(hits_sorted, records, LED_window_left, LED_window_right, + hit_left_extension, hit_right_extension, buffer=buffer) + return res + + +@numba.jit(nopython=True, nogil=True, cache=True) +def _find_ap(hits, records, LED_window_left, LED_window_right, hit_left_extension, + hit_right_extension, buffer=None): + # hits need to be sorted by record_i, then time! + offset = 0 + + is_LED = False + t_LED = None + + prev_record_i = hits[0]['record_i'] + record_data = records[prev_record_i]['data'] + record_len = records[prev_record_i]['length'] + baseline_fpart = records[prev_record_i]['baseline'] % 1 + + for h_i, h in enumerate(hits): + + if h['record_i'] > prev_record_i: + # start of a new record + is_LED = False + # only increment buffer if the old one is not empty! this happens + # when no (LED) hit is found in the previous record + if not buffer[offset]['time'] == 0: + offset += 1 + prev_record_i = h['record_i'] + record_data = records[prev_record_i]['data'] + baseline_fpart = records[prev_record_i]['baseline'] % 1 + + res = buffer[offset] + + if h['left'] < LED_window_left: + # if hit is before LED window: discard + continue + + if h['left'] < LED_window_right: + # hit is in LED window + if not is_LED: + # this is the first hit in the LED window + fill_hitpars(res, h, hit_left_extension, hit_right_extension, + record_data, record_len, baseline_fpart) + + # set the LED time in the current WF + t_LED = res['sample_10pc_area'] + is_LED = True + + continue + + # more hits in LED window: extend the first (merging all hits in the LED window) + fill_hitpars(res, h, hit_left_extension, hit_right_extension, + record_data, record_len, + baseline_fpart, extend=True) + + t_LED = res['sample_10pc_area'] + + continue + + # Here begins a new hit after the LED window + if (h['left'] >= LED_window_right) and not is_LED: + # no LED hit found: ignore and go to next hit (until new record begins) + continue + + # if a hit is completely inside the previous hit's right_extension, + # then skip it (because it's already included in the previous hit) + if h['right'] <= res['right_integration']: + continue + + # if a hit only partly overlaps with the previous hit's right_ + # extension, merge them (extend previous hit by this one) + if h['left'] <= res['right_integration']: + fill_hitpars(res, h, hit_left_extension, hit_right_extension, record_data, record_len, + baseline_fpart, extend=True) + + res['tdelay'] = res['sample_10pc_area'] - t_LED + + continue + + # an actual new hit increases the buffer index + offset += 1 + res = buffer[offset] + + fill_hitpars(res, h, hit_left_extension, hit_right_extension, record_data, record_len, + baseline_fpart) + + res['tdelay'] = res['sample_10pc_area'] - t_LED + + return buffer[:offset] + + +@export +@numba.jit(nopython=True, nogil=True, cache=True) +def get_sample_area_quantile(data, quantile, baseline_fpart): + """ + returns first sample index in hit where integrated area of hit is above total area + """ + + area = 0 + area_tot = data.sum() + len(data) * baseline_fpart + + for d_i, d in enumerate(data): + area += d + baseline_fpart + if area > (quantile * area_tot): + return d_i + if d_i == len(data) - 1: + # if no quantile was found, something went wrong + # (negative area due to wrong baseline, caused by real events that + # by coincidence fall in the first samples of the trigger window) + # print('no quantile found: set to 0') + return 0 + + # What happened here?! + return 0 + + +@numba.jit(nopython=True, nogil=True, cache=True) +def fill_hitpars(result, hit, hit_left_extension, hit_right_extension, record_data, record_len, + baseline_fpart, extend=False): + if not extend: # fill first time only + result['time'] = hit['time'] - hit_left_extension * hit['dt'] + result['dt'] = hit['dt'] + result['channel'] = hit['channel'] + result['left'] = hit['left'] + result['record_i'] = hit['record_i'] + result['threshold'] = hit['threshold'] + result['left_integration'] = hit['left'] - hit_left_extension + result['height'] = hit['height'] + + # fill always (if hits are merged, only these will be updated) + result['right'] = hit['right'] + result['right_integration'] = hit['right'] + hit_right_extension + if result['right_integration'] > record_len: + result['right_integration'] = record_len # cap right_integration at end of record + result['length'] = result['right_integration'] - result['left_integration'] + + hit_data = record_data[result['left_integration']:result['right_integration']] + result['area'] = hit_data.sum() + result['length'] * baseline_fpart + result['sample_10pc_area'] = result['left_integration'] + get_sample_area_quantile( + hit_data, 0.1, baseline_fpart) + result['sample_50pc_area'] = result['left_integration'] + get_sample_area_quantile( + hit_data, 0.5, baseline_fpart) + if len(hit_data): + result['max'] = result['left_integration'] + hit_data.argmax() + + if extend: # only when merging hits + result['height'] = max(result['height'], hit['height']) + + +@export +def dtype_afterpulses(): + # define new data type for afterpulse data + dtype_ap = [ + (('Channel/PMT number', 'channel'), 'self.config['s2_min_area_pattern_fit']) - cur_s2_bool &= (events[t_+'_index']!=-1) - cur_s2_bool &= (events[t_+'_area_fraction_top']>0) - cur_s2_bool &= (x**2 + y**2) < self.config['max_r_pattern_fit']**2 + s2_mask = (events[t_+'_area']>self.config['s2_min_area_pattern_fit']) + s2_mask &= (events[t_+'_area_fraction_top']>0) + s2_mask &= (x**2 + y**2) < self.config['max_r_pattern_fit']**2 # default value is nan, it will be ovewrite if the event satisfy the requirments result[t_+'_2llh'][:] = np.nan # Making expectation patterns [ in PE ] - if np.sum(cur_s2_bool): - s2_map_effs = self.s2_pattern_map(np.array([x, y]).T)[cur_s2_bool, 0:self.config['n_top_pmts']] + if np.sum(s2_mask): + s2_map_effs = self.s2_pattern_map(np.array([x, y]).T)[s2_mask, 0:self.config['n_top_pmts']] s2_map_effs = s2_map_effs[:, self.pmtbool_top] - s2_top_area = (events[t_+'_area_fraction_top']*events[t_+'_area'])[cur_s2_bool] + s2_top_area = (events[t_+'_area_fraction_top']*events[t_+'_area'])[s2_mask] s2_pattern = s2_top_area[:, None]*s2_map_effs/np.sum(s2_map_effs, axis=1)[:,None] # Getting pattern from data - s2_top_area_per_channel = events[t_+'_area_per_channel'][cur_s2_bool, 0:self.config['n_top_pmts']] + s2_top_area_per_channel = events[t_+'_area_per_channel'][s2_mask, 0:self.config['n_top_pmts']] s2_top_area_per_channel = s2_top_area_per_channel[:, self.pmtbool_top] # Calculating LLH, this is shifted Poisson @@ -281,17 +319,33 @@ def compute_s2_llhvalue(self, events, result): areas = s2_top_area_per_channel, mean_pe_photon=self.mean_pe_photon) ) - result[t_+'_2llh'][cur_s2_bool] = np.sum(norm_llh_val, axis=1) + result[t_+'_2llh'][s2_mask] = np.sum(norm_llh_val, axis=1) if self.config['store_per_channel']: store_patterns = np.zeros((s2_pattern.shape[0], self.config['n_top_pmts']) ) store_patterns[:, self.pmtbool_top] = s2_pattern - result[t_+'_pattern'][cur_s2_bool] = store_patterns#:s2_pattern[cur_s2_bool] + result[t_+'_pattern'][s2_mask] = store_patterns#:s2_pattern[s2_mask] store_2LLH_ch = np.zeros((norm_llh_val.shape[0], self.config['n_top_pmts']) ) store_2LLH_ch[:, self.pmtbool_top] = norm_llh_val - result[t_+'_2llh_per_channel'][cur_s2_bool] = store_2LLH_ch - + result[t_+'_2llh_per_channel'][s2_mask] = store_2LLH_ch + + def compute_s2_neural_llhvalue(self, events, result): + for t_ in ['s2', 'alt_s2']: + x, y = events[t_ + '_x'], events[t_ + '_y'] + s2_mask = (events[t_ + '_area'] > self.config['s2_min_area_pattern_fit']) + s2_mask &= (events[t_ + '_area_fraction_top'] > 0) + + # default value is nan, it will be ovewrite if the event satisfy the requirements + result[t_ + '_neural_2llh'][:] = np.nan + + # Produce position and top pattern to feed tensorflow model, return chi2/N + if np.sum(s2_mask): + s2_pos = np.stack((x, y)).T[s2_mask] + s2_pat = events[t_ + '_area_per_channel'][s2_mask, 0:self.config['n_top_pmts']] + # Output[0]: loss function, -2*log-likelihood, Output[1]: chi2 + result[t_ + '_neural_2llh'][s2_mask] = self.model_chi2.predict({'xx': s2_pos, 'yy': s2_pat})[1] + @staticmethod def _infer_map_format(map_name, known_formats=('pkl', 'json', 'json.gz')): for fmt in known_formats: @@ -325,162 +379,124 @@ def neg2llh_modpoisson(mu=None, areas=None, mean_pe_photon=1.0): # continuous and discrete binomial test -# https://github.com/poliastro/cephes/blob/master/src/bdtr.c -@numba.vectorize([numba.float64(numba.float64, numba.float64, numba.float64)]) -def binom_pmf(k, n, p): +@numba.njit +def lbinom_pmf(k, n, p): + """Log of binomial probability mass function approximated with gamma function""" scale_log = numba_gammaln(n + 1) - numba_gammaln(n - k + 1) - numba_gammaln(k + 1) ret_log = scale_log + k * np.log(p) + (n - k) * np.log(1 - p) - return np.exp(ret_log) + return ret_log + + +@numba.njit +def binom_pmf(k, n, p): + """Binomial probability mass function approximated with gamma function""" + return np.exp(lbinom_pmf(k, n, p)) + @numba.njit def binom_cdf(k, n, p): - if k < 0: - return np.nan - if k == n: + if k >= n: return 1.0 - dn = n - k - if k == 0: - dk = np.exp(dn * np.log(1.0 - p)) - else: - dk = k + 1 - dk = numba_betainc(dn, dk, 1.0 - p) - return dk + return numba_betainc(n - k, k + 1, 1.0 - p) + @numba.njit def binom_sf(k, n, p): - if k < 0: - return 1.0 - if k == n: - return 0.0 - dn = n - k - if k == 0: - if p < .01: - dk = -np.expm1(dn * np.log1p(-p)) - else: - dk = 1.0 - np.exp(dn * np.log(1.0 - p)) - else: - dk = k + 1 - dk = numba_betainc(dk, dn, p) - return dk + return 1 - binom_cdf(k, n, p) + @numba.njit -def _find_interval(n, p): - mu = n*p - sigma = np.sqrt(n*p*(1-p)) - if mu-2*sigma < 0: - s1_min_range = 0 - else: - s1_min_range = mu-2*sigma - if mu+2*sigma > n: - s1_max_range = n +def lbinom_pmf_diriv(k, n, p, dk=1e-7): + """Numerical dirivitive of Binomial pmf approximated with gamma function""" + if k + dk < n: + return (lbinom_pmf(k + dk, n, p) - lbinom_pmf(k, n, p)) / dk else: - s1_max_range = mu+2*sigma - return s1_min_range, s1_max_range + return (lbinom_pmf(k - dk, n, p) - lbinom_pmf(k, n, p)) / - dk + + +@numba.njit(cache=True) +def _numeric_derivative(y0, y1, err, target, x_min, x_max, x0, x1): + """Get close to by doing a numeric derivative""" + if abs(y1 - y0) < err: + # break by passing dx == 0 + return 0., x1, x1 + + x = (target - y0) / (y1 - y0) * (x1 - x0) + x0 + x = min(x, x_max) + x = max(x, x_min) + + dx = abs(x - x1) + x0 = x1 + x1 = x + + return dx, x0, x1 + @numba.njit -def _get_min_and_max(s1_min, s1_max, n, p, shift): - s1 = np.arange(s1_min, s1_max, 0.01) - ds1 = binom_pmf(s1, n, p) - s1argmax = np.argmax(ds1) - if np.argmax(ds1) - shift > 0: - minimum = s1[s1argmax - shift] - else: - minimum = s1[0] - maximum = s1[s1argmax + shift] - return minimum, maximum, s1[s1argmax] +def lbinom_pmf_mode(x_min, x_max, target, args, err=1e-7, max_iter=50): + """Find the root of the derivative of log Binomial pmf with secant method""" + x0 = x_min + x1 = x_max + dx = abs(x1 - x0) + + while (dx > err) and (max_iter > 0): + y0 = lbinom_pmf_diriv(x0, *args) + y1 = lbinom_pmf_diriv(x1, *args) + dx, x0, x1 = _numeric_derivative(y0, y1, err, target, x_min, x_max, x0, x1) + max_iter -= 1 + return x1 + + +@numba.njit +def lbinom_pmf_inverse(x_min, x_max, target, args, err=1e-7, max_iter=50): + """Find the where the log Binomial pmf cross target with secant method""" + x0 = x_min + x1 = x_max + dx = abs(x1 - x0) + + while (dx > err) and (max_iter > 0): + y0 = lbinom_pmf(x0, *args) + y1 = lbinom_pmf(x1, *args) + dx, x0, x1 = _numeric_derivative(y0, y1, err, target, x_min, x_max, x0, x1) + max_iter -= 1 + return x1 + @numba.njit def binom_test(k, n, p): """ The main purpose of this algorithm is to find the value j on the - other side of the mean that has the same probability as k, and + other side of the mode that has the same probability as k, and integrate the tails outward from k and j. In the case where either k or j are zero, only the non-zero tail is integrated. """ - # define the S1 interval for finding the maximum - _s1_min_range, _s1_max_range = _find_interval(n, p) - - # compute the binomial probability for each S1 and define the Binom range for later - minimum, maximum, s1max = _get_min_and_max(_s1_min_range, _s1_max_range, n, p, shift=2) - - # comments TODO - _d0 = binom_pmf(0, n, p) - _dmax = binom_pmf(s1max, n, p) - _dn = binom_pmf(n, n, p) - d = binom_pmf(k, n, p) - rerr = 1 + 1e-7 - d = d * rerr - _d0 = _d0 * rerr - _dmax = _dmax * rerr - _dn = _dn * rerr - # define number of interaction for finding the the value j - # the exceptional case of n<=0, is avoid since n_iter is at least 2 - if n > 0: - n_iter = int(np.round_(np.log10(n)) + 1) - n_iter = max(n_iter, 2) - else: - n_iter = 2 - # comments TODO - if k < minimum: - if (_d0 >= d) and (_d0 > _dmax): - n_iter, j_min, j_max = -1, 0, 0 - elif _dn > d: - n_iter, j_min, j_max = -2, 0, 0 - else: - j_min, j_max = s1max, n - def _check_(d, y0, y1): - return (d>y1) and (d<=y0) - # comments TODO - elif k>maximum: - if _d0 >= d: - n_iter, j_min, j_max = -1, 0, 0 - else: - j_min, j_max = 0, s1max - def _check_(d, y0, y1): - return (d>=y0) and (d=y0) and (d 0: + pval += binom_cdf(min(k, j), n, p) + if max(k, j) > 0: + pval += binom_sf(max(k, j), n, p) + pval = min(1.0, pval) + return pval -def _s1_area_fraction_top_probability(aft_prob, area_tot, area_fraction_top, mode='continuous'): - ''' - Wrapper that does the S1 AFT probability calculation for you - ''' - + +@np.vectorize +@numba.njit +def s1_area_fraction_top_probability(aft_prob, area_tot, area_fraction_top, mode='continuous'): + """Function to compute the S1 AFT probability""" area_top = area_tot * area_fraction_top - + # Raise a warning in case one of these three condition is verified # and return binomial test equal to nan since they are not physical # k: size_top, n: size_tot, p: aft_prob @@ -497,11 +513,11 @@ def _s1_area_fraction_top_probability(aft_prob, area_tot, area_fraction_top, mod # warnings.warn(f'k {area_top} must be >= 0') binomial_test = np.nan do_test = False - + if do_test: if mode == 'discrete': binomial_test = binom_pmf(area_top, area_tot, aft_prob) else: binomial_test = binom_test(area_top, area_tot, aft_prob) - + return binomial_test diff --git a/straxen/plugins/event_processing.py b/straxen/plugins/event_processing.py index 6d354c8e9..b5dc04b9c 100644 --- a/straxen/plugins/event_processing.py +++ b/straxen/plugins/event_processing.py @@ -2,9 +2,8 @@ import numpy as np import numba import straxen -from warnings import warn -from .position_reconstruction import DEFAULT_POSREC_ALGO_OPTION -from straxen.common import pax_file, get_resource, first_sr1_run, pre_apply_function +from .position_reconstruction import DEFAULT_POSREC_ALGO +from straxen.common import pax_file, get_resource, first_sr1_run from straxen.get_corrections import get_correction_from_cmt, get_cmt_resource, is_cmt_option from straxen.itp_map import InterpolatingMap export, __all__ = strax.exporter() @@ -12,30 +11,34 @@ @export @strax.takes_config( - strax.Option('trigger_min_area', default=100, + strax.Option('trigger_min_area', default=100, type=(int,float), help='Peaks must have more area (PE) than this to ' 'cause events'), - strax.Option('trigger_max_competing', default=7, + strax.Option('trigger_max_competing', default=7, type=int, help='Peaks must have FEWER nearby larger or slightly smaller' ' peaks to cause events'), - strax.Option('left_event_extension', default=int(0.25e6), + strax.Option('left_event_extension', default=int(0.25e6), type=(int, float), help='Extend events this many ns to the left from each ' 'triggering peak. This extension is added to the maximum ' 'drift time.', ), - strax.Option('right_event_extension', default=int(0.25e6), + strax.Option('right_event_extension', default=int(0.25e6), type=(int, float), help='Extend events this many ns to the right from each ' 'triggering peak.', ), - strax.Option(name='electron_drift_velocity', + strax.Option(name='electron_drift_velocity', infer_type=False, default=("electron_drift_velocity", "ONLINE", True), help='Vertical electron drift velocity in cm/ns (1e4 m/ms)', ), strax.Option(name='max_drift_length', - default=straxen.tpc_z, + default=straxen.tpc_z, type=(int, float), help='Total length of the TPC from the bottom of gate to the ' 'top of cathode wires [cm]', ), + strax.Option(name='exclude_s1_as_triggering_peaks', + default=True, type=bool, + help='If true exclude S1s as triggering peaks.', + ), ) class Events(strax.OverlapWindowPlugin): """ @@ -56,7 +59,7 @@ class Events(strax.OverlapWindowPlugin): depends_on = ['peak_basics', 'peak_proximity'] provides = 'events' data_kind = 'events' - __version__ = '0.0.1' + __version__ = '0.1.0' save_when = strax.SaveWhen.NEVER dtype = [ @@ -71,6 +74,10 @@ def setup(self): self.run_id, self.config['electron_drift_velocity']) self.drift_time_max = int(self.config['max_drift_length'] / electron_drift_velocity) + # Left_extension and right_extension should be computed in setup to be + # reflected in cutax too. + self.left_extension = self.config['left_event_extension'] + self.drift_time_max + self.right_extension = self.config['right_event_extension'] def get_window_size(self): # Take a large window for safety, events can have long tails @@ -79,19 +86,19 @@ def get_window_size(self): + self.config['right_event_extension']) def compute(self, peaks, start, end): - le = self.config['left_event_extension'] + self.drift_time_max - re = self.config['right_event_extension'] + _is_triggering = peaks['area'] > self.config['trigger_min_area'] + _is_triggering &= (peaks['n_competing'] <= self.config['trigger_max_competing']) + if self.config['exclude_s1_as_triggering_peaks']: + _is_triggering &= peaks['type'] == 2 - triggers = peaks[ - (peaks['area'] > self.config['trigger_min_area']) - & (peaks['n_competing'] <= self.config['trigger_max_competing'])] + triggers = peaks[_is_triggering] # Join nearby triggers t0, t1 = strax.find_peak_groups( triggers, - gap_threshold=le + re + 1, - left_extension=le, - right_extension=re) + gap_threshold=self.left_extension + self.right_extension + 1, + left_extension=self.left_extension, + right_extension=self.right_extension) # Don't extend beyond the chunk boundaries # This will often happen for events near the invalid boundary of the @@ -115,15 +122,15 @@ def compute(self, peaks, start, end): @export @strax.takes_config( strax.Option( - name='allow_posts2_s1s', default=False, + name='allow_posts2_s1s', default=False, infer_type=False, help="Allow S1s past the main S2 to become the main S1 and S2"), strax.Option( - name='force_main_before_alt', default=False, + name='force_main_before_alt', default=False, infer_type=False, help="Make the alternate S1 (and likewise S2) the main S1 if " "occurs before the main S1."), strax.Option( name='event_s1_min_coincidence', - default=2, + default=2, infer_type=False, help="Event level S1 min coincidence. Should be >= s1_min_coincidence " "in the peaklet classification"), ) @@ -135,7 +142,7 @@ class EventBasics(strax.Plugin): The main S2 and alternative S2 are given by the largest two S2-Peaks within the event. By default this is also true for S1. """ - __version__ = '1.1.1' + __version__ = '1.2.1' depends_on = ('events', 'peak_basics', @@ -153,7 +160,7 @@ def infer_dtype(self): dtype += strax.time_fields dtype += [('n_peaks', np.int32, 'Number of peaks in the event'), - ('drift_time', np.int32, + ('drift_time', np.float32, 'Drift time between main S1 and S2 in ns'), ('event_number', np.int64, 'Event number in this dataset'), @@ -195,7 +202,9 @@ def _set_dtype_requirements(self): ('range_50p_area', np.float32, 'width, 50% area [ns]'), ('range_90p_area', np.float32, 'width, 90% area [ns]'), ('rise_time', np.float32, 'time between 10% and 50% area quantiles [ns]'), - ('area_fraction_top', np.float32, 'fraction of area seen by the top PMT array') + ('area_fraction_top', np.float32, 'fraction of area seen by the top PMT array'), + ('tight_coincidence', np.int16, 'Channel within tight range of mean'), + ('n_saturated_channels', np.int16, 'Total number of saturated channels'), ) @staticmethod @@ -215,7 +224,7 @@ def _get_si_dtypes(peak_properties): # Drifts and delays si_dtype += [ - (f'alt_s{s_i}_interaction_drift_time', np.int32, + (f'alt_s{s_i}_interaction_drift_time', np.float32, f'Drift time using alternate S{s_i} [ns]'), (f'alt_s{s_i}_delay', np.int32, f'Time between main and alternate S{s_i} [ns]')] @@ -419,17 +428,17 @@ def set_sx_index(res, s1_idx, s2_idx): @export @strax.takes_config( strax.Option( - name='electron_drift_velocity', + name='electron_drift_velocity', infer_type=False, help='Vertical electron drift velocity in cm/ns (1e4 m/ms)', default=("electron_drift_velocity", "ONLINE", True) ), strax.Option( - name='electron_drift_time_gate', + name='electron_drift_time_gate', infer_type=False, help='Electron drift time from the gate in ns', default=("electron_drift_time_gate", "ONLINE", True) ), strax.Option( - name='fdc_map', + name='fdc_map', infer_type=False, help='3D field distortion correction map path', default_by_run=[ (0, pax_file('XENON1T_FDC_SR0_data_driven_3d_correction_tf_nn_v0.json.gz')), # noqa @@ -438,7 +447,6 @@ def set_sx_index(res, s1_idx, s2_idx): (170704_0556, pax_file('XENON1T_FDC_SR1_data_driven_time_dependent_3d_correction_tf_nn_part3_v1.json.gz')), # noqa (170925_0622, pax_file('XENON1T_FDC_SR1_data_driven_time_dependent_3d_correction_tf_nn_part4_v1.json.gz'))], # noqa ), - *DEFAULT_POSREC_ALGO_OPTION ) class EventPositions(strax.Plugin): """ @@ -453,6 +461,11 @@ class EventPositions(strax.Plugin): __version__ = '0.1.4' + default_reconstruction_algorithm = straxen.URLConfig( + default=DEFAULT_POSREC_ALGO, + help="default reconstruction algorithm that provides (x,y)" + ) + dtype = [ ('x', np.float32, 'Interaction x-position, field-distortion corrected (cm)'), @@ -476,8 +489,10 @@ class EventPositions(strax.Plugin): def setup(self): - self.electron_drift_velocity = get_correction_from_cmt(self.run_id, self.config['electron_drift_velocity']) - self.electron_drift_time_gate = get_correction_from_cmt(self.run_id, self.config['electron_drift_time_gate']) + self.electron_drift_velocity = get_correction_from_cmt( + self.run_id, self.config['electron_drift_velocity']) + self.electron_drift_time_gate = get_correction_from_cmt( + self.run_id, self.config['electron_drift_time_gate']) if isinstance(self.config['fdc_map'], str): self.map = InterpolatingMap( @@ -528,6 +543,8 @@ def compute(self, events): 'r_field_distortion_correction': delta_r, 'theta': np.arctan2(orig_pos[:, 1], orig_pos[:, 0]), 'z_naive': z_obs, + # using z_obs in agreement with the dtype description + # the FDC for z (z_cor) is found to be not reliable (see #527) 'z': z_obs, 'z_field_distortion_correction': delta_z }) @@ -590,31 +607,6 @@ def get_veto_tags(events, split_tags, result): @export -@strax.takes_config( - strax.Option( - 's1_xyz_correction_map', - help="S1 relative (x, y, z) correction map", - default_by_run=[ - (0, pax_file('XENON1T_s1_xyz_lce_true_kr83m_SR0_pax-680_fdc-3d_v0.json')), # noqa - (first_sr1_run, pax_file('XENON1T_s1_xyz_lce_true_kr83m_SR1_pax-680_fdc-3d_v0.json'))]), # noqa - strax.Option( - 's2_xy_correction_map', - help="S2 (x, y) correction map. Correct S2 position dependence " - "manly due to bending of anode/gate-grid, PMT quantum efficiency " - "and extraction field distribution, as well as other geometric factors.", - default_by_run=[ - (0, pax_file('XENON1T_s2_xy_ly_SR0_24Feb2017.json')), - (170118_1327, pax_file('XENON1T_s2_xy_ly_SR1_v2.2.json'))]), - strax.Option( - 'elife_conf', - default=("elife", "ONLINE", True), - help='Electron lifetime ' - 'Specify as (model_type->str, model_config->str, is_nT->bool) ' - 'where model_type can be "elife" or "elife_constant" ' - 'and model_config can be a version.' - ), - *DEFAULT_POSREC_ALGO_OPTION -) class CorrectedAreas(strax.Plugin): """ Plugin which applies light collection efficiency maps and electron @@ -626,88 +618,136 @@ class CorrectedAreas(strax.Plugin): Note: Please be aware that for both, the main and alternative S1, the area is corrected according to the xy-position of the main S2. + + There are now 3 components of cS2s: cs2_top, cS2_bottom and cs2. + cs2_top and cs2_bottom are corrected by the corresponding maps, + and cs2 is the sum of the two. """ - __version__ = '0.1.1' + __version__ = '0.2.0' depends_on = ['event_basics', 'event_positions'] - dtype = [('cs1', np.float32, 'Corrected S1 area [PE]'), - ('cs2', np.float32, 'Corrected S2 area [PE]'), - ('alt_cs1', np.float32, 'Corrected area of the alternate S1 [PE]'), - ('alt_cs2', np.float32, 'Corrected area of the alternate S2 [PE]') - ] + strax.time_fields - - def setup(self): - self.elife = get_correction_from_cmt(self.run_id, self.config['elife_conf']) - if isinstance(self.config['s1_xyz_correction_map'], str): - self.config['s1_xyz_correction_map'] = [self.config['s1_xyz_correction_map']] - if isinstance(self.config['s2_xy_correction_map'], str): - self.config['s2_xy_correction_map'] = [self.config['s2_xy_correction_map']] + # Descriptor configs + elife = straxen.URLConfig( + default='cmt://elife?version=ONLINE&run_id=plugin.run_id', + help='electron lifetime in [ns]') + + # default posrec, used to determine which LCE map to use + default_reconstruction_algorithm = straxen.URLConfig( + default=DEFAULT_POSREC_ALGO, + help="default reconstruction algorithm that provides (x,y)" + ) + s1_xyz_map = straxen.URLConfig( + default='itp_map://resource://cmt://format://' + 's1_xyz_map_{algo}?version=ONLINE&run_id=plugin.run_id' + '&fmt=json&algo=plugin.default_reconstruction_algorithm', + cache=True) + s2_xy_map = straxen.URLConfig( + default='itp_map://resource://cmt://format://' + 's2_xy_map_{algo}?version=ONLINE&run_id=plugin.run_id' + '&fmt=json&algo=plugin.default_reconstruction_algorithm', + cache=True) + + # average SE gain for a given time period. default to the value of this run in ONLINE model + # thus, by default, there will be no time-dependent correction according to se gain + avg_se_gain = straxen.URLConfig( + default='cmt://se_gain?version=ONLINE&run_id=plugin.run_id', + help='Nominal single electron (SE) gain in PE / electron extracted. ' + 'Data will be corrected to this value') + + # se gain for this run, allowing for using CMT. default to online + se_gain = straxen.URLConfig( + default='cmt://se_gain?version=ONLINE&run_id=plugin.run_id', + help='Actual SE gain for a given run (allows for time dependence)') + + # relative extraction efficiency which can change with time and modeled by CMT. + # defaults to no correction + rel_extraction_eff = straxen.URLConfig( + default=1.0, + help='Relative extraction efficiency for this run (allows for time dependence)') - self.s1_map = InterpolatingMap( - get_cmt_resource(self.run_id, - tuple(['suffix', - self.config['default_reconstruction_algorithm'], - *self.config['s1_xyz_correction_map']]), - fmt='text')) + def infer_dtype(self): + dtype = [] + dtype += strax.time_fields - self.s2_map = InterpolatingMap( - get_cmt_resource(self.run_id, - tuple([*self.config['s2_xy_correction_map']]), - fmt='text')) + for peak_type, peak_name in zip(['', 'alt_'], ['main', 'alternate']): + dtype += [(f'{peak_type}cs1', np.float32, f'Corrected area of {peak_name} S1 [PE]'), + (f'{peak_type}cs2_wo_elifecorr', np.float32, + f'Corrected area of {peak_name} S2 before elife correction ' + f'(s2 xy correction + SEG/EE correction applied) [PE]'), + (f'{peak_type}cs2_wo_timecorr', np.float32, + f'Corrected area of {peak_name} S2 before SEG/EE and elife corrections' + f'(s2 xy correction applied) [PE]'), + (f'{peak_type}cs2_area_fraction_top', np.float32, + f'Fraction of area seen by the top PMT array for corrected {peak_name} S2'), + (f'{peak_type}cs2_bottom', np.float32, + f'Corrected area of {peak_name} S2 in the bottom PMT array [PE]'), + (f'{peak_type}cs2', np.float32, f'Corrected area of {peak_name} S2 [PE]'), ] + return dtype def compute(self, events): + result = dict( + time=events['time'], + endtime=strax.endtime(events) + ) + # S1 corrections depend on the actual corrected event position. # We use this also for the alternate S1; for e.g. Kr this is # fine as the S1 correction varies slowly. event_positions = np.vstack([events['x'], events['y'], events['z']]).T - # For electron lifetime corrections to the S2s, - # use lifetimes computed using the main S1. - lifetime_corr = np.exp(events['drift_time'] / self.elife) - alt_lifetime_corr = ( - np.exp((events['alt_s2_interaction_drift_time']) - / self.elife)) - - # S2(x,y) corrections use the observed S2 positions - s2_positions = np.vstack([events['s2_x'], events['s2_y']]).T - alt_s2_positions = np.vstack([events['alt_s2_x'], events['alt_s2_y']]).T - - return dict( - time=events['time'], - endtime=strax.endtime(events), + for peak_type in ["", "alt_"]: + result[f"{peak_type}cs1"] = events[f'{peak_type}s1_area'] / self.s1_xyz_map(event_positions) - cs1=events['s1_area'] / self.s1_map(event_positions), - alt_cs1=events['alt_s1_area'] / self.s1_map(event_positions), + # s2 corrections + # S2 top and bottom are corrected separately, and cS2 total is the sum of the two + # figure out the map name + if len(self.s2_xy_map.map_names) > 1: + s2_top_map_name = "map_top" + s2_bottom_map_name = "map_bottom" + else: + s2_top_map_name = "map" + s2_bottom_map_name = "map" + + for peak_type in ["", "alt_"]: + # S2(x,y) corrections use the observed S2 positions + s2_positions = np.vstack([events[f'{peak_type}s2_x'], events[f'{peak_type}s2_y']]).T + + # corrected s2 with s2 xy map only, i.e. no elife correction + # this is for s2-only events which don't have drift time info + cs2_top_xycorr = (events[f'{peak_type}s2_area'] * events[f'{peak_type}s2_area_fraction_top'] / + self.s2_xy_map(s2_positions, map_name=s2_top_map_name)) + cs2_bottom_xycorr = (events[f'{peak_type}s2_area'] * + (1 - events[f'{peak_type}s2_area_fraction_top']) / + self.s2_xy_map(s2_positions, map_name=s2_bottom_map_name)) + + # Correct for SEgain and extraction efficiency + seg_ee_corr = (self.se_gain / self.avg_se_gain) * self.rel_extraction_eff + cs2_top_wo_elifecorr = cs2_top_xycorr / seg_ee_corr + cs2_bottom_wo_elifecorr = cs2_bottom_xycorr / seg_ee_corr + result[f"{peak_type}cs2_wo_elifecorr"] = cs2_top_wo_elifecorr + cs2_bottom_wo_elifecorr + + # cs2aft doesn't need elife/time corrections as they cancel + result[f"{peak_type}cs2_area_fraction_top"] = cs2_top_wo_elifecorr / result[f"{peak_type}cs2_wo_elifecorr"] + + + # For electron lifetime corrections to the S2s, + # use drift time computed using the main S1. + el_string = peak_type + "s2_interaction_" if peak_type == "alt_" else peak_type + elife_correction = np.exp(events[f'{el_string}drift_time'] / self.elife) + result[f"{peak_type}cs2_wo_timecorr"] = (cs2_top_xycorr + cs2_bottom_xycorr) * elife_correction + result[f"{peak_type}cs2"] = result[f"{peak_type}cs2_wo_elifecorr"] * elife_correction + result[f"{peak_type}cs2_bottom"] = cs2_bottom_wo_elifecorr * elife_correction - cs2=(events['s2_area'] * lifetime_corr - / self.s2_map(s2_positions)), - alt_cs2=(events['alt_s2_area'] * alt_lifetime_corr - / self.s2_map(alt_s2_positions))) + return result @export -@strax.takes_config( - strax.Option( - 'g1', - help="S1 gain in PE / photons produced", - default_by_run=[(0, 0.1442), - (first_sr1_run, 0.1426)]), - strax.Option( - 'g2', - help="S2 gain in PE / electrons produced", - default_by_run=[(0, 11.52/(1 - 0.63)), - (first_sr1_run, 11.55/(1 - 0.63))]), - strax.Option( - 'lxe_w', - help="LXe work function in quanta/keV", - default=13.7e-3), -) class EnergyEstimates(strax.Plugin): """ Plugin which converts cS1 and cS2 into energies (from PE to KeVee). """ - __version__ = '0.1.0' + __version__ = '0.1.1' depends_on = ['corrected_areas'] dtype = [ ('e_light', np.float32, 'Energy in light signal [keVee]'), @@ -716,6 +756,20 @@ class EnergyEstimates(strax.Plugin): ] + strax.time_fields save_when = strax.SaveWhen.TARGET + # config options don't double cache things from the resource cache! + g1 = straxen.URLConfig( + default='bodega://g1?bodega_version=v2', + help="S1 gain in PE / photons produced", + ) + g2 = straxen.URLConfig( + default='bodega://g2?bodega_version=v2', + help="S2 gain in PE / electrons produced", + ) + lxe_w = straxen.URLConfig( + default=13.7e-3, + help="LXe work function in quanta/keV" + ) + def compute(self, events): el = self.cs1_to_e(events['cs1']) ec = self.cs2_to_e(events['cs2']) @@ -726,7 +780,65 @@ def compute(self, events): endtime=strax.endtime(events)) def cs1_to_e(self, x): - return self.config['lxe_w'] * x / self.config['g1'] + return self.lxe_w * x / self.g1 def cs2_to_e(self, x): - return self.config['lxe_w'] * x / self.config['g2'] + return self.lxe_w * x / self.g2 + + +@export +class EventShadow(strax.Plugin): + """ + This plugin can calculate shadow at event level. + It depends on peak-level shadow. + The event-level shadow is its first S2 peak's shadow. + If no S2 peaks, the event shadow will be nan. + It also gives the position infomation of the previous S2s + and main peaks' shadow. + """ + __version__ = '0.0.8' + depends_on = ('event_basics', 'peak_basics', 'peak_shadow') + provides = 'event_shadow' + save_when = strax.SaveWhen.EXPLICIT + + def infer_dtype(self): + dtype = [('s1_shadow', np.float32, 'main s1 shadow [PE/ns]'), + ('s2_shadow', np.float32, 'main s2 shadow [PE/ns]'), + ('shadow', np.float32, 'shadow of event [PE/ns]'), + ('pre_s2_area', np.float32, 'previous s2 area [PE]'), + ('shadow_dt', np.int64, 'time difference to the previous s2 [ns]'), + ('shadow_index', np.int32, 'max shadow peak index in event'), + ('pre_s2_x', np.float32, 'x of previous s2 peak causing shadow [cm]'), + ('pre_s2_y', np.float32, 'y of previous s2 peak causing shadow [cm]'), + ('shadow_distance', np.float32, 'distance to the s2 peak with max shadow [cm]')] + dtype += strax.time_fields + return dtype + + def compute(self, events, peaks): + split_peaks = strax.split_by_containment(peaks, events) + res = np.zeros(len(events), self.dtype) + + res['shadow_index'] = -1 + res['pre_s2_x'] = np.nan + res['pre_s2_y'] = np.nan + + for event_i, (event, sp) in enumerate(zip(events, split_peaks)): + if event['s1_index'] >= 0: + res['s1_shadow'][event_i] = sp['shadow'][event['s1_index']] + if event['s2_index'] >= 0: + res['s2_shadow'][event_i] = sp['shadow'][event['s2_index']] + if (sp['type'] == 2).sum() > 0: + # Define event shadow as the first S2 peak shadow + first_s2_index = np.argwhere(sp['type'] == 2)[0] + res['shadow_index'][event_i] = first_s2_index + res['shadow'][event_i] = sp['shadow'][first_s2_index] + res['pre_s2_area'][event_i] = sp['pre_s2_area'][first_s2_index] + res['shadow_dt'][event_i] = sp['shadow_dt'][first_s2_index] + res['pre_s2_x'][event_i] = sp['pre_s2_x'][first_s2_index] + res['pre_s2_y'][event_i] = sp['pre_s2_y'][first_s2_index] + res['shadow_distance'] = ((res['pre_s2_x'] - events['s2_x'])**2 + + (res['pre_s2_y'] - events['s2_y'])**2 + )**0.5 + res['time'] = events['time'] + res['endtime'] = strax.endtime(events) + return res diff --git a/straxen/plugins/led_calibration.py b/straxen/plugins/led_calibration.py index 62f0f4bbd..1fba7570e 100644 --- a/straxen/plugins/led_calibration.py +++ b/straxen/plugins/led_calibration.py @@ -16,16 +16,16 @@ @export @strax.takes_config( strax.Option('baseline_window', - default=(0,40), + default=(0,40), infer_type=False, help="Window (samples) for baseline calculation."), strax.Option('led_window', - default=(78, 116), + default=(78, 116), infer_type=False, help="Window (samples) where we expect the signal in LED calibration"), strax.Option('noise_window', - default=(10, 48), + default=(10, 48), infer_type=False, help="Window (samples) to analysis the noise"), strax.Option('channel_list', - default=(tuple(channel_list)), + default=(tuple(channel_list)), infer_type=False, help="List of PMTs. Defalt value: all the PMTs")) class LEDCalibration(strax.Plugin): diff --git a/straxen/plugins/nveto_recorder.py b/straxen/plugins/nveto_recorder.py index b5bc33eb8..59028e44d 100644 --- a/straxen/plugins/nveto_recorder.py +++ b/straxen/plugins/nveto_recorder.py @@ -19,11 +19,11 @@ help="Pretrigger time before coincidence window in ns."), strax.Option('resolving_time_recorder_nv', type=int, default=600, help="Resolving time of the coincidence in ns."), - strax.Option('baseline_samples_nv', + strax.Option('baseline_samples_nv', infer_type=False, default=('baseline_samples_nv', 'ONLINE', True), track=True, help="Number of samples used in baseline rms calculation"), strax.Option( - 'hit_min_amplitude_nv', + 'hit_min_amplitude_nv', infer_type=False, default=('hit_thresholds_nv', 'ONLINE', True), track=True, help='Minimum hit amplitude in ADC counts above baseline. ' 'Specify as a tuple of length n_nveto_pmts, or a number, ' @@ -37,7 +37,7 @@ help="frozendict mapping subdetector to (min, max) " "channel number."), strax.Option('check_raw_record_overlaps_nv', - default=True, track=False, + default=True, track=False, infer_type=False, help='Crash if any of the pulses in raw_records overlap with others ' 'in the same channel'), ) diff --git a/straxen/plugins/online_monitor.py b/straxen/plugins/online_monitor.py index b05168c5d..1445515a0 100644 --- a/straxen/plugins/online_monitor.py +++ b/straxen/plugins/online_monitor.py @@ -1,240 +1,302 @@ -import strax -import numpy as np -from immutabledict import immutabledict - -export, __all__ = strax.exporter() - - -@export -@strax.takes_config( - strax.Option( - 'area_vs_width_nbins', - type=int, default=60, - help='Number of bins for area vs width histogram for online monitor. ' - 'NB: this is a 2D histogram'), - strax.Option( - 'area_vs_width_bounds', - type=tuple, default=((0, 5), (0, 5)), - help='Boundaries of log-log histogram of area vs width'), - strax.Option( - 'area_vs_width_cut_string', - type=str, default='', - help='Selection (like selection_str) applied to data for ' - '"area_vs_width_hist_clean", cuts should be separated using "&"' - 'For example: (tight_coincidence > 2) & (area_fraction_top < 0.1)' - 'Default is no selection (other than "area_vs_width_min_gap")'), - strax.Option( - 'lone_hits_area_bounds', - type=tuple, default=(0, 1500), - help='Boundaries area histogram of lone hits [ADC]'), - strax.Option( - 'online_peak_monitor_nbins', - type=int, default=100, - help='Number of bins of histogram of online monitor. Will be used ' - 'for: ' - 'lone_hits_area-histogram, ' - 'area_fraction_top-histogram, ' - 'online_se_gain estimate (histogram is not stored), ' - ), - strax.Option( - 'lone_hits_cut_string', - type=str, - default='(area >= 50) & (area <= 250)', - help='Selection (like selection_str) applied to data for ' - '"lone-hits", cuts should be separated using "&")'), - strax.Option( - 'lone_hits_min_gap', - type=int, - default=15_000, - help='Minimal gap [ns] between consecutive lone-hits. To turn off ' - 'this cut, set to 0.'), - strax.Option( - 'n_tpc_pmts', type=int, - help='Number of TPC PMTs'), - strax.Option( - 'online_se_bounds', - type=tuple, default=(7, 70), - help='Window for online monitor [PE] to look for the SE gain, value' - ) -) -class OnlinePeakMonitor(strax.Plugin): - """ - Plugin to write data to the online-monitor. Data that is written by - this plugin should be small such as to not overload the runs- - database. - - This plugin takes 'peak_basics' and 'lone_hits'. Although they are - not strictly related, they are aggregated into a single data_type - in order to minimize the number of documents in the online monitor. - - Produces 'online_peak_monitor' with info on the lone-hits and peaks - """ - depends_on = ('peak_basics', 'lone_hits') - provides = 'online_peak_monitor' - data_kind = 'online_peak_monitor' - __version__ = '0.0.5' - rechunk_on_save = False - - def infer_dtype(self): - n_bins_area_width = self.config['area_vs_width_nbins'] - bounds_area_width = self.config['area_vs_width_bounds'] - - n_bins = self.config['online_peak_monitor_nbins'] - - n_tpc_pmts = self.config['n_tpc_pmts'] - dtype = [ - (('Start time of the chunk', 'time'), - np.int64), - (('End time of the chunk', 'endtime'), - np.int64), - (('Area vs width histogram (log-log)', 'area_vs_width_hist'), - (np.int64, (n_bins_area_width, n_bins_area_width))), - (('Area vs width edges (log-space)', 'area_vs_width_bounds'), - (np.float64, np.shape(bounds_area_width))), - (('Lone hits areas histogram [ADC-counts]', 'lone_hits_area_hist'), - (np.int64, n_bins)), - (('Lone hits areas bounds [ADC-counts]', 'lone_hits_area_bounds'), - (np.float64, 2)), - (('Lone hits per channel', 'lone_hits_per_channel'), - (np.int64, n_tpc_pmts)), - (('AFT histogram', 'aft_hist'), - (np.int64, n_bins)), - (('AFT bounds', 'aft_bounds'), - (np.float64, 2)), - (('Number of contributing channels histogram', 'n_channel_hist'), - (np.int64, n_tpc_pmts)), - (('Single electron gain', 'online_se_gain'), - np.float32), - ] - return dtype - - def compute(self, peaks, lone_hits, start, end): - # General setup - res = np.zeros(1, dtype=self.dtype) - res['time'] = start - res['endtime'] = end - n_pmt = self.config['n_tpc_pmts'] - n_bins = self.config['online_peak_monitor_nbins'] - - # Bounds for histograms - res['area_vs_width_bounds'] = self.config['area_vs_width_bounds'] - res['lone_hits_area_bounds'] = self.config['lone_hits_area_bounds'] - - # -- Peak vs area 2D histogram -- - # Always cut out unphysical peaks - sel = (peaks['area'] > 0) & (peaks['range_50p_area'] > 0) - res['area_vs_width_hist'] = self.area_width_hist(peaks[sel]) - del sel - - # -- Lone hit properties -- - # Make a mask with the cuts. - # Now only take lone hits that are separated in time. - if len(lone_hits): - lh_timedelta = lone_hits[1:]['time'] - strax.endtime(lone_hits)[:-1] - # Hits on the left are far away? (assume first is because of chunk bound) - mask = np.hstack([True, lh_timedelta > self.config['lone_hits_min_gap']]) - # Hits on the right are far away? (assume last is because of chunk bound) - mask &= np.hstack([lh_timedelta > self.config['lone_hits_min_gap'], True]) - else: - mask = [] - masked_lh = strax.apply_selection(lone_hits[mask], - selection_str=self.config['lone_hits_cut_string']) - - # Make histogram of ADC counts - # NB: LONE HITS AREA ARE IN ADC! - lone_hit_areas, _ = np.histogram(masked_lh['area'], - bins=n_bins, - range=self.config['lone_hits_area_bounds']) - - lone_hit_channel_count, _ = np.histogram(masked_lh['channel'], - bins=n_pmt, - range=[0, n_pmt]) - # Count number of lone-hits per PMT - res['lone_hits_area_hist'] = lone_hit_areas - res['lone_hits_per_channel'] = lone_hit_channel_count - # Clear mask, don't re-use - del mask - - # -- AFT histogram -- - aft_b = [0, 1] - aft_hist, _ = np.histogram(peaks['area_fraction_top'], bins=n_bins, range=aft_b) - res['aft_hist'] = aft_hist - res['aft_bounds'] = aft_b - - # Estimate Single Electron (SE) gain - se_hist, se_bins = np.histogram(peaks['area'], bins=n_bins, - range=self.config['online_se_bounds']) - bin_centers = (se_bins[1:] + se_bins[:-1]) / 2 - res['online_se_gain'] = bin_centers[np.argmax(se_hist)] - return res - - def area_width_hist(self, data): - """Make area vs width 2D-hist""" - hist, _, _ = np.histogram2d( - np.log10(data['area']), - np.log10(data['range_50p_area']), - range=self.config['area_vs_width_bounds'], - bins=self.config['area_vs_width_nbins']) - return hist.T - - -@export -@strax.takes_config( - strax.Option( - 'channel_map', - track=False, - type=immutabledict, - help="immutabledict mapping subdetector to (min, max) \ - channel number.") -) -class OnlineMonitorNV(strax.Plugin): - """ - Plugin to write data of nVeto detector to the online-monitor. - Data that is written by this plugin should be small (~MB/chunk) - to not overload the runs-database. - - This plugin takes 'hitlets_nv' and 'events_nv'. Although they are - not strictly related, they are aggregated into a single data_type - in order to minimize the number of documents in the online monitor. - - Produces 'online_monitor_nv' with info on the hitlets_nv and events_nv - """ - - depends_on = ('hitlets_nv', 'events_nv') - provides = 'online_monitor_nv' - data_kind = 'online_monitor_nv' - __version__ = '0.0.2' - rechunk_on_save = False - - def infer_dtype(self): - min_pmt, max_pmt = self.config['channel_map']['nveto'] - n_pmt = (max_pmt - min_pmt) + 1 - dtype = [ - (('Start time of the chunk', 'time'), - np.int64), - (('End time of the chunk', 'endtime'), - np.int64), - (('hitlets_nv per channel', 'hitlets_nv_per_channel'), - (np.int64, n_pmt)), - (('events_nv per chunk', 'events_nv_per_chunk'), - np.int64) - ] - return dtype - - def compute(self, hitlets_nv, events_nv, start, end): - # General setup - res = np.zeros(1, dtype=self.dtype) - res['time'] = start - res['endtime'] = end - min_pmt, max_pmt = self.config['channel_map']['nveto'] - n_pmt = (max_pmt - min_pmt) + 1 - - # Count number of hitlets_nv per PMT - hitlets_channel_count, _ = np.histogram(hitlets_nv['channel'], - bins=n_pmt, - range=[min_pmt, max_pmt + 1]) - res['hitlets_nv_per_channel'] = hitlets_channel_count - - # Count number of events_nv per chunk - res['events_nv_per_chunk'] = len(events_nv) - return res +import strax +import numpy as np +from immutabledict import immutabledict + +export, __all__ = strax.exporter() + + +@export +@strax.takes_config( + strax.Option( + 'area_vs_width_nbins', + type=int, default=60, + help='Number of bins for area vs width histogram for online monitor. ' + 'NB: this is a 2D histogram'), + strax.Option( + 'area_vs_width_bounds', + type=tuple, default=((0, 5), (0, 5)), + help='Boundaries of log-log histogram of area vs width'), + strax.Option( + 'area_vs_width_cut_string', + type=str, default='', + help='Selection (like selection_str) applied to data for ' + '"area_vs_width_hist_clean", cuts should be separated using "&"' + 'For example: (tight_coincidence > 2) & (area_fraction_top < 0.1)' + 'Default is no selection (other than "area_vs_width_min_gap")'), + strax.Option( + 'lone_hits_area_bounds', + type=tuple, default=(0, 1500), + help='Boundaries area histogram of lone hits [ADC]'), + strax.Option( + 'online_peak_monitor_nbins', + type=int, default=100, + help='Number of bins of histogram of online monitor. Will be used ' + 'for: ' + 'lone_hits_area-histogram, ' + 'area_fraction_top-histogram, ' + 'online_se_gain estimate (histogram is not stored), ' + ), + strax.Option( + 'lone_hits_cut_string', + type=str, + default='(area >= 50) & (area <= 250)', + help='Selection (like selection_str) applied to data for ' + '"lone-hits", cuts should be separated using "&")'), + strax.Option( + 'lone_hits_min_gap', + type=int, + default=15_000, + help='Minimal gap [ns] between consecutive lone-hits. To turn off ' + 'this cut, set to 0.'), + strax.Option( + 'n_tpc_pmts', type=int, + help='Number of TPC PMTs'), + strax.Option( + 'online_se_bounds', + type=tuple, default=(7, 70), + help='Window for online monitor [PE] to look for the SE gain, value' + ) +) +class OnlinePeakMonitor(strax.Plugin): + """ + Plugin to write data to the online-monitor. Data that is written by + this plugin should be small such as to not overload the runs- + database. + + This plugin takes 'peak_basics' and 'lone_hits'. Although they are + not strictly related, they are aggregated into a single data_type + in order to minimize the number of documents in the online monitor. + + Produces 'online_peak_monitor' with info on the lone-hits and peaks + """ + depends_on = ('peak_basics', 'lone_hits') + provides = 'online_peak_monitor' + data_kind = 'online_peak_monitor' + __version__ = '0.0.5' + rechunk_on_save = False + + def infer_dtype(self): + n_bins_area_width = self.config['area_vs_width_nbins'] + bounds_area_width = self.config['area_vs_width_bounds'] + + n_bins = self.config['online_peak_monitor_nbins'] + + n_tpc_pmts = self.config['n_tpc_pmts'] + dtype = [ + (('Start time of the chunk', 'time'), + np.int64), + (('End time of the chunk', 'endtime'), + np.int64), + (('Area vs width histogram (log-log)', 'area_vs_width_hist'), + (np.int64, (n_bins_area_width, n_bins_area_width))), + (('Area vs width edges (log-space)', 'area_vs_width_bounds'), + (np.float64, np.shape(bounds_area_width))), + (('Lone hits areas histogram [ADC-counts]', 'lone_hits_area_hist'), + (np.int64, n_bins)), + (('Lone hits areas bounds [ADC-counts]', 'lone_hits_area_bounds'), + (np.float64, 2)), + (('Lone hits per channel', 'lone_hits_per_channel'), + (np.int64, n_tpc_pmts)), + (('AFT histogram', 'aft_hist'), + (np.int64, n_bins)), + (('AFT bounds', 'aft_bounds'), + (np.float64, 2)), + (('Number of contributing channels histogram', 'n_channel_hist'), + (np.int64, n_tpc_pmts)), + (('Single electron gain', 'online_se_gain'), + np.float32), + ] + return dtype + + def compute(self, peaks, lone_hits, start, end): + # General setup + res = np.zeros(1, dtype=self.dtype) + res['time'] = start + res['endtime'] = end + n_pmt = self.config['n_tpc_pmts'] + n_bins = self.config['online_peak_monitor_nbins'] + + # Bounds for histograms + res['area_vs_width_bounds'] = self.config['area_vs_width_bounds'] + res['lone_hits_area_bounds'] = self.config['lone_hits_area_bounds'] + + # -- Peak vs area 2D histogram -- + # Always cut out unphysical peaks + sel = (peaks['area'] > 0) & (peaks['range_50p_area'] > 0) + res['area_vs_width_hist'] = self.area_width_hist(peaks[sel]) + del sel + + # -- Lone hit properties -- + # Make a mask with the cuts. + # Now only take lone hits that are separated in time. + if len(lone_hits): + lh_timedelta = lone_hits[1:]['time'] - strax.endtime(lone_hits)[:-1] + # Hits on the left are far away? (assume first is because of chunk bound) + mask = np.hstack([True, lh_timedelta > self.config['lone_hits_min_gap']]) + # Hits on the right are far away? (assume last is because of chunk bound) + mask &= np.hstack([lh_timedelta > self.config['lone_hits_min_gap'], True]) + else: + mask = [] + masked_lh = strax.apply_selection(lone_hits[mask], + selection_str=self.config['lone_hits_cut_string']) + + # Make histogram of ADC counts + # NB: LONE HITS AREA ARE IN ADC! + lone_hit_areas, _ = np.histogram(masked_lh['area'], + bins=n_bins, + range=self.config['lone_hits_area_bounds']) + + lone_hit_channel_count, _ = np.histogram(masked_lh['channel'], + bins=n_pmt, + range=[0, n_pmt]) + # Count number of lone-hits per PMT + res['lone_hits_area_hist'] = lone_hit_areas + res['lone_hits_per_channel'] = lone_hit_channel_count + # Clear mask, don't re-use + del mask + + # -- AFT histogram -- + aft_b = [0, 1] + aft_hist, _ = np.histogram(peaks['area_fraction_top'], bins=n_bins, range=aft_b) + res['aft_hist'] = aft_hist + res['aft_bounds'] = aft_b + + # Estimate Single Electron (SE) gain + se_hist, se_bins = np.histogram(peaks['area'], bins=n_bins, + range=self.config['online_se_bounds']) + bin_centers = (se_bins[1:] + se_bins[:-1]) / 2 + res['online_se_gain'] = bin_centers[np.argmax(se_hist)] + return res + + def area_width_hist(self, data): + """Make area vs width 2D-hist""" + hist, _, _ = np.histogram2d( + np.log10(data['area']), + np.log10(data['range_50p_area']), + range=self.config['area_vs_width_bounds'], + bins=self.config['area_vs_width_nbins']) + return hist.T + + +@export +@strax.takes_config( + strax.Option( + 'channel_map', + track=False, + type=immutabledict, + help='immutabledict mapping subdetector to (min, max) ' + 'channel number.'), + strax.Option( + 'events_area_bounds', + type=tuple, default=(-0.5, 130.5), + help='Boundaries area histogram of events_nv_area_per_chunk [PE]'), + strax.Option( + 'events_area_nbins', + type=int, default=131, + help='Number of bins of histogram of events_nv_area_per_chunk, ' + 'defined value 1 PE/bin') +) +class OnlineMonitorNV(strax.Plugin): + """ + Plugin to write data of nVeto detector to the online-monitor. + Data that is written by this plugin should be small (~MB/chunk) + to not overload the runs-database. + + This plugin takes 'hitlets_nv' and 'events_nv'. Although they are + not strictly related, they are aggregated into a single data_type + in order to minimize the number of documents in the online monitor. + + Produces 'online_monitor_nv' with info on the hitlets_nv and events_nv + """ + depends_on = ('hitlets_nv', 'events_nv') + provides = 'online_monitor_nv' + data_kind = 'online_monitor_nv' + rechunk_on_save = False + + # Needed in case we make again an muVETO child. + ends_with = '_nv' + + __version__ = '0.0.4' + + def infer_dtype(self): + self.channel_range = self.config['channel_map']['nveto'] + self.n_channel = (self.channel_range[1] - self.channel_range[0]) + 1 + return veto_monitor_dtype(self.ends_with, self.n_channel, self.config['events_area_nbins']) + + def compute(self, hitlets_nv, events_nv, start, end): + # General setup + res = np.zeros(1, dtype=self.dtype) + res['time'] = start + res['endtime'] = end + + # Count number of hitlets_nv per PMT + hitlets_channel_count, _ = np.histogram(hitlets_nv['channel'], + bins=self.n_channel, + range=[self.channel_range[0], + self.channel_range[1] + 1]) + res[f'hitlets{self.ends_with}_per_channel'] = hitlets_channel_count + + # Count number of events_nv with coincidence cut + res[f'events{self.ends_with}_per_chunk'] = len(events_nv) + sel = events_nv['n_contributing_pmt'] >= 4 + res[f'events{self.ends_with}_4coinc_per_chunk'] = np.sum(sel) + sel = events_nv['n_contributing_pmt'] >= 5 + res[f'events{self.ends_with}_5coinc_per_chunk'] = np.sum(sel) + sel = events_nv['n_contributing_pmt'] >= 8 + res[f'events{self.ends_with}_8coinc_per_chunk'] = np.sum(sel) + sel = events_nv['n_contributing_pmt'] >= 10 + res[f'events{self.ends_with}_10coinc_per_chunk'] = np.sum(sel) + + # Get histogram of events_nv_area per chunk + events_area, bins_ = np.histogram(events_nv['area'], + bins=self.config['events_area_nbins'], + range=self.config['events_area_bounds']) + res[f'events{self.ends_with}_area_per_chunk'] = events_area + return res + + +def veto_monitor_dtype(veto_name: str = '_nv', + n_pmts: int = 120, + n_bins: int = 131) -> list: + dtype = [] + dtype += strax.time_fields # because mutable + dtype += [((f'hitlets{veto_name} per channel', f'hitlets{veto_name}_per_channel'), (np.int64, n_pmts)), + ((f'events{veto_name}_area per chunk', f'events{veto_name}_area_per_chunk'), np.int64, n_bins), + ((f'events{veto_name} per chunk', f'events{veto_name}_per_chunk'), np.int64), + ((f'events{veto_name} 4-coincidence per chunk', f'events{veto_name}_4coinc_per_chunk'), np.int64), + ((f'events{veto_name} 5-coincidence per chunk', f'events{veto_name}_5coinc_per_chunk'), np.int64), + ((f'events{veto_name} 8-coincidence per chunk', f'events{veto_name}_8coinc_per_chunk'), np.int64), + ((f'events{veto_name} 10-coincidence per chunk', f'events{veto_name}_10coinc_per_chunk'), np.int64) + ] + return dtype + + +@export +@strax.takes_config( + strax.Option( + 'adc_to_pe_mv', + type=int, default=170.0, + help='conversion factor from ADC to PE for muon Veto') +) +class OnlineMonitorMV(OnlineMonitorNV): + __doc__ = OnlineMonitorNV.__doc__.replace('_nv', '_mv').replace('nVeto', 'muVeto') + depends_on = ('hitlets_mv', 'events_mv') + provides = 'online_monitor_mv' + data_kind = 'online_monitor_mv' + rechunk_on_save = False + + # Needed in case we make again an muVETO child. + ends_with = '_mv' + child_plugin = True + + __version__ = '0.0.1' + + def infer_dtype(self): + self.channel_range = self.config['channel_map']['mv'] + self.n_channel = (self.channel_range[1] - self.channel_range[0]) + 1 + return veto_monitor_dtype(self.ends_with, self.n_channel, self.config['events_area_nbins']) + + def compute(self, hitlets_mv, events_mv, start, end): + events_mv = np.copy(events_mv) + events_mv['area'] *= 1./self.config['adc_to_pe_mv'] + return super().compute(hitlets_mv, events_mv, start, end) diff --git a/straxen/plugins/pax_interface.py b/straxen/plugins/pax_interface.py index 70209cf54..18f707c0e 100644 --- a/straxen/plugins/pax_interface.py +++ b/straxen/plugins/pax_interface.py @@ -120,13 +120,14 @@ def finish_results(): @export @strax.takes_config( - strax.Option('pax_raw_dir', default='/data/xenon/raw', track=False, + strax.Option('pax_raw_dir', default='/data/xenon/raw', track=False, infer_type=False, help="Directory with raw pax datasets"), - strax.Option('stop_after_zips', default=0, track=False, + strax.Option('stop_after_zips', default=0, track=False, infer_type=False, help="Convert only this many zip files. 0 = all."), - strax.Option('events_per_chunk', default=50, track=False, + strax.Option('events_per_chunk', default=50, track=False, infer_type=False, help="Number of events to yield per chunk"), - strax.Option('samples_per_record', default=strax.DEFAULT_RECORD_LENGTH, track=False, + strax.Option('samples_per_record', default=strax.DEFAULT_RECORD_LENGTH, + track=False, infer_type=False, help="Number of samples per record") ) class RecordsFromPax(strax.Plugin): diff --git a/straxen/plugins/peak_processing.py b/straxen/plugins/peak_processing.py index 92cf527f9..201d5d0d3 100644 --- a/straxen/plugins/peak_processing.py +++ b/straxen/plugins/peak_processing.py @@ -14,9 +14,9 @@ @export @strax.takes_config( - strax.Option('n_top_pmts', default=straxen.n_top_pmts, + strax.Option('n_top_pmts', default=straxen.n_top_pmts, infer_type=False, help="Number of top PMTs"), - strax.Option('check_peak_sum_area_rtol', default=None, track=False, + strax.Option('check_peak_sum_area_rtol', default=None, track=False, infer_type=False, help="Check if the sum area and the sum of area per " "channel are the same. If None, don't do the " "check. To perform the check, set to the desired " @@ -28,7 +28,7 @@ class PeakBasics(strax.Plugin): arrays. NB: This plugin can therefore be loaded as a pandas DataFrame. """ - __version__ = "0.0.9" + __version__ = "0.1.0" parallel = True depends_on = ('peaks',) provides = 'peak_basics' @@ -47,6 +47,8 @@ class PeakBasics(strax.Plugin): 'max_pmt'), np.int16), (('Area of signal in the largest-contributing PMT (PE)', 'max_pmt_area'), np.float32), + (('Total number of saturated channels', + 'n_saturated_channels'), np.int16), (('Width (in ns) of the central 50% area of the peak', 'range_50p_area'), np.float32), (('Width (in ns) of the central 90% area of the peak', @@ -62,6 +64,8 @@ class PeakBasics(strax.Plugin): 'rise_time'), np.float32), (('Hits within tight range of mean', 'tight_coincidence'), np.int16), + (('PMT channel within tight range of mean', + 'tight_coincidence_channel'), np.int16), (('Classification of the peak(let)', 'type'), np.int8) ] @@ -78,6 +82,7 @@ def compute(self, peaks): r['max_pmt'] = np.argmax(p['area_per_channel'], axis=1) r['max_pmt_area'] = np.max(p['area_per_channel'], axis=1) r['tight_coincidence'] = p['tight_coincidence'] + r['n_saturated_channels'] = p['n_saturated_channels'] n_top = self.config['n_top_pmts'] area_top = p['area_per_channel'][:, :n_top].sum(axis=1) @@ -159,21 +164,21 @@ def compute(self, peaks_he): @export @strax.takes_config( strax.Option( - 'nn_architecture', + 'nn_architecture', infer_type=False, help='Path to JSON of neural net architecture', default_by_run=[ (0, pax_file('XENON1T_tensorflow_nn_pos_20171217_sr0.json')), (first_sr1_run, straxen.aux_repo + '3548132b55f81a43654dba5141366041e1daaf01/strax_files/XENON1T_tensorflow_nn_pos_20171217_sr1_reformatted.json')]), # noqa strax.Option( - 'nn_weights', + 'nn_weights', infer_type=False, help='Path to HDF5 of neural net weights', default_by_run=[ (0, pax_file('XENON1T_tensorflow_nn_pos_weights_20171217_sr0.h5')), (first_sr1_run, pax_file('XENON1T_tensorflow_nn_pos_weights_20171217_sr1.h5'))]), # noqa strax.Option('min_reconstruction_area', help='Skip reconstruction if area (PE) is less than this', - default=10), - strax.Option('n_top_pmts', default=straxen.n_top_pmts, + default=10, infer_type=False,), + strax.Option('n_top_pmts', default=straxen.n_top_pmts, infer_type=False, help="Number of top PMTs") ) class PeakPositions1T(strax.Plugin): @@ -253,13 +258,13 @@ def compute(self, peaks): @export @strax.takes_config( - strax.Option('min_area_fraction', default=0.5, + strax.Option('min_area_fraction', default=0.5, infer_type=False, help='The area of competing peaks must be at least ' 'this fraction of that of the considered peak'), - strax.Option('nearby_window', default=int(1e7), + strax.Option('nearby_window', default=int(1e7), infer_type=False, help='Peaks starting within this time window (on either side)' 'in ns count as nearby.'), - strax.Option('peak_max_proximity_time', default=int(1e8), + strax.Option('peak_max_proximity_time', default=int(1e8), infer_type=False, help='Maximum value for proximity values such as ' 't_to_next_peak [ns]')) class PeakProximity(strax.OverlapWindowPlugin): @@ -326,6 +331,103 @@ def find_n_competing(peaks, windows, fraction): return n_left, n_tot +@export +@strax.takes_config( + strax.Option(name='pre_s2_area_threshold', default=1000, + help='Only take S2s larger than this into account ' + 'when calculating PeakShadow [PE]'), + strax.Option(name='deltatime_exponent', default=-1.0, + help='The exponent of delta t when calculating shadow'), + strax.Option('time_window_backward', default=int(3e9), + help='Search for S2s causing shadow in this time window [ns]'), + strax.Option(name='electron_drift_velocity', + default=('electron_drift_velocity', 'ONLINE', True), + help='Vertical electron drift velocity in cm/ns (1e4 m/ms)'), + strax.Option(name='max_drift_length', default=straxen.tpc_z, + help='Total length of the TPC from the bottom of gate to the ' + 'top of cathode wires [cm]'), + strax.Option(name='exclude_drift_time', default=False, + help='Subtract max drift time to avoid peak interference in ' + 'a single event [ns]')) +class PeakShadow(strax.OverlapWindowPlugin): + """ + This plugin can find and calculate the previous S2 shadow at peak level, + with time window backward and previous S2 area as options. + It also gives the area and position information of these previous S2s. + """ + + __version__ = '0.1.0' + depends_on = ('peak_basics', 'peak_positions') + provides = 'peak_shadow' + save_when = strax.SaveWhen.EXPLICIT + + def setup(self): + self.time_window_backward = self.config['time_window_backward'] + if self.config['exclude_drift_time']: + electron_drift_velocity = straxen.get_correction_from_cmt( + self.run_id, + self.config['electron_drift_velocity']) + drift_time_max = int(self.config['max_drift_length'] / electron_drift_velocity) + self.n_drift_time = drift_time_max + else: + self.n_drift_time = 0 + self.s2_threshold = self.config['pre_s2_area_threshold'] + self.exponent = self.config['deltatime_exponent'] + + def get_window_size(self): + return 3 * self.config['time_window_backward'] + + def infer_dtype(self): + dtype = [('shadow', np.float32, 'previous s2 shadow [PE/ns]'), + ('pre_s2_area', np.float32, 'previous s2 area [PE]'), + ('shadow_dt', np.int64, 'time difference to the previous s2 [ns]'), + ('pre_s2_x', np.float32, 'x of previous s2 peak causing shadow [cm]'), + ('pre_s2_y', np.float32, 'y of previous s2 peak causing shadow [cm]')] + dtype += strax.time_fields + return dtype + + def compute(self, peaks): + roi_shadow = np.zeros(len(peaks), dtype=strax.time_fields) + roi_shadow['time'] = peaks['center_time'] - self.time_window_backward + roi_shadow['endtime'] = peaks['center_time'] - self.n_drift_time + + mask_pre_s2 = peaks['area'] > self.s2_threshold + mask_pre_s2 &= peaks['type'] == 2 + split_peaks = strax.touching_windows(peaks[mask_pre_s2], roi_shadow) + res = np.zeros(len(peaks), self.dtype) + res['pre_s2_x'] = np.nan + res['pre_s2_y'] = np.nan + if len(peaks): + self.compute_shadow(peaks, peaks[mask_pre_s2], split_peaks, self.exponent, res) + + res['time'] = peaks['time'] + res['endtime'] = strax.endtime(peaks) + return res + + @staticmethod + @numba.njit + def compute_shadow(peaks, pre_s2_peaks, touching_windows, exponent, res): + """ + For each peak in peaks, check if there is a shadow-casting S2 peak + and check if it casts the largest shadow + """ + for p_i, p_a in enumerate(peaks): + # reset for every peak + new_shadow = 0 + s2_indices = touching_windows[p_i] + for s2_idx in range(s2_indices[0], s2_indices[1]): + s2_a = pre_s2_peaks[s2_idx] + if p_a['center_time'] - s2_a['center_time'] <= 0: + continue + new_shadow = s2_a['area'] * ( + p_a['center_time'] - s2_a['center_time'])**exponent + if new_shadow > res['shadow'][p_i]: + res['shadow'][p_i] = new_shadow + res['pre_s2_area'][p_i] = s2_a['area'] + res['shadow_dt'][p_i] = p_a['center_time'] - s2_a['center_time'] + res['pre_s2_x'][p_i] = s2_a['x'] + res['pre_s2_y'][p_i] = s2_a['y'] + @export class VetoPeakTags(IntEnum): diff --git a/straxen/plugins/peaklet_processing.py b/straxen/plugins/peaklet_processing.py index bc35c7a0f..deb8714fe 100644 --- a/straxen/plugins/peaklet_processing.py +++ b/straxen/plugins/peaklet_processing.py @@ -1,6 +1,7 @@ import numba import numpy as np import strax +from immutabledict import immutabledict from strax.processing.general import _touching_windows import straxen from .pulse_processing import HITFINDER_OPTIONS, HITFINDER_OPTIONS_he, HE_PREAMBLE @@ -13,13 +14,13 @@ @export @strax.takes_config( - strax.Option('peaklet_gap_threshold', default=700, + strax.Option('peaklet_gap_threshold', default=700, infer_type=False, help="No hits for this many ns triggers a new peak"), - strax.Option('peak_left_extension', default=30, + strax.Option('peak_left_extension', default=30, infer_type=False, help="Include this many ns left of hits in peaks"), - strax.Option('peak_right_extension', default=200, + strax.Option('peak_right_extension', default=200, infer_type=False, help="Include this many ns right of hits in peaks"), - strax.Option('peak_min_pmts', default=2, + strax.Option('peak_min_pmts', default=2, infer_type=False, help="Minimum number of contributing PMTs needed to define a peak"), strax.Option('peak_split_gof_threshold', # See https://xe1t-wiki.lngs.infn.it/doku.php?id= @@ -29,40 +30,43 @@ default=( None, # Reserved ((0.5, 1.0), (6.0, 0.4)), - ((2.5, 1.0), (5.625, 0.4))), + ((2.5, 1.0), (5.625, 0.4))), infer_type=False, help='Natural breaks goodness of fit/split threshold to split ' 'a peak. Specify as tuples of (log10(area), threshold).'), - strax.Option('peak_split_filter_wing_width', default=70, + strax.Option('peak_split_filter_wing_width', default=70, infer_type=False, help='Wing width of moving average filter for ' 'low-split natural breaks'), - strax.Option('peak_split_min_area', default=40., + strax.Option('peak_split_min_area', default=40., infer_type=False, help='Minimum area to evaluate natural breaks criterion. ' 'Smaller peaks are not split.'), - strax.Option('peak_split_iterations', default=20, + strax.Option('peak_split_iterations', default=20, infer_type=False, help='Maximum number of recursive peak splits to do.'), - strax.Option('diagnose_sorting', track=False, default=False, + strax.Option('diagnose_sorting', track=False, default=False, infer_type=False, help="Enable runtime checks for sorting and disjointness"), - strax.Option('gain_model', + strax.Option('gain_model', infer_type=False, help='PMT gain model. Specify as ' '(str(model_config), str(version), nT-->boolean'), - strax.Option('tight_coincidence_window_left', default=50, + strax.Option('tight_coincidence_window_left', default=50, infer_type=False, help="Time range left of peak center to call " "a hit a tight coincidence (ns)"), - strax.Option('tight_coincidence_window_right', default=50, + strax.Option('tight_coincidence_window_right', default=50, infer_type=False, help="Time range right of peak center to call " "a hit a tight coincidence (ns)"), strax.Option('n_tpc_pmts', type=int, help='Number of TPC PMTs'), - strax.Option('saturation_correction_on', default=True, + strax.Option('saturation_correction_on', default=True, infer_type=False, help='On off switch for saturation correction'), - strax.Option('saturation_reference_length', default=100, + strax.Option('saturation_reference_length', default=100, infer_type=False, help="Maximum number of reference sample used " "to correct saturated samples"), - strax.Option('saturation_min_reference_length', default=20, + strax.Option('saturation_min_reference_length', default=20, infer_type=False, help="Minimum number of reference sample used " "to correct saturated samples"), - strax.Option('peaklet_max_duration', default=int(10e6), + strax.Option('peaklet_max_duration', default=int(10e6), infer_type=False, help="Maximum duration [ns] of a peaklet"), + strax.Option('channel_map', track=False, type=immutabledict, + help="immutabledict mapping subdetector to (min, max) " + "channel number."), *HITFINDER_OPTIONS, ) class Peaklets(strax.Plugin): @@ -92,7 +96,7 @@ class Peaklets(strax.Plugin): parallel = 'process' compressor = 'zstd' - __version__ = '0.4.1' + __version__ = '0.6.0' def infer_dtype(self): return dict(peaklets=strax.peak_dtype( @@ -121,6 +125,8 @@ def setup(self): self.config['hit_min_amplitude']) else: # int or array self.hit_thresholds = self.config['hit_min_amplitude'] + + self.channel_range = self.config['channel_map']['tpc'] def compute(self, records, start, end): r = records @@ -234,11 +240,15 @@ def compute(self, records, start, end): peaklet_max_times = ( peaklets['time'] + np.argmax(peaklets['data'], axis=1) * peaklets['dt']) - peaklets['tight_coincidence'] = get_tight_coin( + tight_coincidence_channel = get_tight_coin( hit_max_times, + hitlets['channel'], peaklet_max_times, self.config['tight_coincidence_window_left'], - self.config['tight_coincidence_window_right']) + self.config['tight_coincidence_window_right'], + self.channel_range) + + peaklets['tight_coincidence'] = tight_coincidence_channel if self.config['diagnose_sorting'] and len(r): assert np.diff(r['time']).min(initial=1) >= 0, "Records not sorted" @@ -473,18 +483,18 @@ def _peak_saturation_correction_inner(channel_saturated, records, p, @export @strax.takes_config( - strax.Option('n_he_pmts', track=False, default=752, + strax.Option('n_he_pmts', track=False, default=752, infer_type=False, help="Maximum channel of the he channels"), - strax.Option('he_channel_offset', track=False, default=500, + strax.Option('he_channel_offset', track=False, default=500, infer_type=False, help="Minimum channel number of the he channels"), - strax.Option('le_to_he_amplification', default=20, track=True, + strax.Option('le_to_he_amplification', default=20, track=True, infer_type=False, help="Difference in amplification between low energy and high " "energy channels"), - strax.Option('peak_min_pmts_he', default=2, + strax.Option('peak_min_pmts_he', default=2, infer_type=False, child_option=True, parent_option_name='peak_min_pmts', track=True, help="Minimum number of contributing PMTs needed to define a peak"), - strax.Option('saturation_correction_on_he', default=False, + strax.Option('saturation_correction_on_he', default=False, infer_type=False, child_option=True, parent_option_name='saturation_correction_on', track=True, help='On off switch for saturation correction for High Energy' @@ -522,6 +532,8 @@ def setup(self): self.config['hit_min_amplitude_he']) else: # int or array self.hit_thresholds = self.config['hit_min_amplitude_he'] + + self.channel_range = self.config['channel_map']['he'] def compute(self, records_he, start, end): result = super().compute(records_he, start, end) @@ -530,13 +542,21 @@ def compute(self, records_he, start, end): @export @strax.takes_config( - strax.Option('s1_max_rise_time', default=110, - help="Maximum S1 rise time for < 100 PE [ns]"), - strax.Option('s1_max_rise_time_post100', default=200, + strax.Option('s1_risetime_area_parameters', default=(50, 80, 12), type=(list, tuple), + help="norm, const, tau in the empirical boundary in the risetime-area plot"), + strax.Option('s1_risetime_aft_parameters', default=(-1, 2.6), type=(list, tuple), + help=("Slope and offset in exponential of emperical boundary in the rise time-AFT " + "plot. Specified as (slope, offset)")), + strax.Option('s1_flatten_threshold_aft', default=(0.6, 100), type=(tuple, list), + help=("Threshold for AFT, above which we use a flatted boundary for rise time" + "Specified values: (AFT boundary, constant rise time).")), + strax.Option('n_top_pmts', default=straxen.n_top_pmts, type=int, + help="Number of top PMTs"), + strax.Option('s1_max_rise_time_post100', default=200, type=(int, float), help="Maximum S1 rise time for > 100 PE [ns]"), - strax.Option('s1_min_coincidence', default=2, + strax.Option('s1_min_coincidence', default=2, type=int, help="Minimum tight coincidence necessary to make an S1"), - strax.Option('s2_min_pmts', default=4, + strax.Option('s2_min_pmts', default=4, type=int, help="Minimum number of PMTs contributing to an S2")) class PeakletClassification(strax.Plugin): """Classify peaklets as unknown, S1, or S2.""" @@ -545,27 +565,60 @@ class PeakletClassification(strax.Plugin): parallel = True dtype = (strax.peak_interval_dtype + [('type', np.int8, 'Classification of the peak(let)')]) - __version__ = '0.2.1' - def compute(self, peaklets): - peaks = peaklets + __version__ = '3.0.3' + + @staticmethod + def upper_rise_time_area_boundary(area, norm, const, tau): + """ + Function which determines the upper boundary for the rise-time + for a given area. + """ + return norm*np.exp(-area/tau) + const + + @staticmethod + def upper_rise_time_aft_boundary(aft, slope, offset, aft_boundary, flat_threshold): + """ + Function which computes the upper rise time boundary as a function + of area fraction top. + """ + res = 10**(slope * aft + offset) + res[aft >= aft_boundary] = flat_threshold + return res + def compute(self, peaklets): ptype = np.zeros(len(peaklets), dtype=np.int8) - # Properties needed for classification. Bit annoying these computations - # are duplicated in peak_basics curently... - rise_time = -peaks['area_decile_from_midpoint'][:, 1] - n_channels = (peaks['area_per_channel'] > 0).sum(axis=1) + # Properties needed for classification: + rise_time = -peaklets['area_decile_from_midpoint'][:, 1] + n_channels = (peaklets['area_per_channel'] > 0).sum(axis=1) + n_top = self.config['n_top_pmts'] + area_top = peaklets['area_per_channel'][:, :n_top].sum(axis=1) + area_total = peaklets['area_per_channel'].sum(axis=1) + area_fraction_top = area_top/area_total + + is_large_s1 = (peaklets['area'] >= 100) + is_large_s1 &= (rise_time <= self.config['s1_max_rise_time_post100']) + is_large_s1 &= peaklets['tight_coincidence'] >= self.config['s1_min_coincidence'] + + is_small_s1 = peaklets["area"] < 100 + is_small_s1 &= rise_time < self.upper_rise_time_area_boundary( + peaklets["area"], + *self.config["s1_risetime_area_parameters"], + ) + + is_small_s1 &= rise_time < self.upper_rise_time_aft_boundary( + area_fraction_top, + *self.config["s1_risetime_aft_parameters"], + *self.config["s1_flatten_threshold_aft"], + ) - is_s1 = ( - (rise_time <= self.config['s1_max_rise_time']) - | ((rise_time <= self.config['s1_max_rise_time_post100']) - & (peaks['area'] > 100))) - is_s1 &= peaks['tight_coincidence'] >= self.config['s1_min_coincidence'] - ptype[is_s1] = 1 + is_small_s1 &= peaklets['tight_coincidence'] >= self.config['s1_min_coincidence'] + + ptype[is_large_s1 | is_small_s1] = 1 is_s2 = n_channels >= self.config['s2_min_pmts'] - is_s2[is_s1] = False + is_s2[is_large_s1 | is_small_s1] = False ptype[is_s2] = 2 return dict(type=ptype, @@ -593,20 +646,21 @@ def compute(self, peaklets_he): @export @strax.takes_config( - strax.Option('s2_merge_max_duration', default=50_000, + strax.Option('s2_merge_max_duration', default=50_000, infer_type=False, help="Do not merge peaklets at all if the result would be a peak " "longer than this [ns]"), strax.Option('s2_merge_gap_thresholds', default=((1.7, 2.65e4), (4.0, 2.6e3), (5.0, 0.)), + infer_type=False, help="Points to define maximum separation between peaklets to allow " "merging [ns] depending on log10 area of the merged peak\n" "where the gap size of the first point is the maximum gap to allow merging" "and the area of the last point is the maximum area to allow merging. " "The format is ((log10(area), max_gap), (..., ...), (..., ...))" ), - strax.Option('gain_model', + strax.Option('gain_model', infer_type=False, help='PMT gain model. Specify as ' '(str(model_config), str(version), nT-->boolean'), - strax.Option('merge_without_s1', default=True, + strax.Option('merge_without_s1', default=True, infer_type=False, help="If true, S1s will be igored during the merging. " "It's now possible for a S1 to be inside a S2 post merging"), ) @@ -618,7 +672,7 @@ class MergedS2s(strax.OverlapWindowPlugin): depends_on = ('peaklets', 'peaklet_classification', 'lone_hits') data_kind = 'merged_s2s' provides = 'merged_s2s' - __version__ = '0.3.1' + __version__ = '0.4.1' def setup(self): self.to_pe = straxen.get_correction_from_cmt(self.run_id, @@ -660,8 +714,8 @@ def compute(self, peaklets, lone_hits): merged_s2s = strax.merge_peaks( peaklets, start_merge_at, end_merge_at, - max_buffer=int(self.config['s2_merge_max_duration'] - // peaklets['dt'].min())) + max_buffer=int(self.config['s2_merge_max_duration']//np.gcd.reduce(peaklets['dt'])), + ) merged_s2s['type'] = 2 # Updated time and length of lone_hits and sort again: @@ -789,9 +843,9 @@ def compute(self, peaklets_he): @export @strax.takes_config( - strax.Option('diagnose_sorting', track=False, default=False, + strax.Option('diagnose_sorting', track=False, default=False, infer_type=False, help="Enable runtime checks for sorting and disjointness"), - strax.Option('merge_without_s1', default=True, + strax.Option('merge_without_s1', default=True, infer_type=False, help="If true, S1s will be igored during the merging. " "It's now possible for a S1 to be inside a S2 post merging"), ) @@ -805,7 +859,7 @@ class Peaks(strax.Plugin): data_kind = 'peaks' provides = 'peaks' parallel = True - save_when = strax.SaveWhen.NEVER + save_when = strax.SaveWhen.EXPLICIT __version__ = '0.1.2' @@ -853,33 +907,53 @@ def compute(self, peaklets_he, merged_s2s_he): @numba.jit(nopython=True, nogil=True, cache=True) -def get_tight_coin(hit_max_times, peak_max_times, left, right): - """Calculates the tight coincidence +def get_tight_coin(hit_max_times, hit_channel, peak_max_times, left, right, + channels=(0, 493)): + """Calculates the tight coincidence based on PMT channels. Defined by number of hits within a specified time range of the the peak's maximum amplitude. Imitates tight_coincidence variable in pax: github.com/XENON1T/pax/blob/master/pax/plugins/peak_processing/BasicProperties.py + + :param hit_max_times: Time of the hit amplitude in ns. + :param hit_channel: PMT channels of the hits + :param peak_max_times: Time of the peaks maximum in ns. + :param left: Left boundary in which we search for the tight + coincidence in ns. + :param right: Right boundary in which we search for the tight + coincidence in ns. + :param channel_range: (min/max) channel for the corresponding detector. + + :returns: n_coin_channel of length peaks containing the + tight coincidence. """ left_hit_i = 0 - n_coin = np.zeros(len(peak_max_times), dtype=np.int16) + n_coin_channel = np.zeros(len(peak_max_times), dtype=np.int16) + start_ch, end_ch = channels + channels_seen = np.zeros(end_ch-start_ch+1, dtype=np.bool_) # loop over peaks for p_i, p_t in enumerate(peak_max_times): - + channels_seen[:] = 0 # loop over hits starting from the last one we left at for left_hit_i in range(left_hit_i, len(hit_max_times)): # if the hit is in the window, its a tight coin d = hit_max_times[left_hit_i] - p_t if (-left <= d) & (d <= right): - n_coin[p_i] += 1 + channels_seen[hit_channel[left_hit_i]-start_ch] = 1 # stop the loop when we know we're outside the range if d > right: + n_coin_channel[p_i] = np.sum(channels_seen) break + + # Add channel information in case there are no hits beyond + # the last peak: + n_coin_channel[p_i] = np.sum(channels_seen) - return n_coin + return n_coin_channel @numba.njit(cache=True, nogil=True) diff --git a/straxen/plugins/position_reconstruction.py b/straxen/plugins/position_reconstruction.py index 19cb51de3..50324c1b5 100644 --- a/straxen/plugins/position_reconstruction.py +++ b/straxen/plugins/position_reconstruction.py @@ -1,205 +1,286 @@ -"""Position reconstruction for Xenon-nT""" - -import os -import tempfile -import tarfile -import numpy as np -import strax -import straxen -from warnings import warn -export, __all__ = strax.exporter() - -DEFAULT_POSREC_ALGO_OPTION = tuple([strax.Option("default_reconstruction_algorithm", - help="default reconstruction algorithm that provides (x,y)", - default="mlp", - )]) - -@export -@strax.takes_config( - strax.Option('min_reconstruction_area', - help='Skip reconstruction if area (PE) is less than this', - default=10), - strax.Option('n_top_pmts', default=straxen.n_top_pmts, - help="Number of top PMTs") -) - -class PeakPositionsBaseNT(strax.Plugin): - """ - Base class for reconstructions. - This class should only be used when subclassed for the different - algorithms. Provides x_algorithm, y_algorithm for all peaks > than - min-reconstruction area based on the top array. - """ - depends_on = ('peaks',) - algorithm = None - compressor = 'zstd' - # Using parallel = 'process' is not allowed as we cannot Pickle - # self.model during multiprocessing (to fix?) - parallel = True - __version__ = '0.0.0' - - def infer_dtype(self): - if self.algorithm is None: - raise NotImplementedError(f'Base class should not be used without ' - f'algorithm as done in {__class__.__name__}') - dtype = [('x_' + self.algorithm, np.float32, - f'Reconstructed {self.algorithm} S2 X position (cm), uncorrected'), - ('y_' + self.algorithm, np.float32, - f'Reconstructed {self.algorithm} S2 Y position (cm), uncorrected')] - dtype += strax.time_fields - return dtype - - def setup(self): - self.model_file = self._get_model_file_name() - if self.model_file is None: - warn(f'No file provided for {self.algorithm}. Setting all values ' - f'for {self.provides} to None.') - # No further setup required - return - - # Load the tensorflow model - import tensorflow as tf - if os.path.exists(self.model_file): - print(f"Path is local. Loading {self.algorithm} TF model locally " - f"from disk.") - else: - downloader = straxen.MongoDownloader() - try: - self.model_file = downloader.download_single(self.model_file) - except straxen.mongo_storage.CouldNotLoadError as e: - raise RuntimeError(f'Model files {self.model_file} is not found') from e - with tempfile.TemporaryDirectory() as tmpdirname: - tar = tarfile.open(self.model_file, mode="r:gz") - tar.extractall(path=tmpdirname) - self.model = tf.keras.models.load_model(tmpdirname) - - def compute(self, peaks): - result = np.ones(len(peaks), dtype=self.dtype) - result['time'], result['endtime'] = peaks['time'], strax.endtime(peaks) - - result['x_' + self.algorithm] *= float('nan') - result['y_' + self.algorithm] *= float('nan') - - if self.model_file is None: - # This plugin is disabled since no model is provided - return result - - # Keep large peaks only - peak_mask = peaks['area'] > self.config['min_reconstruction_area'] - if not np.sum(peak_mask): - # Nothing to do, and .predict crashes on empty arrays - return result - - # Getting actual position reconstruction - _in = peaks['area_per_channel'][peak_mask, 0:self.config['n_top_pmts']] - with np.errstate(divide='ignore', invalid='ignore'): - _in = _in / np.max(_in, axis=1).reshape(-1, 1) - _in = _in.reshape(-1, self.config['n_top_pmts']) - _out = self.model.predict(_in) - - # writing output to the result - result['x_' + self.algorithm][peak_mask] = _out[:, 0] - result['y_' + self.algorithm][peak_mask] = _out[:, 1] - return result - - def _get_model_file_name(self): - - config_file = f'{self.algorithm}_model' - model_from_config = self.config.get(config_file, 'No file') - if model_from_config == 'No file': - raise ValueError(f'{__class__.__name__} should have {config_file} ' - f'provided as an option.') - if isinstance(model_from_config, str) and os.path.exists(model_from_config): - # Allow direct path specification - return model_from_config - if model_from_config is None: - # Allow None to be specified (disables processing for given posrec) - return model_from_config - - # Use CMT - model_file = straxen.get_correction_from_cmt(self.run_id, model_from_config) - return model_file - -@export -@strax.takes_config( - strax.Option('mlp_model', - help='Neural network model.' - 'If CMT, specify as (mlp_model, ONLINE, True)' - 'Set to None to skip the computation of this plugin.', - default=('mlp_model', "ONLINE", True) - ) -) -class PeakPositionsMLP(PeakPositionsBaseNT): - """Multilayer Perceptron (MLP) neural net for position reconstruction""" - provides = "peak_positions_mlp" - algorithm = "mlp" - - -@export -@strax.takes_config( - strax.Option('gcn_model', - help='Neural network model.' - 'If CMT, specify as (gcn_model, ONLINE, True)' - 'Set to None to skip the computation of this plugin.', - default=('gcn_model', "ONLINE", True) - ) -) -class PeakPositionsGCN(PeakPositionsBaseNT): - """Graph Convolutional Network (GCN) neural net for position reconstruction""" - provides = "peak_positions_gcn" - algorithm = "gcn" - __version__ = '0.0.1' - - -@export -@strax.takes_config( - strax.Option('cnn_model', - help='Neural network model.' - 'If CMT, specify as (cnn_model, ONLINE, True)' - 'Set to None to skip the computation of this plugin.', - default=('cnn_model', "ONLINE", True) - ) -) -class PeakPositionsCNN(PeakPositionsBaseNT): - """Convolutional Neural Network (CNN) neural net for position reconstruction""" - provides = "peak_positions_cnn" - algorithm = "cnn" - __version__ = '0.0.1' - - -@export -@strax.takes_config( - *DEFAULT_POSREC_ALGO_OPTION -) -class PeakPositionsNT(strax.MergeOnlyPlugin): - """ - Merge the reconstructed algorithms of the different algorithms - into a single one that can be used in Event Basics. - - Select one of the plugins to provide the 'x' and 'y' to be used - further down the chain. Since we already have the information - needed here, there is no need to wait until events to make the - decision. - - Since the computation is trivial as it only combined the three - input plugins, don't save this plugins output. - """ - provides = "peak_positions" - depends_on = ("peak_positions_cnn", "peak_positions_mlp", "peak_positions_gcn") - save_when = strax.SaveWhen.NEVER - __version__ = '0.0.0' - - def infer_dtype(self): - dtype = strax.merged_dtype([self.deps[d].dtype_for(d) for d in self.depends_on]) - dtype += [('x', np.float32, 'Reconstructed S2 X position (cm), uncorrected'), - ('y', np.float32, 'Reconstructed S2 Y position (cm), uncorrected')] - return dtype - - def compute(self, peaks): - result = {dtype: peaks[dtype] for dtype in peaks.dtype.names} - algorithm = self.config['default_reconstruction_algorithm'] - if not 'x_' + algorithm in peaks.dtype.names: - raise ValueError - for xy in ('x', 'y'): - result[xy] = peaks[f'{xy}_{algorithm}'] - return result +"""Position reconstruction for Xenon-nT""" + +import numpy as np +import strax +import straxen +from warnings import warn + + +export, __all__ = strax.exporter() + + +DEFAULT_POSREC_ALGO = "mlp" + + +@export +@strax.takes_config( + strax.Option('min_reconstruction_area', + help='Skip reconstruction if area (PE) is less than this', + default=10, infer_type=False,), + strax.Option('n_top_pmts', default=straxen.n_top_pmts, infer_type=False, + help="Number of top PMTs") +) +class PeakPositionsBaseNT(strax.Plugin): + """ + Base class for reconstructions. + This class should only be used when subclassed for the different + algorithms. Provides x_algorithm, y_algorithm for all peaks > than + min-reconstruction area based on the top array. + """ + depends_on = ('peaks',) + algorithm = None + compressor = 'zstd' + parallel = True # can set to "process" after #82 + __version__ = '0.0.0' + + def infer_dtype(self): + if self.algorithm is None: + raise NotImplementedError(f'Base class should not be used without ' + f'algorithm as done in {__class__.__name__}') + dtype = [('x_' + self.algorithm, np.float32, + f'Reconstructed {self.algorithm} S2 X position (cm), uncorrected'), + ('y_' + self.algorithm, np.float32, + f'Reconstructed {self.algorithm} S2 Y position (cm), uncorrected')] + dtype += strax.time_fields + return dtype + + def get_tf_model(self): + """ + Simple wrapper to have several tf_model_mlp, tf_model_cnn, .. + point to this same function in the compute method + """ + model = getattr(self, f'tf_model_{self.algorithm}', None) + if model is None: + warn(f'Setting model to None for {self.__class__.__name__} will ' + f'set only nans as output for {self.algorithm}') + if isinstance(model, str): + raise ValueError(f'open files from tf:// protocol! Got {model} ' + f'instead, see tests/test_posrec.py for examples.') + return model + + def compute(self, peaks): + result = np.ones(len(peaks), dtype=self.dtype) + result['time'], result['endtime'] = peaks['time'], strax.endtime(peaks) + + result['x_' + self.algorithm] *= float('nan') + result['y_' + self.algorithm] *= float('nan') + model = self.get_tf_model() + + if model is None: + # This plugin is disabled since no model is provided + return result + + # Keep large peaks only + peak_mask = peaks['area'] > self.config['min_reconstruction_area'] + if not np.sum(peak_mask): + # Nothing to do, and .predict crashes on empty arrays + return result + + # Getting actual position reconstruction + area_per_channel_top = peaks['area_per_channel'][ + peak_mask, + 0:self.config['n_top_pmts']] + with np.errstate(divide='ignore', invalid='ignore'): + area_per_channel_top = ( + area_per_channel_top / + np.max(area_per_channel_top, axis=1).reshape(-1, 1) + ) + area_per_channel_top = area_per_channel_top.reshape(-1, + self.config['n_top_pmts'] + ) + output = model.predict(area_per_channel_top) + + # writing output to the result + result['x_' + self.algorithm][peak_mask] = output[:, 0] + result['y_' + self.algorithm][peak_mask] = output[:, 1] + return result + + +@export +class PeakPositionsMLP(PeakPositionsBaseNT): + """Multilayer Perceptron (MLP) neural net for position reconstruction""" + provides = "peak_positions_mlp" + algorithm = "mlp" + + tf_model_mlp = straxen.URLConfig( + default=f'tf://' + f'resource://' + f'cmt://{algorithm}_model' + f'?version=ONLINE' + f'&run_id=plugin.run_id' + f'&fmt=abs_path', + help='MLP model. Should be opened using the "tf" descriptor. ' + 'Set to "None" to skip computation', + cache=3, + ) + + +@export +class PeakPositionsGCN(PeakPositionsBaseNT): + """Graph Convolutional Network (GCN) neural net for position reconstruction""" + provides = "peak_positions_gcn" + algorithm = "gcn" + __version__ = '0.0.1' + + tf_model_gcn = straxen.URLConfig( + default=f'tf://' + f'resource://' + f'cmt://{algorithm}_model' + f'?version=ONLINE' + f'&run_id=plugin.run_id' + f'&fmt=abs_path', + help='GCN model. Should be opened using the "tf" descriptor. ' + 'Set to "None" to skip computation', + cache=3, + ) + + +@export +class PeakPositionsCNN(PeakPositionsBaseNT): + """Convolutional Neural Network (CNN) neural net for position reconstruction""" + provides = "peak_positions_cnn" + algorithm = "cnn" + __version__ = '0.0.1' + + tf_model_cnn = straxen.URLConfig( + default=f'tf://' + f'resource://' + f'cmt://{algorithm}_model' + f'?version=ONLINE' + f'&run_id=plugin.run_id' + f'&fmt=abs_path', + cache=3, + ) + + +@export +class PeakPositionsNT(strax.MergeOnlyPlugin): + """ + Merge the reconstructed algorithms of the different algorithms + into a single one that can be used in Event Basics. + + Select one of the plugins to provide the 'x' and 'y' to be used + further down the chain. Since we already have the information + needed here, there is no need to wait until events to make the + decision. + + Since the computation is trivial as it only combined the three + input plugins, don't save this plugins output. + """ + provides = "peak_positions" + depends_on = ("peak_positions_cnn", "peak_positions_mlp", "peak_positions_gcn") + save_when = strax.SaveWhen.NEVER + __version__ = '0.0.0' + + default_reconstruction_algorithm = straxen.URLConfig( + default=DEFAULT_POSREC_ALGO, + help="default reconstruction algorithm that provides (x,y)" + ) + + def infer_dtype(self): + dtype = strax.merged_dtype([self.deps[d].dtype_for(d) for d in self.depends_on]) + dtype += [('x', np.float32, 'Reconstructed S2 X position (cm), uncorrected'), + ('y', np.float32, 'Reconstructed S2 Y position (cm), uncorrected')] + return dtype + + def compute(self, peaks): + result = {dtype: peaks[dtype] for dtype in peaks.dtype.names} + algorithm = self.config['default_reconstruction_algorithm'] + for xy in ('x', 'y'): + result[xy] = peaks[f'{xy}_{algorithm}'] + return result + + +@export +@strax.takes_config( + strax.Option('recon_alg_included', + help='The list of all reconstruction algorithm considered.', + default=('_mlp', '_gcn', '_cnn'), infer_type=False, + ) +) +class S2ReconPosDiff(strax.Plugin): + """ + Plugin that provides position reconstruction difference for S2s in events, see note: + https://xe1t-wiki.lngs.infn.it/doku.php?id=xenon:shengchao:sr0:reconstruction_quality + """ + + __version__ = '0.0.3' + parallel = True + depends_on = 'event_basics' + provides = 's2_recon_pos_diff' + save_when = strax.SaveWhen.EXPLICIT + + def infer_dtype(self): + dtype = [ + ('s2_recon_avg_x', np.float32, + 'Mean value of x for main S2'), + ('alt_s2_recon_avg_x', np.float32, + 'Mean value of x for alternatice S2'), + ('s2_recon_avg_y', np.float32, + 'Mean value of y for main S2'), + ('alt_s2_recon_avg_y', np.float32, + 'Mean value of y for alternatice S2'), + ('s2_recon_pos_diff', np.float32, + 'Reconstructed position difference for main S2'), + ('alt_s2_recon_pos_diff', np.float32, + 'Reconstructed position difference for alternative S2'), + ] + dtype += strax.time_fields + return dtype + + def compute(self, events): + + result = np.zeros(len(events), dtype = self.dtype) + result['time'] = events['time'] + result['endtime'] = strax.endtime(events) + # Computing position difference + self.compute_pos_diff(events, result) + return result + + def cal_avg_and_std(self, values, axis = 1): + average = np.mean(values, axis = axis) + std = np.std(values, axis = axis) + return average, std + + def eval_recon(self, data, name_x_list, name_y_list): + """ + This function reads the name list based on s2/alt_s2 and all recon algorithm registered + Each row consists the reconstructed x/y and their average and standard deviation is calculated + """ + x_avg, x_std = self.cal_avg_and_std(np.array(data[name_x_list].tolist())) #lazy fix to delete field name in array, otherwise np.mean will complain + y_avg, y_std = self.cal_avg_and_std(np.array(data[name_y_list].tolist())) + r_std = np.sqrt(x_std**2 + y_std**2) + res = x_avg, y_avg, r_std + return res + + def compute_pos_diff(self, events, result): + + alg_list = self.config['recon_alg_included'] + for peak_type in ['s2', 'alt_s2']: + # Selecting S2s for pos diff + # - must exist (index != -1) + # - must have positive AFT + # - must contain all alg info + cur_s2_bool = (events[peak_type + '_index'] !=- 1) + cur_s2_bool &= (events[peak_type + '_area_fraction_top'] > 0) + for name in self.config['recon_alg_included']: + cur_s2_bool &= ~np.isnan(events[peak_type+'_x'+name]) + cur_s2_bool &= ~np.isnan(events[peak_type+'_y'+name]) + + # default value is nan, it will be ovewrite if the event satisfy the requirments + result[peak_type + '_recon_pos_diff'][:] = np.nan + result[peak_type + '_recon_avg_x'][:] = np.nan + result[peak_type + '_recon_avg_y'][:] = np.nan + + if np.any(cur_s2_bool): + name_x_list = [] + name_y_list = [] + for alg in alg_list: + name_x_list.append(peak_type + '_x' + alg) + name_y_list.append(peak_type + '_y' + alg) + + # Calculating average x,y, and position difference + x_avg, y_avg, r_std = self.eval_recon(events[cur_s2_bool], name_x_list, name_y_list) + result[peak_type + '_recon_pos_diff'][cur_s2_bool] = r_std + result[peak_type + '_recon_avg_x'][cur_s2_bool] = x_avg + result[peak_type + '_recon_avg_y'][cur_s2_bool] = y_avg diff --git a/straxen/plugins/pulse_processing.py b/straxen/plugins/pulse_processing.py index a0840db5e..fda8cc61e 100644 --- a/straxen/plugins/pulse_processing.py +++ b/straxen/plugins/pulse_processing.py @@ -13,7 +13,7 @@ # These are also needed in peaklets, since hitfinding is repeated HITFINDER_OPTIONS = tuple([ strax.Option( - 'hit_min_amplitude', track=True, + 'hit_min_amplitude', track=True, infer_type=False, default=('hit_thresholds_tpc', 'ONLINE', True), help='Minimum hit amplitude in ADC counts above baseline. ' 'Specify as a tuple of length n_tpc_pmts, or a number,' @@ -26,7 +26,7 @@ HITFINDER_OPTIONS_he = tuple([ strax.Option( 'hit_min_amplitude_he', - default=('hit_thresholds_he', 'ONLINE', True), track=True, + default=('hit_thresholds_he', 'ONLINE', True), track=True, infer_type=False, help='Minimum hit amplitude in ADC counts above baseline. ' 'Specify as a tuple of length n_tpc_pmts, or a number,' 'or a string like "pmt_commissioning_initial" which means calling' @@ -41,50 +41,50 @@ @export @strax.takes_config( strax.Option('hev_gain_model', - default=('disabled', None), + default=('disabled', None), infer_type=False, help='PMT gain model used in the software high-energy veto.' 'Specify as (model_type, model_config)'), strax.Option( 'baseline_samples', - default=40, + default=40, infer_type=False, help='Number of samples to use at the start of the pulse to determine ' 'the baseline'), # Tail veto options strax.Option( 'tail_veto_threshold', - default=0, + default=0, infer_type=False, help=("Minimum peakarea in PE to trigger tail veto." "Set to None, 0 or False to disable veto.")), strax.Option( 'tail_veto_duration', - default=int(3e6), + default=int(3e6), infer_type=False, help="Time in ns to veto after large peaks"), strax.Option( 'tail_veto_resolution', - default=int(1e3), + default=int(1e3), infer_type=False, help="Time resolution in ns for pass-veto waveform summation"), strax.Option( 'tail_veto_pass_fraction', - default=0.05, + default=0.05, infer_type=False, help="Pass veto if maximum amplitude above max * fraction"), strax.Option( 'tail_veto_pass_extend', - default=3, + default=3, infer_type=False, help="Extend pass veto by this many samples (tail_veto_resolution!)"), strax.Option( 'max_veto_value', - default=None, + default=None, infer_type=False, help="Optionally pass a HE peak that exceeds this absolute area. " "(if performing a hard veto, can keep a few statistics.)"), # PMT pulse processing options strax.Option( 'pmt_pulse_filter', - default=None, + default=None, infer_type=False, help='Linear filter to apply to pulses, will be normalized.'), strax.Option( 'save_outside_hits', - default=(3, 20), + default=(3, 20), infer_type=False, help='Save (left, right) samples besides hits; cut the rest'), strax.Option( @@ -93,12 +93,12 @@ strax.Option( 'check_raw_record_overlaps', - default=True, track=False, + default=True, track=False, infer_type=False, help='Crash if any of the pulses in raw_records overlap with others ' 'in the same channel'), strax.Option( 'allow_sloppy_chunking', - default=False, track=False, + default=False, track=False, infer_type=False, help=('Use a default baseline for incorrectly chunked fragments. ' 'This is a kludge for improperly converted XENON1T data.')), @@ -185,7 +185,6 @@ def compute(self, raw_records, start, end): # Do not trust in DAQ + strax.baseline to leave the # out-of-bounds samples to zero. - # TODO: better to throw an error if something is nonzero strax.zero_out_of_bounds(r) strax.baseline(r, @@ -242,7 +241,7 @@ def compute(self, raw_records, start, end): @export @strax.takes_config( - strax.Option('n_he_pmts', track=False, default=752, + strax.Option('n_he_pmts', track=False, default=752, infer_type=False, help="Maximum channel of the he channels"), strax.Option('record_length', default=110, track=False, type=int, help="Number of samples per raw_record"), diff --git a/straxen/plugins/veto_events.py b/straxen/plugins/veto_events.py index 6783d2581..5f6d83168 100644 --- a/straxen/plugins/veto_events.py +++ b/straxen/plugins/veto_events.py @@ -12,11 +12,11 @@ @strax.takes_config( - strax.Option('event_left_extension_nv', default=0, + strax.Option('event_left_extension_nv', default=0, infer_type=False, help="Extends events this many ns to the left"), - strax.Option('event_resolving_time_nv', default=300, + strax.Option('event_resolving_time_nv', default=300, infer_type=False, help="Resolving time for fixed window coincidence [ns]."), - strax.Option('event_min_hits_nv', default=3, + strax.Option('event_min_hits_nv', default=3, infer_type=False, help="Minimum number of fully confined hitlets to define an event."), strax.Option('channel_map', track=False, type=immutabledict, help="immutabledict mapping subdetector to (min, max) " @@ -29,12 +29,8 @@ class nVETOEvents(strax.OverlapWindowPlugin): depends_on = 'hitlets_nv' provides = 'events_nv' data_kind = 'events_nv' - compressor = 'zstd' - # Needed in case we make again an muVETO child. - ends_with = '_nv' - __version__ = '0.0.2' events_seen = 0 @@ -224,12 +220,12 @@ def _make_event(hitlets: np.ndarray, @strax.takes_config( - strax.Option('position_max_time_nv', default=20, + strax.Option('position_max_time_nv', default=20, infer_type=False, help="Time [ns] within an event use to compute the azimuthal angle of the " "event."), strax.Option('nveto_pmt_position_map', help="nVeto PMT position mapfile", - default='nveto_pmt_position.csv'), + default='nveto_pmt_position.csv', infer_type=False,), ) class nVETOEventPositions(strax.Plugin): """ @@ -239,13 +235,8 @@ class nVETOEventPositions(strax.Plugin): depends_on = ('events_nv', 'hitlets_nv') data_kind = 'events_nv' provides = 'event_positions_nv' - - loop_over = 'events_nv' compressor = 'zstd' - # Needed in case we make again an muVETO child. - ends_with = '_nv' - __version__ = '0.1.0' def infer_dtype(self): @@ -281,7 +272,7 @@ def compute(self, events_nv, hitlets_nv): self.pmt_properties) event_angles['angle'] = angle compute_positions(event_angles, events_nv, hits_in_events, self.pmt_properties) - strax.copy_to_buffer(events_nv, event_angles, f'_copy_events{self.ends_with}') + strax.copy_to_buffer(events_nv, event_angles, f'_copy_events_nv') return event_angles @@ -436,13 +427,13 @@ def first_hitlets(hitlets_per_event: np.ndarray, @strax.takes_config( - strax.Option('event_left_extension_mv', default=0, + strax.Option('event_left_extension_mv', default=0, infer_type=False, child_option=True, parent_option_name='event_left_extension_nv', help="Extends events this many ns to the left"), - strax.Option('event_resolving_time_mv', default=300, + strax.Option('event_resolving_time_mv', default=300, infer_type=False, child_option=True, parent_option_name='event_resolving_time_nv', help="Resolving time for fixed window coincidence [ns]."), - strax.Option('event_min_hits_mv', default=3, + strax.Option('event_min_hits_mv', default=3, infer_type=False, child_option=True, parent_option_name='event_min_hits_nv', help="Minimum number of fully confined hitlets to define an event."), ) @@ -454,9 +445,6 @@ class muVETOEvents(nVETOEvents): data_kind = 'events_mv' compressor = 'zstd' - - # Needed in case we make again an muVETO child. - ends_with = '_mv' child_plugin = True __version__ = '0.0.1' @@ -473,3 +461,99 @@ def get_window_size(self): def compute(self, hitlets_mv, start, end): return super().compute(hitlets_mv, start, end) + + +@strax.takes_config( + strax.Option('hardware_delay_nv', default=0, type=int, + help="Hardware delay to be added to the set electronics offset."), +) +class nVETOEventsSync(strax.Plugin): + """ + Plugin which computes time stamps which are synchronized with the + TPC. Uses delay set in the DAQ. + """ + depends_on = 'events_nv' + provides = 'events_sync_nv' + save_when = strax.SaveWhen.EXPLICIT + __version__ = '0.0.1' + + def infer_dtype(self): + dtype = [] + dtype += strax.time_fields + dtype += [(('Time of the event synchronized according to the total digitizer delay.', + 'time_sync'), np.int64), + (('Endtime of the event synchronized according to the total digitizer delay.', + 'endtime_sync'), np.int64), + ] + return dtype + + def setup(self): + self.total_delay = get_delay(self.run_id) + self.total_delay += self.config['hardware_delay_nv'] + + def compute(self, events_nv): + events_sync_nv = np.zeros(len(events_nv), self.dtype) + events_sync_nv['time'] = events_nv['time'] + events_sync_nv['endtime'] = events_nv['endtime'] + events_sync_nv['time_sync'] = events_nv['time'] + self.total_delay + events_sync_nv['endtime_sync'] = events_nv['endtime'] + self.total_delay + return events_sync_nv + + +def get_delay(run_id): + """ + Function which returns the total delay between TPC and veto for a + given run_id. Returns nan if + """ + try: + import utilix + except ModuleNotFoundError: + return np.nan + + delay = np.nan + if straxen.utilix_is_configured(): + run_db = utilix.DB() + run_meta = run_db.get_doc(run_id) + delay = _get_delay(run_meta) + + return delay + + +def _get_delay(run_meta): + """ + Loops over registry entries for correct entries and computes delay. + """ + delay_nveto = 0 + delay_tpc = 0 + for item in run_meta['daq_config']['registers']: + if (item['reg'] == '8034') and (item['board'] == 'tpc'): + delay_tpc = item['val'] + delay_tpc = int('0x'+delay_tpc, 16) + delay_tpc = 2*delay_tpc*10 + if (item['reg'] == '8170') and (item['board'] == 'neutron_veto'): + delay_nveto = item['val'] + delay_nveto = int('0x'+delay_nveto, 16) + delay_nveto = 16*delay_nveto # Delay is specified as multiple of 16 ns + delay = delay_tpc - delay_nveto + return delay + + +@strax.takes_config( + strax.Option('hardware_delay_mv', default=0, type=int, + help="Hardware delay to be added to the set electronics offset."), +) +class mVETOEventSync(nVETOEventsSync): + """ + Plugin which computes synchronized timestamps for the muon-veto with + respect to the TPC. + """ + depends_on = 'events_mv' + provides = 'events_sync_mv' + __version__ = '0.0.1' + child_plugin = True + + def setup(self): + self.total_delay = self.config['hardware_delay_mv'] + + def compute(self, events_mv): + return super().compute(events_mv) diff --git a/straxen/plugins/veto_hitlets.py b/straxen/plugins/veto_hitlets.py index ff09ba3e9..669651d49 100644 --- a/straxen/plugins/veto_hitlets.py +++ b/straxen/plugins/veto_hitlets.py @@ -1,4 +1,3 @@ -import numba import numpy as np from immutabledict import immutabledict @@ -6,52 +5,39 @@ import straxen from straxen.get_corrections import is_cmt_option +from straxen.plugins.veto_pulse_processing import MV_PREAMBLE, NV_HIT_OPTIONS export, __all__ = strax.exporter() -MV_PREAMBLE = 'Muno-Veto Plugin: Same as the corresponding nVETO-PLugin.\n' - @export @strax.takes_config( - strax.Option( - 'save_outside_hits_nv', - default=(3, 15), track=True, - help='Save (left, right) samples besides hits; cut the rest'), - strax.Option( - 'hit_min_amplitude_nv', - default=('hit_thresholds_nv', 'ONLINE', True), track=True, - help='Minimum hit amplitude in ADC counts above baseline. ' - 'Specify as a tuple of length n_nveto_pmts, or a number, ' - 'or a string like "pmt_commissioning_initial" which means calling ' - 'hitfinder_thresholds.py, ' - 'or a tuple like (correction=str, version=str, nT=boolean), ' - 'which means we are using cmt.'), + *NV_HIT_OPTIONS, strax.Option( 'min_split_nv', - default=0.063, track=True, + default=0.063, track=True, infer_type=False, help='Minimum height difference pe/sample between local minimum and maximum, ' 'that a pulse get split.'), strax.Option( 'min_split_ratio_nv', - default=0.75, track=True, + default=0.75, track=True, infer_type=False, help='Min ratio between local maximum and minimum to split pulse (zero to switch this ' 'off).'), strax.Option( 'entropy_template_nv', - default='flat', track=True, + default='flat', track=True, infer_type=False, help='Template data is compared with in conditional entropy. Can be either "flat" or an ' 'template array.'), strax.Option( 'entropy_square_data_nv', - default=False, track=True, + default=False, track=True, infer_type=False, help='Parameter which decides if data is first squared before normalized and compared to ' 'the template.'), strax.Option('channel_map', track=False, type=immutabledict, help="immutabledict mapping subdetector to (min, max) " "channel number."), strax.Option('gain_model_nv', - default=("to_pe_model_nv", "ONLINE", True), + default=("to_pe_model_nv", "ONLINE", True), infer_type=False, help='PMT gain model. Specify as (model_type, model_config, nT = True)'), ) class nVETOHitlets(strax.Plugin): @@ -71,7 +57,7 @@ class nVETOHitlets(strax.Plugin): Note: Hitlets are getting chopped if extended in not recorded regions. """ - __version__ = '0.1.0' + __version__ = '0.1.1' parallel = 'process' rechunk_on_save = True @@ -81,7 +67,6 @@ class nVETOHitlets(strax.Plugin): provides = 'hitlets_nv' data_kind = 'hitlets_nv' - ends_with = '_nv' dtype = strax.hitlet_dtype() @@ -164,11 +149,11 @@ def remove_switched_off_channels(hits, to_pe): @strax.takes_config( strax.Option( 'save_outside_hits_mv', - default=(2, 5), track=True, + default=(2, 5), track=True, infer_type=False, child_option=True, parent_option_name='save_outside_hits_nv', help='Save (left, right) samples besides hits; cut the rest'), strax.Option( - 'hit_min_amplitude_mv', + 'hit_min_amplitude_mv', infer_type=False, default=('hit_thresholds_mv', 'ONLINE', True), track=True, help='Minimum hit amplitude in ADC counts above baseline. ' 'Specify as a tuple of length n_mveto_pmts, or a number, ' @@ -178,30 +163,30 @@ def remove_switched_off_channels(hits, to_pe): 'which means we are using cmt.'), strax.Option( 'min_split_mv', - default=100, track=True, + default=100, track=True, infer_type=False, child_option=True, parent_option_name='min_split_nv', help='Minimum height difference pe/sample between local minimum and maximum, ' 'that a pulse get split.'), strax.Option( 'min_split_ratio_mv', - default=0, track=True, + default=0, track=True, infer_type=False, child_option=True, parent_option_name='min_split_ratio_nv', help='Min ratio between local maximum and minimum to split pulse (zero to switch this ' 'off).'), strax.Option( 'entropy_template_mv', - default='flat', track=True, + default='flat', track=True, infer_type=False, child_option=True, parent_option_name='entropy_template_nv', help='Template data is compared with in conditional entropy. Can be either "flat" or a ' 'template array.'), strax.Option( 'entropy_square_data_mv', - default=False, track=True, + default=False, track=True, infer_type=False, child_option=True, parent_option_name='entropy_square_data_nv', help='Parameter which decides if data is first squared before normalized and compared to ' 'the template.'), strax.Option('gain_model_mv', - default=("to_pe_model_mv", "ONLINE", True), + default=("to_pe_model_mv", "ONLINE", True), infer_type=False, child_option=True, parent_option_name='gain_model_nv', help='PMT gain model. Specify as (model_type, model_config)'), ) diff --git a/straxen/plugins/veto_pulse_processing.py b/straxen/plugins/veto_pulse_processing.py index fb54420a1..8d0c37855 100644 --- a/straxen/plugins/veto_pulse_processing.py +++ b/straxen/plugins/veto_pulse_processing.py @@ -8,21 +8,13 @@ export, __all__ = strax.exporter() MV_PREAMBLE = 'Muno-Veto Plugin: Same as the corresponding nVETO-PLugin.\n' - - -@export -@strax.takes_config( +NV_HIT_OPTIONS = ( strax.Option( 'save_outside_hits_nv', - default=(3, 15), track=True, + default=(3, 15), track=True, infer_type=False, help='Save (left, right) samples besides hits; cut the rest'), strax.Option( - 'baseline_samples_nv', - default=('baseline_samples_nv', 'ONLINE', True), track=True, - help='Number of samples to use at the start of the pulse to determine ' - 'the baseline'), - strax.Option( - 'hit_min_amplitude_nv', + 'hit_min_amplitude_nv', infer_type=False, default=('hit_thresholds_nv', 'ONLINE', True), track=True, help='Minimum hit amplitude in ADC counts above baseline. ' 'Specify as a tuple of length n_nveto_pmts, or a number, ' @@ -30,9 +22,20 @@ 'hitfinder_thresholds.py, ' 'or a tuple like (correction=str, version=str, nT=boolean), ' 'which means we are using cmt.'), +) + + +@export +@strax.takes_config( + *NV_HIT_OPTIONS, + strax.Option( + 'baseline_samples_nv', infer_type=False, + default=('baseline_samples_nv', 'ONLINE', True), track=True, + help='Number of samples to use at the start of the pulse to determine ' + 'the baseline'), strax.Option( 'min_samples_alt_baseline_nv', - default=None, track=True, + default=None, track=True, infer_type=False, help='Min. length of pulse before alternative baselineing via ' 'pulse median is applied.'), ) @@ -55,7 +58,6 @@ class nVETOPulseProcessing(strax.Plugin): depends_on = 'raw_records_coin_nv' provides = 'records_nv' data_kind = 'records_nv' - ends_with = '_nv' def setup(self): if isinstance(self.config['baseline_samples_nv'], int): @@ -172,17 +174,17 @@ def _correct_baseline(records): @strax.takes_config( strax.Option( 'save_outside_hits_mv', - default=(2, 5), track=True, + default=(2, 5), track=True, infer_type=False, child_option=True, parent_option_name='save_outside_hits_nv', help='Save (left, right) samples besides hits; cut the rest'), strax.Option( 'baseline_samples_mv', - default=100, track=True, + default=100, track=True, infer_type=False, child_option=True, parent_option_name='baseline_samples_nv', help='Number of samples to use at the start of the pulse to determine ' 'the baseline'), strax.Option( - 'hit_min_amplitude_mv', + 'hit_min_amplitude_mv', infer_type=False, default=('hit_thresholds_mv', 'ONLINE', True), track=True, help='Minimum hit amplitude in ADC counts above baseline. ' 'Specify as a tuple of length n_mveto_pmts, or a number, ' @@ -192,7 +194,7 @@ def _correct_baseline(records): 'which means we are using cmt.'), strax.Option( 'check_raw_record_overlaps', - default=True, track=False, + default=True, track=False, infer_type=False, help='Crash if any of the pulses in raw_records overlap with others ' 'in the same channel'), ) diff --git a/straxen/plugins/veto_veto_regions.py b/straxen/plugins/veto_veto_regions.py index bafbc1ef8..f6a3e171e 100644 --- a/straxen/plugins/veto_veto_regions.py +++ b/straxen/plugins/veto_veto_regions.py @@ -31,15 +31,14 @@ class nVETOVetoRegions(strax.OverlapWindowPlugin): tagged as vetoed. An event must surpass all three criteria to trigger a veto. """ - __version__ = '0.0.1' + __version__ = '0.0.2' - depends_on = 'events_nv' + depends_on = ('events_nv', 'events_sync_nv') provides = 'veto_regions_nv' data_kind = 'veto_regions_nv' save_when = strax.SaveWhen.NEVER dtype = strax.time_fields - ends_with = '_nv' def get_window_size(self): return 10 * (self.config['veto_left_extension_nv'] + self.config['veto_right_extension_nv']) @@ -112,8 +111,8 @@ def _create_veto_intervals(events, if not satisfies_veto_trigger: continue - res[offset]['time'] = ev['time'] - left_extension - res[offset]['endtime'] = ev['endtime'] + right_extension + res[offset]['time'] = ev['time_sync'] - left_extension + res[offset]['endtime'] = ev['endtime_sync'] + right_extension offset += 1 return res[:offset] @@ -144,13 +143,12 @@ class muVETOVetoRegions(nVETOVetoRegions): __doc__ = MV_PREAMBLE + nVETOVetoRegions.__doc__ __version__ = '0.0.1' - depends_on = 'events_mv' + depends_on = ('events_mv', 'events_sync_mv') provides = 'veto_regions_mv' data_kind = 'veto_regions_mv' save_when = strax.SaveWhen.NEVER dtype = strax.time_fields - ends_with = '_mv' child_plugin = True def get_window_size(self): diff --git a/straxen/plugins/x1t_cuts.py b/straxen/plugins/x1t_cuts.py index 1b923957f..1fa0cc4c4 100644 --- a/straxen/plugins/x1t_cuts.py +++ b/straxen/plugins/x1t_cuts.py @@ -236,11 +236,16 @@ class SR1Cuts(strax.MergeOnlyPlugin): class FiducialEvents(strax.Plugin): depends_on = ['event_info', 'cut_fiducial_cylinder_1t'] data_kind = 'fiducial_events' + __version__ = '0.0.1' def infer_dtype(self): dtype = [self.deps[d].dtype_for(d) for d in self.depends_on] - dtype.sort() - return strax.merged_dtype(dtype) + dtype = strax.merged_dtype(dtype) + return dtype def compute(self, events): - return events[events['cut_fiducial_cylinder_1t']] + fiducial_events = events[events['cut_fiducial_cylinder_1t']] + result = np.zeros(len(fiducial_events), dtype=self.dtype) + # Cast the fiducual events dtype into the expected format + strax.copy_to_buffer(fiducial_events, result, '_fiducial_copy') + return result diff --git a/straxen/rucio.py b/straxen/rucio.py index 2ab9b7045..1248567a1 100644 --- a/straxen/rucio.py +++ b/straxen/rucio.py @@ -1,22 +1,27 @@ -import socket -import re -import json -from bson import json_util -import os import glob import hashlib -import time -from utilix import xent_collection +import json +import os +import re +import socket +from warnings import warn +import numpy as np import strax +from bson import json_util +from utilix import xent_collection -export, __all__ = strax.exporter() +try: + import admix + from rucio.common.exception import DataIdentifierNotFound + HAVE_ADMIX = True +except ImportError: + HAVE_ADMIX = False -class TooMuchDataError(Exception): - pass +export, __all__ = strax.exporter() -class DownloadError(Exception): +class TooMuchDataError(Exception): pass @@ -25,14 +30,14 @@ class RucioFrontend(strax.StorageFrontend): """ Uses the rucio client for the data find. """ - local_rses = {'UC_DALI_USERDISK': r'.rcc.'} + local_rses = {'UC_DALI_USERDISK': r'.rcc.', + 'SDSC_USERDISK': r'.sdsc.' + } local_did_cache = None - local_rucio_path = None - - # Some attributes to set if we have the remote backend - _did_client = None - _id_not_found_error = None - _rse_client = None + path = None + local_prefixes = {'UC_DALI_USERDISK': '/dali/lgrandi/rucio/', + 'SDSC_USERDISK': '/expanse/lustre/projects/chi135/shockley/rucio', + } def __init__(self, include_remote=False, @@ -59,12 +64,6 @@ def __init__(self, f"I'm not sure what to do with that.") local_rse = rse - # if there is no local host and we don't want to include the - # remote ones, we can't do anything - if local_rse is None and not include_remote: - raise RuntimeError(f"Could not find a local RSE for hostname {hostname}, " - f"and include_remote is False.") - self.local_rse = local_rse self.include_remote = include_remote @@ -74,11 +73,15 @@ def __init__(self, # rucio backend to read from that path rucio_prefix = self.get_rse_prefix(local_rse) self.backends.append(RucioLocalBackend(rucio_prefix)) - self.local_rucio_path = rucio_prefix + self.path = rucio_prefix if include_remote: - self._set_remote_imports() - self.backends.append(RucioRemoteBackend(staging_dir, download_heavy=download_heavy)) + if not HAVE_ADMIX: + self.log.warning("You passed use_remote=True to rucio fronted, " + "but you don't have access to admix/rucio! Using local backed only.") + else: + self.backends.append(RucioRemoteBackend(staging_dir, + download_heavy=download_heavy)) def __repr__(self): # List the relevant attributes @@ -89,29 +92,10 @@ def __repr__(self): representation += f', {attr}: {getattr(self, attr)}' return representation - def _set_remote_imports(self): - try: - from rucio.client.rseclient import RSEClient - from rucio.client.didclient import DIDClient - from rucio.common.exception import DataIdentifierNotFound - self._did_client = DIDClient() - self._id_not_found_error = DataIdentifierNotFound - self._rse_client = RSEClient() - except (ModuleNotFoundError, RuntimeError) as e: - raise ImportError('Cannot work with Rucio remote backend') from e - def find_several(self, keys, **kwargs): - if not len(keys): - return [] - - ret = [] - for key in keys: - did = key_to_rucio_did(key) - if self.did_is_local(did): - ret.append(('RucioLocalBackend', did)) - else: - ret.append(False) - return ret + # for performance, dont do find_several with this plugin + # we basically do the same query we would do in the RunDB plugin + return np.zeros_like(keys, dtype=bool).tolist() def _find(self, key: strax.DataKey, write, allow_incomplete, fuzzy_for, fuzzy_for_options): did = key_to_rucio_did(key) @@ -122,13 +106,11 @@ def _find(self, key: strax.DataKey, write, allow_incomplete, fuzzy_for, fuzzy_fo if self.did_is_local(did): return "RucioLocalBackend", did elif self.include_remote: - # only do this part if we include the remote backend try: - # check if the DID exists - scope, name = did.split(':') - self._did_client.get_did(scope, name) - return "RucioRemoteBackend", did - except self._id_not_found_error: + rules = admix.rucio.list_rules(did, state="OK") + if len(rules): + return "RucioRemoteBackend", did + except DataIdentifierNotFound: pass if fuzzy_for or fuzzy_for_options: @@ -140,15 +122,20 @@ def _find(self, key: strax.DataKey, write, allow_incomplete, fuzzy_for, fuzzy_fo raise strax.DataNotAvailable + def find(self, key: strax.DataKey, + write=False, + check_broken=False, + allow_incomplete=False, + fuzzy_for=tuple(), fuzzy_for_options=tuple()): + return super().find(key, write, check_broken, allow_incomplete, fuzzy_for, fuzzy_for_options) + def get_rse_prefix(self, rse): - if self._rse_client is not None: - rse_info = self._rse_client.get_rse(rse) - prefix = rse_info['protocols'][0]['prefix'] - elif self.local_rse == 'UC_DALI_USERDISK': - # If rucio is not loaded but we are on dali, look here: - prefix = '/dali/lgrandi/rucio/' + if HAVE_ADMIX: + prefix = admix.rucio.get_rse_prefix(rse) + elif self.local_rse in self.local_prefixes: + prefix = self.local_prefixes[self.local_rse] else: - raise ValueError(f'We are not on dali and cannot load rucio') + raise ValueError(f'We are not on dali nor expanse and thus cannot load rucio') return prefix def did_is_local(self, did): @@ -161,7 +148,7 @@ def did_is_local(self, did): """ try: md = self._get_backend("RucioLocalBackend").get_metadata(did) - except (strax.DataNotAvailable, strax.DataCorrupted): + except (strax.DataNotAvailable, strax.DataCorrupted, KeyError): return False return self._all_chunk_stored(md, did) @@ -175,7 +162,7 @@ def _all_chunk_stored(self, md: dict, did: str) -> bool: for chunk in md.get('chunks', []): if chunk.get('filename'): _did = f"{scope}:{chunk['filename']}" - ch_path = rucio_path(self.local_rucio_path, _did) + ch_path = rucio_path(self.path, _did) if not os.path.exists(ch_path): return False return True @@ -244,10 +231,13 @@ class RucioRemoteBackend(strax.FileSytemBackend): # datatypes we don't want to download since they're too heavy heavy_types = ['raw_records', 'raw_records_nv', 'raw_records_he'] + # for caching RSE locations + dset_cache = {} + def __init__(self, staging_dir, download_heavy=False, **kwargs): """ :param staging_dir: Path (a string) where to save data. Must be a writable location. - :param *args: Passed to strax.FileSystemBackend + :param download_heavy: Whether or not to allow downloads of the heaviest data (raw_records*, less aqmon and MV) :param **kwargs: Passed to strax.FileSystemBackend """ @@ -261,33 +251,22 @@ def __init__(self, staging_dir, download_heavy=False, **kwargs): except OSError: raise PermissionError(f"You told the rucio backend to download data to {staging_dir}, " f"but that path is not writable by your user") - super().__init__(**kwargs) self.staging_dir = staging_dir self.download_heavy = download_heavy - # Do it only when we actually load rucio - from rucio.client.downloadclient import DownloadClient - self.download_client = DownloadClient() - - def get_metadata(self, dset_did, rse='UC_OSG_USERDISK', **kwargs): - base_dir = os.path.join(self.staging_dir, did_to_dirname(dset_did)) - - # define where the metadata will go (or where it already might be) - number, dtype, hsh = parse_did(dset_did) - metadata_file = f"{dtype}-{hsh}-metadata.json" - metadata_path = os.path.join(base_dir, metadata_file) - - # download if it doesn't exist - if not os.path.exists(metadata_path): - metadata_did = f'{dset_did}-metadata.json' - did_dict = dict(did=metadata_did, - base_dir=base_dir, - no_subdir=True, - rse=rse - ) - print(f"Downloading {metadata_did}") - self._download([did_dict]) + def _get_metadata(self, dset_did, **kwargs): + if dset_did in self.dset_cache: + rse = self.dset_cache[dset_did] + else: + rses = admix.rucio.get_rses(dset_did) + rse = admix.downloader.determine_rse(rses) + self.dset_cache[dset_did] = rse + + metadata_did = f'{dset_did}-metadata.json' + downloaded = admix.download(metadata_did, rse=rse, location=self.staging_dir) + assert len(downloaded) == 1, f"{metadata_did} should be a single file. We found {len(downloaded)}." + metadata_path = downloaded[0] # check again if not os.path.exists(metadata_path): raise FileNotFoundError(f"No metadata found at {metadata_path}") @@ -295,10 +274,10 @@ def get_metadata(self, dset_did, rse='UC_OSG_USERDISK', **kwargs): with open(metadata_path, mode='r') as f: return json.loads(f.read()) - def _read_chunk(self, dset_did, chunk_info, dtype, compressor, rse="UC_OSG_USERDISK"): + def _read_chunk(self, dset_did, chunk_info, dtype, compressor): base_dir = os.path.join(self.staging_dir, did_to_dirname(dset_did)) chunk_file = chunk_info['filename'] - chunk_path = os.path.join(base_dir, chunk_file) + chunk_path = os.path.abspath(os.path.join(base_dir, chunk_file)) if not os.path.exists(chunk_path): number, datatype, hsh = parse_did(dset_did) if datatype in self.heavy_types and not self.download_heavy: @@ -307,16 +286,20 @@ def _read_chunk(self, dset_did, chunk_info, dtype, compressor, rse="UC_OSG_USERD "doing, pass download_heavy=True to the Rucio " "frontend. If not, check your context and/or ask " "someone if this raw data is needed locally.") - raise DownloadError(error_msg) + warn(error_msg) + raise strax.DataNotAvailable scope, name = dset_did.split(':') chunk_did = f"{scope}:{chunk_file}" - print(f"Downloading {chunk_did}") - did_dict = dict(did=chunk_did, - base_dir=base_dir, - no_subdir=True, - rse=rse, - ) - self._download([did_dict]) + if dset_did in self.dset_cache: + rse = self.dset_cache[dset_did] + else: + rses = admix.rucio.get_rses(dset_did) + rse = admix.downloader.determine_rse(rses) + self.dset_cache[dset_did] = rse + + downloaded = admix.download(chunk_did, rse=rse, location=self.staging_dir) + assert len(downloaded) == 1, f"{chunk_did} should be a single file. We found {len(downloaded)}." + assert chunk_path == downloaded[0] # check again if not os.path.exists(chunk_path): @@ -327,28 +310,6 @@ def _read_chunk(self, dset_did, chunk_info, dtype, compressor, rse="UC_OSG_USERD def _saver(self, dirname, metadata, **kwargs): raise NotImplementedError("Cannot save directly into rucio (yet), upload with admix instead") - def _download(self, did_dict_list): - # need to pass a list of dicts - # let's try 3 times - success = False - _try = 1 - while _try <= 3 and not success: - if _try > 1: - for did_dict in did_dict_list: - did_dict['rse'] = None - try: - self.download_client.download_dids(did_dict_list) - success = True - except KeyboardInterrupt: - raise - except Exception: - sleep = 3**_try - print(f"Download try #{_try} failed. Sleeping for {sleep} seconds and trying again...") - time.sleep(sleep) - _try += 1 - if not success: - raise DownloadError(f"Error downloading from rucio.") - class RucioSaver(strax.Saver): """ @@ -360,8 +321,11 @@ def __init__(self, *args, **kwargs): def rucio_path(root_dir, did): - """Convert target to path according to rucio convention. - See the __hash method here: https://github.com/rucio/rucio/blob/1.20.15/lib/rucio/rse/protocols/protocol.py""" + """ + Convert target to path according to rucio convention. + See the __hash method here: + https://github.com/rucio/rucio/blob/1.20.15/lib/rucio/rse/protocols/protocol.py + """ scope, filename = did.split(':') # disable bandit rucio_md5 = hashlib.md5(did.encode('utf-8')).hexdigest() # nosec @@ -392,19 +356,8 @@ def key_to_rucio_did(key: strax.DataKey) -> str: return f'xnt_{key.run_id}:{key.data_type}-{key.lineage_hash}' -def key_to_rucio_meta(key: strax.DataKey) -> str: - return f'{str(key.data_type)}-{key.lineage_hash}-metadata.json' - - def read_md(path: str) -> json: with open(path, mode='r') as f: md = json.loads(f.read(), object_hook=json_util.object_hook) return md - - -def list_datasets(scope): - from rucio.client.client import Client - rucio_client = Client() - datasets = [d for d in rucio_client.list_dids(scope, {'type': 'dataset'}, type='dataset')] - return datasets diff --git a/straxen/rundb.py b/straxen/rundb.py index e7d9e09e1..254ddd570 100644 --- a/straxen/rundb.py +++ b/straxen/rundb.py @@ -57,7 +57,6 @@ def __init__(self, :param new_data_path: Path where new files are to be written. Defaults to None: do not write new data New files will be registered in the runs db! - TODO: register under hostname alias (e.g. 'dali') :param reader_ini_name_is_mode: If True, will overwrite the 'mode' field with 'reader.ini.name'. :param rucio_path: What is the base path where Rucio is mounted @@ -122,7 +121,6 @@ def __init__(self, self.available_query.append({'host': host_alias}) if self.rucio_path is not None: - # TODO replace with rucio backend in the rucio module self.backends.append(RucioLocalBackend(self.rucio_path)) # When querying for rucio, add that it should be dali-userdisk self.available_query.append({'host': 'rucio-catalogue', @@ -136,8 +134,6 @@ def _data_query(self, key): 'data': { '$elemMatch': { 'type': key.data_type, - # TODO remove the meta.lineage since this doc - # entry is deprecated. '$and': [{'$or': [ {'meta.lineage': key.lineage}, {'did': @@ -182,7 +178,8 @@ def _find(self, key: strax.DataKey, if doc is not None: datum = doc['data'][0] error_message = f'Expected {rucio_key} got data on {datum["location"]}' - assert datum.get('did', '') == rucio_key, error_message + if datum.get('did', '') != rucio_key: + raise RuntimeError(error_message) backend_name = 'RucioLocalBackend' backend_key = key_to_rucio_did(key) return backend_name, backend_key @@ -209,7 +206,6 @@ def _find(self, key: strax.DataKey, 'host': self.hostname, 'type': key.data_type, 'protocol': strax.FileSytemBackend.__name__, - # TODO: duplication with metadata stuff elsewhere? 'meta': {'lineage': key.lineage} }}}) @@ -218,7 +214,6 @@ def _find(self, key: strax.DataKey, datum = doc['data'][0] if datum['host'] == 'rucio-catalogue': - # TODO this is due to a bad query in _data_query. We aren't rucio. raise strax.DataNotAvailable if write and not self._can_overwrite(key): @@ -267,20 +262,6 @@ def find_several(self, keys: typing.List[strax.DataKey], **kwargs): return [results_dict.get(k.run_id, False) for k in keys] - def _list_available(self, key: strax.DataKey, - allow_incomplete, fuzzy_for, fuzzy_for_options): - if fuzzy_for or fuzzy_for_options or allow_incomplete: - # The RunDB frontend can do neither fuzzy nor incomplete - warnings.warn('RunDB cannot do fuzzy or incomplete') - - q = self._data_query(key) - q.update(self.number_query()) - - cursor = self.collection.find( - q, - projection=[self.runid_field]) - return [x[self.runid_field] for x in cursor] - def _scan_runs(self, store_fields): query = self.number_query() projection = strax.to_str_tuple(list(store_fields)) @@ -291,8 +272,9 @@ def _scan_runs(self, store_fields): cursor = self.collection.find( filter=query, projection=projection) - for doc in tqdm(cursor, desc='Fetching run info from MongoDB', - total=cursor.count()): + for doc in strax.utils.tqdm( + cursor, desc='Fetching run info from MongoDB', + total=self.collection.count_documents(query)): del doc['_id'] if self.reader_ini_name_is_mode: doc['mode'] = \ @@ -301,7 +283,7 @@ def _scan_runs(self, store_fields): def run_metadata(self, run_id, projection=None): if run_id.startswith('_'): - # Superruns are currently not supprorted.. + # Superruns are currently not supported.. raise strax.DataNotAvailable if self.runid_field == 'name': diff --git a/straxen/scada.py b/straxen/scada.py index 4cd947b0a..2252e4839 100644 --- a/straxen/scada.py +++ b/straxen/scada.py @@ -18,16 +18,13 @@ from configparser import NoOptionError import sys -if any('jupyter' in arg for arg in sys.argv): - # In some cases we are not using any notebooks, - # Taken from 44952863 on stack overflow thanks! - from tqdm import tqdm_notebook as tqdm -else: - from tqdm import tqdm - export, __all__ = strax.exporter() +# Fancy tqdm style in notebooks +tqdm = strax.utils.tqdm + + @export class SCADAInterface: @@ -153,12 +150,12 @@ def get_scada_values(self, self._get_token() # Now loop over specified parameters and get the values for those. - iterator = enumerate(parameters.items()) - if self._use_progress_bar: - # wrap using progress bar - iterator = tqdm(iterator, total=len(parameters), desc='Load parameters') - - for ind, (k, p) in iterator: + for ind, (k, p) in tqdm( + enumerate(parameters.items()), + total=len(parameters), + desc='Load parameters', + disable=not self._use_progress_bar, + ): try: temp_df = self._query_single_parameter(start, end, k, p, @@ -175,7 +172,7 @@ def get_scada_values(self, f' {p} does not match the previous timestamps.') except ValueError as e: warnings.warn(f'Was not able to load parameters for "{k}". The reason was: "{e}".' - f'Continue without {k}.') + f'Continue without {k}.') temp_df = pd.DataFrame(columns=(k,)) if ind: @@ -321,31 +318,25 @@ def _query_single_parameter(self, offset = 1 else: offset = 0 - ntries = 0 - max_tries = 40000 # This corresponds to ~23 years - while ntries < max_tries: - temp_df = self._query(query, - self.SCData_URL, - start=(start // 10**9) + offset, - end=(end // 10**9), - query_type_lab=query_type_lab, - seconds_interval=every_nth_value, - raise_error_message=False # No valid value in query range... - ) # +1 since it is end before exclusive - if temp_df.empty: - # In case WebInterface does not return any data, e.g. if query range too small - break - times = (temp_df['timestampseconds'].values * 10**9).astype(' otherwise kwargs are + passed directly to straxen.get_resource. + """ + if fmt == 'abs_path': + downloader = straxen.MongoDownloader() + return downloader.download_single(name) + return straxen.get_resource(name, fmt=fmt) + + +@URLConfig.register('fsspec') +def read_file(path: str, **kwargs): + """Support fetching files from arbitrary filesystems + """ + with fsspec.open(path, **kwargs) as f: + content = f.read() + return content + + +@URLConfig.register('json') +def read_json(content: str, **kwargs): + """Load json string as a python object + """ + return json.loads(content) + + +@URLConfig.register('take') +def get_key(container: Container, take=None, **kwargs): + """ return a single element of a container + """ + if take is None: + return container + if not isinstance(take, list): + take = [take] + + # support for multiple keys for + # nested objects + for t in take: + container = container[t] + + return container + + +@URLConfig.register('format') +def format_arg(arg: str, **kwargs): + """apply pythons builtin format function to a string""" + return arg.format(**kwargs) + + +@URLConfig.register('itp_map') +def load_map(some_map, method='WeightedNearestNeighbors', **kwargs): + """Make an InterpolatingMap""" + return straxen.InterpolatingMap(some_map, method=method, **kwargs) + + +@URLConfig.register('bodega') +def load_value(name: str, bodega_version=None): + """Load a number from BODEGA file""" + if bodega_version is None: + raise ValueError('Provide version see e.g. tests/test_url_config.py') + nt_numbers = straxen.get_resource("XENONnT_numbers.json", fmt="json") + return nt_numbers[name][bodega_version]["value"] + + +@URLConfig.register('tf') +def open_neural_net(model_path: str, **kwargs): + # Nested import to reduce loading time of import straxen and it not + # base requirement + import tensorflow as tf + if not os.path.exists(model_path): + raise FileNotFoundError(f'No file at {model_path}') + with tempfile.TemporaryDirectory() as tmpdirname: + tar = tarfile.open(model_path, mode="r:gz") + tar.extractall(path=tmpdirname) + return tf.keras.models.load_model(tmpdirname) diff --git a/tests/test_basics.py b/tests/test_basics.py index 2ddfd95c9..042869c45 100644 --- a/tests/test_basics.py +++ b/tests/test_basics.py @@ -6,7 +6,6 @@ import shutil import uuid - test_run_id_1T = '180423_1021' @@ -40,17 +39,22 @@ def test_run_selection(self): assert run_id == test_run_id_1T def test_processing(self): - st = self.st - df = st.get_df(self.run_id, 'event_info') + df = self.st.get_df(self.run_id, 'event_info') assert len(df) > 0 assert 'cs1' in df.columns assert df['cs1'].sum() > 0 assert not np.all(np.isnan(df['x'].values)) + def test_event_info_double(self): + df = self.st.get_df(self.run_id, 'event_info_double') + assert 'cs2_a' in df.columns + assert df['cs2_a'].sum() > 0 + assert len(df) > 0 + def test_get_livetime_sec(self): st = self.st - events = st.get_array(self.run_id, 'peaks') + events = st.get_array(self.run_id, 'events') straxen.get_livetime_sec(st, test_run_id_1T, things=events) def test_mini_analysis(self): @@ -60,3 +64,34 @@ def count_rr(raw_records): n = self.st.count_rr(self.run_id) assert n > 100 + + @staticmethod + def _extract_latest_comment(context, + test_for_target='raw_records', + **context_kwargs, + ): + if context == 'xenonnt_online' and not straxen.utilix_is_configured(): + return + st = getattr(straxen.contexts, context)(**context_kwargs) + assert hasattr(st, 'extract_latest_comment'), "extract_latest_comment not added to context?" + st.extract_latest_comment() + assert st.runs is not None, "No registry build?" + assert 'comments' in st.runs.keys() + st.select_runs(available=test_for_target) + if context == 'demo': + assert len(st.runs) + assert f'{test_for_target}_available' in st.runs.keys() + + def test_extract_latest_comment_nt(self, **opt): + """Run the test for nt (but only 2000 runs""" + self._extract_latest_comment(context='xenonnt_online', + _minimum_run_number=10_000, + _maximum_run_number=12_000, + **opt) + + def test_extract_latest_comment_demo(self): + self._extract_latest_comment(context='demo') + + def test_extract_latest_comment_lone_hits(self): + """Run the test for some target that is not in the default availability check""" + self.test_extract_latest_comment_nt(test_for_target='lone_hits') diff --git a/tests/test_cmt.py b/tests/test_cmt.py index ad8805871..9567b320e 100644 --- a/tests/test_cmt.py +++ b/tests/test_cmt.py @@ -1,171 +1,157 @@ -"""Testing functions for the CMT services""" - -import strax -import straxen -import utilix -import numpy as np -from warnings import warn -from .test_basics import test_run_id_1T -from .test_plugins import test_run_id_nT -from straxen.common import aux_repo - -def test_connect_to_db(): - """ - Test connection to db - """ - if not straxen.utilix_is_configured(): - warn('Cannot do test becaus ' - 'no have access to the database.') - return - - username=None - password=None - mongo_url=None - is_nt=True - mongo_kwargs = {'url': mongo_url, - 'user': username, - 'password': password, - 'database': 'corrections'} - corrections_collection = utilix.rundb.xent_collection(**mongo_kwargs) - client = corrections_collection.database.client - cmt = strax.CorrectionsInterface(client, database_name='corrections') - df = cmt.read('global_xenonnt') - mes = 'Return empty dataframe when reading DB. Please check' - assert not df.empty, mes - -def test_1T_elife(): - """ - Test elife from CMT DB against historical data(aux file) - """ - if not straxen.utilix_is_configured(): - warn('Cannot do test becaus ' - 'no have access to the database.') - return - - elife_conf = ('elife_xenon1t', 'ONLINE', False) - elife_cmt = straxen.get_correction_from_cmt(test_run_id_1T, elife_conf) - elife_file = elife_conf=aux_repo + '3548132b55f81a43654dba5141366041e1daaf01/strax_files/elife.npy' - x = straxen.get_resource(elife_file, fmt='npy') - run_index = np.where(x['run_id'] == int(test_run_id_1T))[0] - elife = x[run_index[0]]['e_life'] - mes = 'Elife values do not match. Please check' - assert elife_cmt == elife, mes - -def test_cmt_conf_option(option='mlp_model', version='ONLINE', is_nT=True): - """ - Test CMT conf options - If wrong conf is passed it would raise an error accordingly - """ - if not straxen.utilix_is_configured(): - warn('Cannot do test becaus ' - 'no have access to the database.') - return - - conf = option, version, is_nT - correction = straxen.get_correction_from_cmt(test_run_id_nT, conf) - assert isinstance(correction, (float, int, str, np.ndarray)) - -def test_mc_wrapper_elife(run_id='009000', - cmt_id='016000', - mc_id='mc_0', - ): - """ - Test that for two different run ids, we get different elifes using - the MC wrapper. - :param run_id: First run-id (used for normal query) - :param cmt_id: Second run-id used as a CMT id (should not be the - same as run_id! otherwise the values might actually be the same - and the test does not work). - :return: None - """ - if not straxen.utilix_is_configured(): - return - assert np.abs(int(run_id) - int(cmt_id)) > 500, 'runs must be far apart' - - # First for the run-id let's get the value - elife = straxen.get_correction_from_cmt( - run_id, - ("elife", "ONLINE", True)) - - # Now, we repeat the same query using the MC wrapper, this should - # give us a different result since we are now asking for a very - # different run-number. - mc_elife_diff = straxen.get_correction_from_cmt( - mc_id, - ('cmt_run_id', cmt_id, "elife", "ONLINE", True) - ) - - # Repeat the query from above to verify, let's see if we are getting - # the same results as for `elife` above - mc_elife_same = straxen.get_correction_from_cmt( - mc_id, - ('cmt_run_id', run_id, "elife", "ONLINE", True) - ) - - assert elife != mc_elife_diff - assert elife == mc_elife_same - - -def test_mc_wrapper_gains(run_id='009000', - cmt_id='016000', - mc_id='mc_0', - execute=True, - ): - """ - Test that for two different run ids, we get different gains using - the MC wrapper. - :param run_id: First run-id (used for normal query) - :param cmt_id: Second run-id used as a CMT id (should not be the - same as run_id! otherwise the values might actually be the same - and the test does not work). - :param execute: Execute this test (this is set to False since the - test takes 9 minutes which is too long. We can activate this if - the testing time due to faster CMT queries is reduced). - :return: None - """ - if not straxen.utilix_is_configured() or not execute: - return - - assert np.abs(int(run_id) - int(cmt_id)) > 500, 'runs must be far apart' - - # First for the run-id let's get the value - gains = straxen.get_correction_from_cmt( - run_id, - ('to_pe_model', 'ONLINE', True)) - - # Now, we repeat the same query using the MC wrapper, this should - # give us a different result since we are now asking for a very - # different run-number. - mc_gains_diff = straxen.get_correction_from_cmt( - mc_id, - ('cmt_run_id', cmt_id, 'to_pe_model', 'ONLINE', True)) - - # Repeat the query from above to verify, let's see if we are getting - # the same results as for `gains` above - mc_gains_same = straxen.get_correction_from_cmt( - mc_id, - ('cmt_run_id', run_id, 'to_pe_model', 'ONLINE', True)) - - assert not np.all(gains == mc_gains_diff) - assert np.all(gains == mc_gains_same) - - -def test_is_cmt_option(): - """ - Catches if we change the CMT option structure. - The example dummy_option works at least before Jun 13 2021 - """ - dummy_option = ('hit_thresholds_tpc', 'ONLINE', True) - assert is_cmt_option(dummy_option), 'Structure of CMT options changed!' - - -def is_cmt_option(config): - """ - Check if the input configuration is cmt style. - """ - is_cmt = (isinstance(config, tuple) - and len(config) == 3 - and isinstance(config[0], str) - and isinstance(config[1], (str, int, float)) - and isinstance(config[2], bool)) - return is_cmt +"""Testing functions for the CMT services""" + +import strax +import straxen +import utilix +import numpy as np +from warnings import warn +from .test_basics import test_run_id_1T +from straxen.test_utils import nt_test_run_id as test_run_id_nT +from straxen.common import aux_repo +import unittest + + +@unittest.skipIf(not straxen.utilix_is_configured(), "No db access, cannot test!") +def test_connect_to_db(): + """ + Test connection to db + """ + corrections_collection = utilix.rundb.xent_collection(database='corrections') + client = corrections_collection.database.client + cmt = strax.CorrectionsInterface(client, database_name='corrections') + df = cmt.read('global_xenonnt') + mes = 'Return empty dataframe when reading DB. Please check' + assert not df.empty, mes + + +@unittest.skipIf(not straxen.utilix_is_configured(), "No db access, cannot test!") +def test_1T_elife(): + """ + Test elife from CMT DB against historical data(aux file) + """ + elife_conf = ('elife_xenon1t', 'ONLINE', False) + elife_cmt = straxen.get_correction_from_cmt(test_run_id_1T, elife_conf) + elife_file = aux_repo + '3548132b55f81a43654dba5141366041e1daaf01/strax_files/elife.npy' + x = straxen.get_resource(elife_file, fmt='npy') + run_index = np.where(x['run_id'] == int(test_run_id_1T))[0] + elife = x[run_index[0]]['e_life'] + mes = 'Elife values do not match. Please check' + assert elife_cmt == elife, mes + + +@unittest.skipIf(not straxen.utilix_is_configured(), "No db access, cannot test!") +def test_cmt_conf_option(option='mlp_model', version='ONLINE', is_nT=True): + """ + Test CMT conf options + If wrong conf is passed it would raise an error accordingly + """ + conf = option, version, is_nT + correction = straxen.get_correction_from_cmt(test_run_id_nT, conf) + assert isinstance(correction, (float, int, str, np.ndarray)) + + +@unittest.skipIf(not straxen.utilix_is_configured(), "No db access, cannot test!") +def test_mc_wrapper_elife(run_id='009000', + cmt_id='016000', + mc_id='mc_0', + ): + """ + Test that for two different run ids, we get different elifes using + the MC wrapper. + :param run_id: First run-id (used for normal query) + :param cmt_id: Second run-id used as a CMT id (should not be the + same as run_id! otherwise the values might actually be the same + and the test does not work). + :return: None + """ + assert np.abs(int(run_id) - int(cmt_id)) > 500, 'runs must be far apart' + + # First for the run-id let's get the value + elife = straxen.get_correction_from_cmt( + run_id, + ("elife", "ONLINE", True)) + + # Now, we repeat the same query using the MC wrapper, this should + # give us a different result since we are now asking for a very + # different run-number. + mc_elife_diff = straxen.get_correction_from_cmt( + mc_id, + ('cmt_run_id', cmt_id, "elife", "ONLINE", True) + ) + + # Repeat the query from above to verify, let's see if we are getting + # the same results as for `elife` above + mc_elife_same = straxen.get_correction_from_cmt( + mc_id, + ('cmt_run_id', run_id, "elife", "ONLINE", True) + ) + + assert elife != mc_elife_diff + assert elife == mc_elife_same + + +@unittest.skipIf(not straxen.utilix_is_configured(), "No db access, cannot test!") +def test_mc_wrapper_gains(run_id='009000', + cmt_id='016000', + mc_id='mc_0', + execute=True, + ): + """ + Test that for two different run ids, we get different gains using + the MC wrapper. + :param run_id: First run-id (used for normal query) + :param cmt_id: Second run-id used as a CMT id (should not be the + same as run_id! otherwise the values might actually be the same + and the test does not work). + :param execute: Execute this test (this is set to False since the + test takes 9 minutes which is too long. We can activate this if + the testing time due to faster CMT queries is reduced). + :return: None + """ + if not execute: + return + assert np.abs(int(run_id) - int(cmt_id)) > 500, 'runs must be far apart' + + # First for the run-id let's get the value + gains = straxen.get_correction_from_cmt( + run_id, + ('to_pe_model', 'ONLINE', True)) + + # Now, we repeat the same query using the MC wrapper, this should + # give us a different result since we are now asking for a very + # different run-number. + mc_gains_diff = straxen.get_correction_from_cmt( + mc_id, + ('cmt_run_id', cmt_id, 'to_pe_model', 'ONLINE', True)) + + # Repeat the query from above to verify, let's see if we are getting + # the same results as for `gains` above + mc_gains_same = straxen.get_correction_from_cmt( + mc_id, + ('cmt_run_id', run_id, 'to_pe_model', 'ONLINE', True)) + + assert not np.all(gains == mc_gains_diff) + assert np.all(gains == mc_gains_same) + + +def test_is_cmt_option(): + """ + Catches if we change the CMT option structure. + The example dummy_option works at least before Jun 13 2021 + """ + dummy_option = ('hit_thresholds_tpc', 'ONLINE', True) + assert straxen.is_cmt_option(dummy_option), 'Structure of CMT options changed!' + + dummy_url_config = 'cmt://correction?version=ONLINE&run_id=plugin.run_id' + assert straxen.is_cmt_option(dummy_url_config), 'Structure of CMT options changed!' + + +def test_replace_url_version(): + """Tests the replace_url_version function which is important in apply_cmt_version""" + url = 'cmt://elife?version=ONLINE&run_id=plugin.run_id' + url_check = 'cmt://elife?version=v1&run_id=plugin.run_id' + url_test = straxen.corrections_services.replace_url_version(url, 'v1') + if url_check != url_test: + raise AssertionError("replace_url_version did not do its job! " + f"it returns:\n{url_test}\nwhen it should return:\n{url_check}" + ) diff --git a/tests/test_common.py b/tests/test_common.py new file mode 100644 index 000000000..70a924450 --- /dev/null +++ b/tests/test_common.py @@ -0,0 +1,38 @@ +from straxen import rotate_perp_wires, tpc_r, aux_repo, get_resource +import numpy as np +from unittest import TestCase + + +class TestRotateWires(TestCase): + """Test that the rotate wires function works or raises usefull errors""" + + def test_rotate_wires(self): + """Use xy and see that we don't break""" + x_obs = np.linspace(-tpc_r, -tpc_r, 10) + y_obs = np.linspace(-tpc_r, -tpc_r, 10) + rotate_perp_wires(x_obs, y_obs) + with self.assertRaises(ValueError): + rotate_perp_wires(x_obs, y_obs[::2]) + + +class TestGetResourceFmt(TestCase): + """ + Replicate bug with ignored formatting + github.com/XENONnT/straxen/issues/741 + """ + json_file = aux_repo + '/01809798105f0a6c9efbdfcb5755af087824c234/sim_files/placeholder_map.json' # noqa + + def test_format(self): + """ + We did not do this correctly before, so let's make sure to do it right this time + """ + json_as_text = get_resource(self.json_file, fmt='text') + self.assertIsInstance(json_as_text, str) + # load it from memory + json_as_text_from_mem = get_resource(self.json_file, fmt='text') + self.assertEqual(json_as_text, json_as_text_from_mem) + + # Now let's check out if we do a JSON file + json_as_dict = get_resource(self.json_file, fmt='json') + self.assertIsInstance(json_as_dict, dict) + self.assertEqual(json_as_dict, get_resource(self.json_file, fmt='json')) diff --git a/tests/test_contexts.py b/tests/test_contexts.py index 049db13ed..8e45ada53 100644 --- a/tests/test_contexts.py +++ b/tests/test_contexts.py @@ -1,88 +1,163 @@ -"""For all of the context, do a quick check to see that we are able to search -a field (i.e. can build the dependencies in the context correctly) -See issue #233 and PR #236""" -from straxen.contexts import xenon1t_dali, xenon1t_led, fake_daq, demo -from straxen.contexts import xenonnt_led, xenonnt_online, xenonnt -import straxen -import tempfile -import os - - -## -# XENONnT -## - - -def test_xenonnt_online(): - st = xenonnt_online(_database_init=False, use_rucio=False) - st.search_field('time') - - -def test_xennonnt(): - if straxen.utilix_is_configured(): - st = xenonnt(_database_init=False, use_rucio=False) - st.search_field('time') - - -def test_xennonnt_latest(cmt_version='latest'): - if straxen.utilix_is_configured(): - st = xenonnt(cmt_version, _database_init=False, use_rucio=False) - st.search_field('time') - - -def test_xenonnt_led(): - st = xenonnt_led(_database_init=False, use_rucio=False) - st.search_field('time') - - -def test_nt_is_nt_online(): - if not straxen.utilix_is_configured(): - # Cannot contact CMT without the database - return - # Test that nT and nT online are the same - st_online = xenonnt_online(_database_init=False, use_rucio=False) - - st = xenonnt(_database_init=False, use_rucio=False) - for plugin in st._plugin_class_registry.keys(): - print(f'Checking {plugin}') - nt_key = st.key_for('0', plugin) - nt_online_key = st_online.key_for('0', plugin) - assert str(nt_key) == str(nt_online_key) - - -## -# XENON1T -## - - -def test_xenon1t_dali(): - st = xenon1t_dali() - st.search_field('time') - - -def test_demo(): - """ - Test the demo context. Since we download the folder to the current - working directory, make sure we are in a tempfolder where we - can write the data to - """ - with tempfile.TemporaryDirectory() as temp_dir: - try: - print("Temporary directory is ", temp_dir) - os.chdir(temp_dir) - st = demo() - st.search_field('time') - # On windows, you cannot delete the current process' - # working directory, so we have to chdir out first. - finally: - os.chdir('..') - - -def test_fake_daq(): - st = fake_daq() - st.search_field('time') - - -def test_xenon1t_led(): - st = xenon1t_led() - st.search_field('time') +"""For all of the context, do a quick check to see that we are able to search +a field (i.e. can build the dependencies in the context correctly) +See issue #233 and PR #236""" +from straxen.contexts import xenon1t_dali, xenon1t_led, fake_daq, demo +from straxen.contexts import xenonnt_led, xenonnt_online, xenonnt +import straxen +import tempfile +import os +import unittest + +## +# XENONnT +## + + +def test_xenonnt_online(): + st = xenonnt_online(_database_init=False) + st.search_field('time') + + +@unittest.skipIf(not straxen.utilix_is_configured(), "No db access, cannot test!") +def test_xennonnt(): + st = xenonnt(_database_init=False) + st.search_field('time') + + +def test_xenonnt_led(): + st = xenonnt_led(_database_init=False) + st.search_field('time') + + +@unittest.skipIf(not straxen.utilix_is_configured(), "No db access, cannot test!") +def test_nt_is_nt_online(): + # Test that nT and nT online are the same + st_online = xenonnt_online(_database_init=False) + + st = xenonnt(_database_init=False) + for plugin in st._plugin_class_registry.keys(): + print(f'Checking {plugin}') + nt_key = st.key_for('0', plugin) + nt_online_key = st_online.key_for('0', plugin) + assert str(nt_key) == str(nt_online_key) + + +@unittest.skipIf(not straxen.utilix_is_configured(), "No db access, cannot test!") +def test_offline(): + """ + Let's try and see which CMT versions are compatible with this straxen + version + """ + cmt = straxen.CorrectionsManagementServices() + cmt_versions = list(cmt.global_versions)[::-1] + print(cmt_versions) + success_for = [] + for global_version in cmt_versions: + try: + xenonnt(global_version) + success_for.append(global_version) + except straxen.CMTVersionError: + pass + print(f'This straxen version works with {success_for} but is ' + f'incompatible with {set(cmt_versions)-set(success_for)}') + + test = unittest.TestCase() + # We should always work for one offline and the online version + test.assertTrue(len(success_for) >= 2) + + +## +# XENON1T +## + + +def test_xenon1t_dali(): + st = xenon1t_dali() + st.search_field('time') + + +def test_demo(): + """ + Test the demo context. Since we download the folder to the current + working directory, make sure we are in a tempfolder where we + can write the data to + """ + with tempfile.TemporaryDirectory() as temp_dir: + try: + print("Temporary directory is ", temp_dir) + os.chdir(temp_dir) + st = demo() + st.search_field('time') + # On windows, you cannot delete the current process' + # working directory, so we have to chdir out first. + finally: + os.chdir('..') + + +def test_fake_daq(): + st = fake_daq() + st.search_field('time') + + +def test_xenon1t_led(): + st = xenon1t_led() + st.search_field('time') + +## +# WFSim +## + + +# Simulation contexts are only tested when special flags are set + +@unittest.skipIf('ALLOW_WFSIM_TEST' not in os.environ, + "if you want test wfsim context do `export 'ALLOW_WFSIM_TEST'=1`") +class TestSimContextNT(unittest.TestCase): + @staticmethod + def context(*args, **kwargs): + kwargs.setdefault('cmt_version', 'global_ONLINE') + return straxen.contexts.xenonnt_simulation(*args, **kwargs) + + def test_nt_sim_context_main(self): + st = self.context(cmt_run_id_sim='008000') + st.search_field('time') + + @unittest.skipIf(not straxen.utilix_is_configured(), "No db access, cannot test!") + def test_nt_sim_context_alt(self): + """Some examples of how to run with a custom WFSim context""" + self.context(cmt_run_id_sim='008000', cmt_run_id_proc='008001') + self.context(cmt_run_id_sim='008000', + cmt_option_overwrite_sim={'elife': 1e6}) + + self.context(cmt_run_id_sim='008000', + overwrite_fax_file_sim={'elife': 1e6}) + + @unittest.skipIf(not straxen.utilix_is_configured(), "No db access, cannot test!") + def test_nt_diverging_context_options(self): + """ + Test diverging options. Idea is that you can use different + settings for processing and generating data, should have been + handled by RawRecordsFromWFsim but is now hacked into the + xenonnt_simulation context + + Just to show how convoluted this syntax for the + xenonnt_simulation context / CMT is... + """ + self.context(cmt_run_id_sim='008000', + cmt_option_overwrite_sim={'elife': ('elife_constant', 1e6, True)}, + cmt_option_overwrite_proc={'elife': ('elife_constant', 1e5, True)}, + overwrite_from_fax_file_proc=True, + overwrite_from_fax_file_sim=True, + _config_overlap={'electron_lifetime_liquid': 'elife'}, + ) + + def test_nt_sim_context_bad_inits(self): + with self.assertRaises(RuntimeError): + self.context(cmt_run_id_sim=None, cmt_run_id_proc=None,) + + +@unittest.skipIf('ALLOW_WFSIM_TEST' not in os.environ, + "if you want test wfsim context do `export 'ALLOW_WFSIM_TEST'=1`") +def test_sim_context(): + st = straxen.contexts.xenon1t_simulation() + st.search_field('time') diff --git a/tests/test_count_pulses.py b/tests/test_count_pulses.py index 8bc60bc68..f1b3eeaab 100644 --- a/tests/test_count_pulses.py +++ b/tests/test_count_pulses.py @@ -43,7 +43,7 @@ def _check_pulse_count(records): # Not sure how to check lone pulses other than duplicating logic # already in count_pulses, so just do a basic check: assert count['lone_pulse_area'][ch] <= count['pulse_area'][ch] - + # Check baseline values: # nan does not exist for integer: mean = straxen.NO_PULSE_COUNTS if np.isnan(np.mean(rc0['baseline'])) else int(np.mean(rc0['baseline'])) diff --git a/tests/test_daq_reader.py b/tests/test_daq_reader.py new file mode 100644 index 000000000..fb6b9c55c --- /dev/null +++ b/tests/test_daq_reader.py @@ -0,0 +1,168 @@ +import os +import shutil +import unittest +import strax +from straxen import download_test_data, get_resource +from straxen.plugins.daqreader import ArtificialDeadtimeInserted, \ + DAQReader, \ + ARTIFICIAL_DEADTIME_CHANNEL +from datetime import timezone, datetime +from time import sleep + + +class DummyDAQReader(DAQReader): + """Dummy version of DAQReader with different provides and different lineage""" + provides = ['raw_records', + 'raw_records_nv', + 'raw_records_aqmon', + ] + dummy_version = strax.Config( + default=None, + track=True, + help="Extra option to make sure that we are getting a different lineage if we want to", + ) + data_kind = dict(zip(provides, provides)) + rechunk_on_save = False + + def _path(self, chunk_i): + path = super()._path(chunk_i) + path_exists = os.path.exists(path) + print(f'looked for {chunk_i} on {path}. Is found = {path_exists}') + return path + + +class TestDAQReader(unittest.TestCase): + """ + Test DAQReader with a few chunks of amstrax data: + https://github.com/XAMS-nikhef/amstrax + + + This class is structured with three parts: + - A. The test(s) where we execute some tests to make sure the + DAQ-reader works well; + - B. Setup and teardown logic which downloads/removes test data if + we run this test so that we get a fresh sample of data every time + we run this test; + - C. Some utility functions for part A and B (like setting the + context etc). + """ + + run_id = '999999' + run_doc_name = 'rundoc_999999.json' + live_data_path = f'./live_data/{run_id}' + rundoc_file = 'https://raw.githubusercontent.com/XAMS-nikhef/amstrax_files/73681f112d748f6cd0e95045970dd29c44e983b0/data/rundoc_999999.json' # noqa + data_file = 'https://raw.githubusercontent.com/XAMS-nikhef/amstrax_files/73681f112d748f6cd0e95045970dd29c44e983b0/data/999999.tar' # noqa + + # # Part A. the actual tests + def test_make(self) -> None: + """ + Test if we can run the daq-reader without chrashing and if we + actually stored the data after making it. + """ + run_id = self.run_id + for target, plugin_class in self.st._plugin_class_registry.items(): + self.st.make(run_id, target) + sleep(0.5) # allow os to rename the file + if plugin_class.save_when >= strax.SaveWhen.TARGET: + self.assertTrue( + self.st.is_stored(run_id, target), + ) + + def test_insert_deadtime(self): + """ + In the DAQ reader, we need a mimimium quiet period to say where + we can start/end a chunk. Test that this information gets + propagated to the ARTIFICIAL_DEADTIME_CHANNEL (in + raw-records-aqmon) if we set this value to an unrealistic value + of 0.5 s. + """ + st = self.st.new_context() + st.set_config({'safe_break_in_pulses': int(0.5e9), + 'dummy_version': 'test_insert_deadtime', + }) + + with self.assertWarns(ArtificialDeadtimeInserted): + st.make(self.run_id, 'raw_records') + + rr_aqmon = st.get_array(self.run_id, 'raw_records_aqmon') + self.assertTrue(len(rr_aqmon)) + + def test_invalid_setting(self): + """The safe break in pulses cannot be longer than the chunk size""" + st = self.st.new_context() + st.set_config({'safe_break_in_pulses': int(3600e9), + 'dummy_version': 'test_invalid_setting', + }) + with self.assertRaises(ValueError): + st.make(self.run_id, 'raw_records') + + # # Part B. data-download and cleanup + @classmethod + def setUpClass(cls) -> None: + st = strax.Context() + st.register(DummyDAQReader) + st.storage = [strax.DataDirectory('./daq_test_data')] + st.set_config({'daq_input_dir': cls.live_data_path}) + cls.st = st + + @classmethod + def tearDownClass(cls) -> None: + path_live = f'live_data/{cls.run_id}' + if os.path.exists(path_live): + shutil.rmtree(path_live) + print(f'rm {path_live}') + + def setUp(self) -> None: + if not os.path.exists(self.live_data_path): + print(f'Fetch {self.live_data_path}') + self.download_test_data() + rd = self.get_metadata() + st = self.set_context_config(self.st, rd) + self.st = st + self.assertFalse(self.st.is_stored(self.run_id, 'raw_records')) + + def tearDown(self) -> None: + data_path = self.st.storage[0].path + if os.path.exists(data_path): + shutil.rmtree(data_path) + print(f'rm {data_path}') + + # # Part C. Some utility functions for A & B + def download_test_data(self): + download_test_data(self.data_file) + self.assertTrue(os.path.exists(self.live_data_path)) + + def get_metadata(self): + md = get_resource(self.rundoc_file, fmt='json') + # This is a flat dict but we need to have a datetime object, + # since this is only a test, let's just replace it with a + # placeholder + md['start'] = datetime.now() + return md + + @staticmethod + def set_context_config(st, run_doc): + """Update context with fields needed by the DAQ reader""" + daq_config = run_doc['daq_config'] + st.set_context_config(dict(forbid_creation_of=tuple())) + st.set_config({ + 'channel_map': + dict( + # (Minimum channel, maximum channel) + # Channels must be listed in a ascending order! + tpc=(0, 1), + nveto=(1, 2), + aqmon=(ARTIFICIAL_DEADTIME_CHANNEL, ARTIFICIAL_DEADTIME_CHANNEL + 1), + )}) + update_config = { + 'readout_threads': daq_config['processing_threads'], + 'record_length': daq_config['strax_fragment_payload_bytes'] // 2, + 'max_digitizer_sampling_time': 10, + 'run_start_time': run_doc['start'].replace(tzinfo=timezone.utc).timestamp(), + 'daq_chunk_duration': int(daq_config['strax_chunk_length'] * 1e9), + 'daq_overlap_chunk_duration': int(daq_config['strax_chunk_overlap'] * 1e9), + 'daq_compressor': daq_config.get('compressor', 'lz4'), + } + print(f'set config to {update_config}') + st.set_config(update_config) + return st diff --git a/tests/test_database_frontends.py b/tests/test_database_frontends.py new file mode 100644 index 000000000..e22ab2612 --- /dev/null +++ b/tests/test_database_frontends.py @@ -0,0 +1,204 @@ +import unittest +import strax +from strax.testutils import Records, Peaks +import straxen +import os +import shutil +import tempfile +import pymongo +import datetime + + +def mongo_uri_not_set(): + return 'TEST_MONGO_URI' not in os.environ + + +@unittest.skipIf(mongo_uri_not_set(), "No access to test database") +class TestRunDBFrontend(unittest.TestCase): + """ + Test the saving behavior of the context with the straxen.RunDB + + Requires write access to some pymongo server, the URI of witch is to be set + as an environment variable under: + + TEST_MONGO_URI + + At the moment this is just an empty database but you can also use some free + ATLAS mongo server. + """ + _run_test = True + + @classmethod + def setUpClass(cls) -> None: + # Just to make sure we are running some mongo server, see test-class docstring + cls.test_run_ids = ['0', '1'] + cls.all_targets = ('peaks', 'records') + + uri = os.environ.get('TEST_MONGO_URI') + db_name = 'test_rundb' + cls.collection_name = 'test_rundb_coll' + client = pymongo.MongoClient(uri) + cls.database = client[db_name] + collection = cls.database[cls.collection_name] + cls.path = os.path.join(tempfile.gettempdir(), 'strax_data') + # assert cls.collection_name not in cls.database.list_collection_names() + + if not straxen.utilix_is_configured(): + # Bit of an ugly hack but there is no way to get around this + # function even though we don't need it + straxen.rundb.utilix.rundb.xent_collection = lambda *args, **kwargs: collection + + cls.rundb_sf = straxen.RunDB(readonly=False, + runid_field='number', + new_data_path=cls.path, + minimum_run_number=-1, + rucio_path='./strax_test_data', + ) + cls.rundb_sf.client = client + cls.rundb_sf.collection = collection + + cls.st = strax.Context(register=[Records, Peaks], + storage=[cls.rundb_sf], + use_per_run_defaults=False, + config=dict(bonus_area=0), + ) + + def setUp(self) -> None: + for run_id in self.test_run_ids: + self.collection.insert_one(_rundoc_format(run_id)) + assert not self.is_all_targets_stored + + def tearDown(self): + self.database[self.collection_name].drop() + if os.path.exists(self.path): + print(f'rm {self.path}') + shutil.rmtree(self.path) + + @property + def collection(self): + return self.database[self.collection_name] + + @property + def is_all_targets_stored(self) -> bool: + """This should always be False as one of the targets (records) is not stored in mongo""" + return all([all( + [self.st.is_stored(r, t) for t in self.all_targets]) + for r in self.test_run_ids]) + + def test_finding_runs(self): + rdb = self.rundb_sf + col = self.database[self.collection_name] + assert col.find_one() is not None + query = rdb.number_query() + assert col.find_one(query) is not None + runs = self.st.select_runs() + assert len(runs) == len(self.test_run_ids) + + def test_write_and_load(self): + assert not self.is_all_targets_stored + + # Make ALL the data + # NB: the context writes to ALL the storage frontends that are susceptible + for t in self.all_targets: + self.st.make(self.test_run_ids, t) + + for r in self.test_run_ids: + print(self.st.available_for_run(r)) + assert self.is_all_targets_stored + + # Double check that we can load data from mongo even if we cannot make it + self.st.context_config['forbid_creation_of'] = self.all_targets + peaks = self.st.get_array(self.test_run_ids, self.all_targets[-1]) + assert len(peaks) + runs = self.st.select_runs(available=self.all_targets) + assert len(runs) == len(self.test_run_ids) + + # Insert a new run number and check that it's not marked as available + self.database[self.collection_name].insert_one(_rundoc_format(3)) + self.st.runs = None # Reset + all_runs = self.st.select_runs() + available_runs = self.st.select_runs(available=self.all_targets) + assert len(available_runs) == len(self.test_run_ids) + assert len(all_runs) == len(self.test_run_ids) + 1 + + def test_lineage_changes(self): + st = strax.Context(register=[Records, Peaks], + storage=[self.rundb_sf], + use_per_run_defaults=True, + ) + lineages = [st.key_for(r, 'peaks').lineage_hash for r in self.test_run_ids] + assert len(set(lineages)) > 1 + with self.assertRaises(ValueError): + # Lineage changing per run is not allowed! + st.select_runs(available='peaks') + + def test_fuzzy(self): + """See that fuzzy for does not work yet with the RunDB""" + fuzzy_st = self.st.new_context(fuzzy_for=self.all_targets) + with self.assertWarns(UserWarning): + fuzzy_st.is_stored(self.test_run_ids[0], self.all_targets[0]) + with self.assertWarns(UserWarning): + keys = [fuzzy_st.key_for(r, self.all_targets[0]) for r in self.test_run_ids] + self.rundb_sf.find_several(keys, fuzzy_for=self.all_targets) + + def test_invalids(self): + """Test a couble of invalid ways of passing arguments to the RunDB""" + with self.assertRaises(ValueError): + straxen.RunDB(runid_field='numbersdfgsd', ) + with self.assertRaises(ValueError): + r = self.test_run_ids[0] + keys = [self.st.key_for(r, t) for t in self.all_targets] + self.rundb_sf.find_several(keys, fuzzy_for=self.all_targets) + with self.assertRaises(strax.DataNotAvailable): + self.rundb_sf.find(self.st.key_for('_super-run', self.all_targets[0])) + + def test_rucio_format(self): + """Test that document retrieval works for rucio files in the RunDB""" + rucio_id = '999999' + target = self.all_targets[-1] + key = self.st.key_for(rucio_id, target) + self.assertFalse(rucio_id in self.test_run_ids) + rd = _rundoc_format(rucio_id) + did = straxen.rucio.key_to_rucio_did(key) + rd['data'] = [{'host': 'rucio-catalogue', + 'location': 'UC_DALI_USERDISK', + 'status': 'transferred', + 'did': did, + 'number': int(rucio_id), + 'type': target, + }] + self.database[self.collection_name].insert_one(rd) + + # Make sure we get the backend key using the _find option + self.assertTrue( + self.rundb_sf._find(key, + write=False, + allow_incomplete=False, + fuzzy_for=None, + fuzzy_for_options=None, + )[1] == did, + ) + with self.assertRaises(strax.DataNotAvailable): + # Although we did insert a document, we should get a data + # not available error as we did not actually save any data + # on the rucio folder + self.rundb_sf.find(key) + + +def _rundoc_format(run_id): + start = datetime.datetime.fromtimestamp(0) + datetime.timedelta(days=int(run_id)) + end = start + datetime.timedelta(days=1) + doc = { + 'comments': [{'comment': 'some testdoc', + 'date': start, + 'user': 'master user'}], + 'data': [], + 'detectors': ['tpc'], + + 'mode': 'test', + 'number': int(run_id), + 'source': 'none', + 'start': start, + 'end': end, + 'user': 'master user'} + return doc diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py new file mode 100644 index 000000000..9628c7a0a --- /dev/null +++ b/tests/test_deprecated.py @@ -0,0 +1,14 @@ +""" +These tests are for deprecated functions that we will remove in future releases + +This is as such a bit of a "to do list" of functions to remove from straxen +""" + +import straxen +import matplotlib.pyplot as plt + + +def test_tight_layout(): + plt.scatter([1], [2]) + straxen.quiet_tight_layout() + plt.clf() diff --git a/tests/test_holoviews_utils.py b/tests/test_holoviews_utils.py index 7cf7aae23..187a24eb2 100644 --- a/tests/test_holoviews_utils.py +++ b/tests/test_holoviews_utils.py @@ -1,22 +1,14 @@ import strax import straxen from straxen.holoviews_utils import nVETOEventDisplay - import holoviews as hv import panel as pn import numpy as np - from tempfile import TemporaryDirectory import os -dummy_map = np.zeros(120, dtype=[('x', np.int32), - ('y', np.int32), - ('z', np.int32), - ('channel', np.int32),]) -dummy_map['x'] = np.arange(0, 120) -dummy_map['y'] = np.arange(0, 120) -dummy_map['z'] = np.arange(0, 120) -dummy_map['channel'] = np.arange(2000, 2120, 1, dtype=np.int32) +_dummy_map = straxen.test_utils._nveto_pmt_dummy_df.to_records() + def test_hitlets_to_hv_points(): hit = np.zeros(1, dtype=strax.hit_dtype) @@ -26,10 +18,11 @@ def test_hitlets_to_hv_points(): hit['channel'] = 2000 hit['area'] = 1 - nvd = nVETOEventDisplay(pmt_map=dummy_map) + nvd = nVETOEventDisplay(pmt_map=_dummy_map) points = nvd.hitlets_to_hv_points(hit, t_ref=0) - m = [hit[key] == points.data[key] for key in hit.dtype.names if key in points.data.columns.values] + m = [hit[key] == points.data[key] for key in hit.dtype.names if + key in points.data.columns.values] assert np.all(m), 'Data has not been converted corretly into hv.Points.' @@ -41,7 +34,7 @@ def test_hitlet_matrix(): hit['channel'] = 2000 hit['area'] = 1 - nvd = nVETOEventDisplay(pmt_map=dummy_map) + nvd = nVETOEventDisplay(pmt_map=_dummy_map) hit_m = nvd.plot_hitlet_matrix(hitlets=hit) with TemporaryDirectory() as d: @@ -54,7 +47,7 @@ def test_plot_nveto_pattern(): hit['channel'] = 2000 hit['area'] = 1 - nvd = nVETOEventDisplay(pmt_map=dummy_map) + nvd = nVETOEventDisplay(pmt_map=_dummy_map) pmt_plot = nvd.plot_nveto(hitlets=hit) with TemporaryDirectory() as d: # Have to store plot to make sure it is rendered @@ -76,7 +69,7 @@ def test_nveto_event_display(): event['endtime'] = hit['time'] + 40 event['area'] = hit['area'] - nvd = nVETOEventDisplay(event, hit, pmt_map=dummy_map, run_id='014986') + nvd = nVETOEventDisplay(event, hit, pmt_map=_dummy_map, run_id='014986') dispaly = nvd.plot_event_display() with TemporaryDirectory() as d: @@ -89,7 +82,7 @@ def test_array_to_df_and_make_sliders(): + straxen.veto_events.veto_event_positions_dtype()[2:]) evt = np.zeros(1, dtype) - nvd = nVETOEventDisplay(pmt_map=dummy_map) + nvd = nVETOEventDisplay(pmt_map=_dummy_map) df = straxen.convert_array_to_df(evt) nvd._make_sliders_and_tables(df) diff --git a/tests/test_itp_map.py b/tests/test_itp_map.py new file mode 100644 index 000000000..e72fa1e2b --- /dev/null +++ b/tests/test_itp_map.py @@ -0,0 +1,41 @@ +from unittest import TestCase, skipIf +from straxen import InterpolatingMap, utilix_is_configured, get_resource + + +class TestItpMaps(TestCase): + def open_map(self, map_name, fmt, method='WeightedNearestNeighbors'): + map_data = get_resource(map_name, fmt=fmt) + m = InterpolatingMap(map_data, method=method) + self.assertTrue(m is not None) + + @skipIf(not utilix_is_configured(), 'Cannot download maps without db access') + def test_lce_map(self): + self.open_map('XENONnT_s1_xyz_LCE_corrected_qes_MCva43fa9b_wires.json.gz', fmt='json.gz') + + def test_array_valued(self): + """See https://github.com/XENONnT/straxen/pull/757""" + _map = {'coordinate_system': [[-18.3, -31.7, -111.5], + [36.6, -0.0, -111.5], + [-18.3, 31.7, -111.5], + [-18.3, -31.7, -37.5], + [36.6, -0.0, -37.5], + [-18.3, 31.7, -37.5]], + 'description': 'Array_valued dummy map with lists', + 'name': 'Dummy map', + 'map': [[1.7, 11.1], + [1.7, 11.1], + [1.7, 11.0], + [3.3, 5.7], + [3.3, 5.8], + [3.3, 5.7]]} + itp_map = InterpolatingMap(_map) + + # Let's do something easy, check if one fixed point yields the + # same result if not, our interpolation map depends on the + # straxen version?! That's bad! + map_at_random_point = itp_map([[0, 0, 0], [0, 0, -140]]) + self.assertAlmostEqual(map_at_random_point[0][0], 2.80609655) + self.assertAlmostEqual(map_at_random_point[0][1], 7.37967879) + + self.assertAlmostEqual(map_at_random_point[1][0], 2.17815179) + self.assertAlmostEqual(map_at_random_point[1][1], 9.47282782) diff --git a/tests/test_led_calibration.py b/tests/test_led_calibration.py index b8e003df3..2bfe57099 100644 --- a/tests/test_led_calibration.py +++ b/tests/test_led_calibration.py @@ -17,7 +17,7 @@ def test_ext_timings_nv(records): # and channel start at 0, convert to nv: records['pulse_length'] = records['length'] records['channel'] += 2000 - + st = straxen.contexts.xenonnt_led() plugin = st.get_single_plugin('1', 'ext_timings_nv') hits = strax.find_hits(records, min_amplitude=1) diff --git a/tests/test_mini_analyses.py b/tests/test_mini_analyses.py new file mode 100644 index 000000000..be0e8fced --- /dev/null +++ b/tests/test_mini_analyses.py @@ -0,0 +1,434 @@ +import os +import unittest +import platform +import numpy as np +import pandas +import strax +import straxen +from matplotlib.pyplot import clf as plt_clf +from straxen.test_utils import nt_test_context, nt_test_run_id + + +def is_py310(): + """Check python version""" + return platform.python_version_tuple()[:2] == ('3', '10') + +def test_pmt_pos_1t(): + """ + Test if we can get the 1T PMT positions + """ + pandas.DataFrame(straxen.pmt_positions(True)) + + +def test_pmt_pos_nt(): + """ + Test if we can get the nT PMT positions + """ + pandas.DataFrame(straxen.pmt_positions(False)) + + +class TestMiniAnalyses(unittest.TestCase): + """ + Generally, tests in this class run st. + + We provide minimal arguments to just probe if the + is not breaking when running, we are NOT + checking if plots et cetera make sense, just if the code is not + broken (e.g. because for changes in dependencies like matplotlib or + bokeh) + + NB! If this tests fails locally (but not on github-CI), please do: + `rm strax_test_data` + You might be an old version of test data. + """ + # They were added on 25/10/2021 and may be outdated by now + _expected_test_results = { + 'peak_basics': 40, + 'n_s1': 19, + 'run_live_time': 4.7516763, + 'event_basics': 20, + } + + @classmethod + def setUpClass(cls) -> None: + """ + Common setup for all the tests. We need some data which we + don't delete but reuse to prevent a lot of computations in this + class + """ + cls.st = nt_test_context() + # For al the WF plotting, we might need records, let's make those + cls.st.make(nt_test_run_id, 'records') + cls.first_peak = cls.st.get_array(nt_test_run_id, 'peak_basics')[0] + cls.first_event = cls.st.get_array(nt_test_run_id, 'event_basics')[0] + + def tearDown(self): + """After each test, clear a figure (if one was open)""" + plt_clf() + + def test_target_peaks(self, target='peak_basics', tol=2): + """ + Not a real mini analysis but let's see if the number of peaks + matches some pre-defined value. This is just to safeguard one + from accidentally adding some braking code. + """ + self.assertTrue(target + in self._expected_test_results, + f'No expectation for {target}?!') + data = self.st.get_array(nt_test_run_id, target) + message = (f'Got more/less data for {target}. If you changed something ' + f'on {target}, please update the numbers in ' + f'tests/test_mini_analyses.TestMiniAnalyses._expected_test_results') + if not straxen.utilix_is_configured(): + # If we do things with dummy maps, things might be slightly different + tol += 10 + self.assertTrue(np.abs(len(data) - self._expected_test_results[target]) < tol, message) + + def test_target_events(self): + """Test that the number of events is roughly right""" + self.test_target_peaks(target='event_basics') + + def test_plot_waveform(self, deep=False): + self.st.plot_waveform(nt_test_run_id, + time_within=self.first_peak, + deep=deep) + + def test_plot_waveform_deep(self): + self.test_plot_waveform(deep=True) + + def test_plot_hit_pattern(self): + self.st.plot_hit_pattern(nt_test_run_id, + time_within=self.first_peak, + xenon1t=False) + + def test_plot_records_matrix(self): + self._st_attr_for_one_peak('plot_records_matrix') + + def test_raw_records_matrix(self): + self._st_attr_for_one_peak('raw_records_matrix') + + def test_event_display_simple(self): + plot_all_positions = straxen.utilix_is_configured() + with self.assertRaises(NotImplementedError): + # old way of calling the simple display + self.st.event_display_simple(nt_test_run_id, + time_within=self.first_event, + ) + # New, correct way of calling the simple display + self.st.event_display(nt_test_run_id, + time_within=self.first_event, + xenon1t=False, + plot_all_positions=plot_all_positions, + simple_layout=True, + ) + + def test_single_event_plot(self): + plot_all_positions = straxen.utilix_is_configured() + straxen.analyses.event_display.plot_single_event( + self.st, + nt_test_run_id, + events=self.st.get_array(nt_test_run_id, 'events'), + event_number=self.first_event['event_number'], + xenon1t=False, + plot_all_positions=plot_all_positions, + ) + + def test_event_display_interactive(self): + self.st.event_display_interactive(nt_test_run_id, + time_within=self.first_event, + xenon1t=False, + ) + + def test_plot_peaks_aft_histogram(self): + self.st.plot_peaks_aft_histogram(nt_test_run_id) + + @unittest.skipIf(not straxen.utilix_is_configured(), "No db access, cannot test CMT.") + def test_event_scatter(self): + self.st.event_scatter(nt_test_run_id) + + def test_event_scatter_diff_options(self): + self.st.event_scatter(nt_test_run_id, + color_range=(0, 10), + color_dim='s1_area') + + def test_energy_spectrum(self): + self.st.plot_energy_spectrum(nt_test_run_id) + + def test_energy_spectrum_diff_options(self): + """Run st.plot_energy_spectrum with several options""" + self.st.plot_energy_spectrum(nt_test_run_id, + unit='kg_day_kev', + exposure_kg_sec=1) + self.st.plot_energy_spectrum(nt_test_run_id, + unit='tonne_day_kev', + exposure_kg_sec=1) + self.st.plot_energy_spectrum(nt_test_run_id, + unit='tonne_year_kev', + exposure_kg_sec=1, + geomspace=False) + with self.assertRaises(ValueError): + # Some units shouldn't be allowed + self.st.plot_energy_spectrum(nt_test_run_id, + unit='not_allowed_unit', + exposure_kg_sec=1) + + def test_peak_classification(self): + self.st.plot_peak_classification(nt_test_run_id) + + def _st_attr_for_one_peak(self, function_name): + """ + Utility function to prevent having to copy past the code + below for all the functions we are going to test for one peak + """ + f = getattr(self.st, function_name) + f(nt_test_run_id, time_within=self.first_peak) + + @unittest.skipIf(is_py310(), 'holoviews incompatible with py3.10') + def test_waveform_display(self): + """test st.waveform_display for one peak""" + self._st_attr_for_one_peak('waveform_display') + + def test_hvdisp_plot_pmt_pattern(self): + """test st.hvdisp_plot_pmt_pattern for one peak""" + self._st_attr_for_one_peak('hvdisp_plot_pmt_pattern') + + def test_hvdisp_plot_peak_waveforms(self): + """test st.hvdisp_plot_peak_waveforms for one peak""" + self._st_attr_for_one_peak('hvdisp_plot_peak_waveforms') + + def test_plot_pulses_tpc(self): + """ + Test that we can plot some TPC pulses and fail if raise a + ValueError if an invalid combination of parameters is given + """ + self.st.plot_pulses_tpc(nt_test_run_id, + time_within=self.first_peak, + max_plots=2, + plot_hits=True, + ignore_time_warning=False, + store_pdf=True, + ) + with self.assertRaises(ValueError): + # Raise an error if no time range is specified + self.st.plot_pulses_tpc(nt_test_run_id, + max_plots=2, + plot_hits=True, + ignore_time_warning=True, + store_pdf=True, + ) + + def test_plot_pulses_mv(self): + """Repeat above for mv""" + self.st.plot_pulses_mv(nt_test_run_id, + max_plots=2, + plot_hits=True, + ignore_time_warning=True, + ) + + def test_plot_pulses_nv(self): + """Repeat above for nv""" + self.st.plot_pulses_nv(nt_test_run_id, + max_plots=2, + plot_hits=True, + ignore_time_warning=True, + ) + + @unittest.skipIf(not straxen.utilix_is_configured(), + "No db access, cannot test!") + def test_event_display(self): + """Event display plot, needs CMT""" + self.st.event_display(nt_test_run_id, time_within=self.first_event) + + @unittest.skipIf(not straxen.utilix_is_configured(), + "No db access, cannot test!") + def test_event_display_no_rr(self): + """Make an event display without including records""" + self.st.event_display(nt_test_run_id, + time_within=self.first_event, + records_matrix=False, + event_time_limit=[self.first_event['time'], + self.first_event['endtime']], + ) + + def test_calc_livetime(self): + """Use straxen.get_livetime_sec""" + try: + live_time = straxen.get_livetime_sec(self.st, nt_test_run_id) + except strax.RunMetadataNotAvailable: + things = self.st.get_array(nt_test_run_id, 'peaks') + live_time = straxen.get_livetime_sec(self.st, nt_test_run_id, things=things) + assertion_statement = "Live-time calculation is wrong" + expected = self._expected_test_results['run_live_time'] + self.assertTrue(live_time == expected, assertion_statement) + + def test_df_wiki(self): + """We have a nice utility to write dataframes to the wiki""" + df = self.st.get_df(nt_test_run_id, 'peak_basics')[:10] + straxen.dataframe_to_wiki(df) + + @unittest.skipIf(straxen.utilix_is_configured(), + "Test for no DB access") + def test_daq_plot_errors_without_utilix(self): + """ + We should get a not implemented error if we call a function + in the daq_waveforms analyses + """ + with self.assertRaises(NotImplementedError): + straxen.analyses.daq_waveforms._get_daq_config( + 'som_run', run_collection=None) + + @unittest.skipIf(not straxen.utilix_is_configured(), + "No db access, cannot test!") + def test_daq_plot_errors(self): + """To other ways we should not be allowed to call daq_waveforms.XX""" + with self.assertRaises(ValueError): + straxen.analyses.daq_waveforms._get_daq_config('no_run') + with self.assertRaises(ValueError): + straxen.analyses.daq_waveforms._board_to_host_link({'boards': [{'no_boards': 0}]}, 1) + + @unittest.skipIf(not straxen.utilix_is_configured(), "No db access, cannot test!") + def test_event_plot_errors(self): + """ + Several Exceptions should be raised with these following bad + ways of calling the event display + """ + with self.assertRaises(ValueError): + # Wrong way of calling records matrix + self.st.event_display(nt_test_run_id, + records_matrix='records_are_bad') + with self.assertRaises(ValueError): + # A single event should not have three entries + straxen.analyses.event_display._event_display(events=[1, 2, 3], + context=self.st, + to_pe=None, + run_id='1' + ) + with self.assertRaises(ValueError): + # Can't pass empty axes like this to the inner script + straxen.analyses.event_display._event_display(axes=None, + events=[None], + context=self.st, + to_pe=None, + run_id=nt_test_run_id, + ) + with self.assertRaises(ValueError): + # Should raise a valueError + straxen.analyses.event_display.plot_single_event(context=None, + run_id=None, + events=[1, 2, 3], + event_number=None, + ) + with self.assertRaises(ValueError): + # Give to many recs to this inner script + straxen.analyses.event_display._scatter_rec(_event=None, + recs=list(range(10))) + + @unittest.skipIf(is_py310(), 'holoviews incompatible with py3.10') + def test_interactive_display(self): + """Run and save interactive display""" + fig = self.st.event_display_interactive(nt_test_run_id, + time_within=self.first_event, + xenon1t=False, + plot_record_matrix=True, + ) + save_as = 'test_display.html' + fig.save(save_as) + self.assertTrue(os.path.exists(save_as)) + os.remove(save_as) + self.assertFalse(os.path.exists(save_as)) + st = self.st.new_context() + st.event_display_interactive(nt_test_run_id, + time_within=self.first_event, + xenon1t=False, + plot_record_matrix=False, + only_main_peaks=True, + ) + + def test_bokeh_selector(self): + """Test the bokeh data selector""" + from straxen.analyses.bokeh_waveform_plot import DataSelectionHist + p = self.st.get_array(nt_test_run_id, 'peak_basics') + ds = DataSelectionHist('ds') + fig = ds.histogram2d(p, + p['area'], + p['area'], + bins=50, + hist_range=((0, 200), (0, 2000)), + log_color_scale=True, + clim=(10, None), + undeflow_color='white') + + import bokeh.plotting as bklt + save_as = 'test_data_selector.html' + bklt.save(fig, save_as) + self.assertTrue(os.path.exists(save_as)) + os.remove(save_as) + self.assertFalse(os.path.exists(save_as)) + # Also test if we can write it to the wiki + straxen.bokeh_to_wiki(fig) + straxen.bokeh_to_wiki(fig, save_as) + self.assertTrue(os.path.exists(save_as)) + os.remove(save_as) + self.assertFalse(os.path.exists(save_as)) + + @unittest.skipIf(not straxen.utilix_is_configured(), + "No db access, cannot test!") + def test_nt_daq_plot(self): + """Make an nt DAQ plot""" + self.st.daq_plot(nt_test_run_id, + time_within=self.first_peak, + vmin=0.1, + vmax=1, + ) + + @unittest.skipIf(not straxen.utilix_is_configured(), + "No db access, cannot test!") + def test_nt_daq_plot_grouped(self): + """Same as above grouped by ADC""" + self.st.plot_records_matrix(nt_test_run_id, + time_within=self.first_peak, + vmin=0.1, + vmax=1, + group_by='ADC ID', + ) + + def test_records_matrix_downsample(self): + """Test that downsampling works in the record matrix""" + self.st.records_matrix(nt_test_run_id, + time_within=self.first_event, + max_samples=20 + ) + + @unittest.skipIf(not straxen.utilix_is_configured(), + "No db access, cannot test!") + def test_load_corrected_positions(self): + """Test that we can do st.load_corrected_positions""" + self.st.load_corrected_positions(nt_test_run_id, + time_within=self.first_peak) + + @unittest.skipIf(not straxen.utilix_is_configured(), + "No db access, cannot test!") + def test_nv_event_display(self): + """ + Test NV event display for a single event. + """ + events_nv = self.st.get_array(nt_test_run_id, 'events_nv') + warning = ("Do 'rm ./strax_test_data' since your *_nv test " + "data in that folder is messing up this test.") + self.assertTrue(len(events_nv), warning) + self.st.make(nt_test_run_id, 'event_positions_nv') + self.st.plot_nveto_event_display(nt_test_run_id, + time_within=events_nv[0], + ) + with self.assertRaises(ValueError): + # If there is no data, we should raise a ValueError + self.st.plot_nveto_event_display(nt_test_run_id, + time_range=[-1000,-900], + ) + + +def test_plots(): + """Make some plots""" + c = np.ones(straxen.n_tpc_pmts) + straxen.plot_pmts(c) + straxen.plot_pmts(c, log_scale=True) diff --git a/tests/test_misc.py b/tests/test_misc.py index c217f96cf..08123ce21 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -1,4 +1,6 @@ -from straxen.misc import TimeWidgets +from straxen.misc import TimeWidgets, print_versions +import straxen +import unittest def test_widgets(): @@ -14,10 +16,12 @@ def test_widgets(): start_utc, end_utc = tw.get_start_end() h_in_ns_unix = 60*60*10**9 - unix_conversion_worked = start_utc - start == h_in_ns_unix or start_utc - start == 2 * h_in_ns_unix - assert unix_conversion_worked - unix_conversion_worked = start_utc - end == h_in_ns_unix or start_utc - end == 2 * h_in_ns_unix - assert unix_conversion_worked + assert (start_utc - start == h_in_ns_unix + or start_utc - start == 2 * h_in_ns_unix + or start_utc - start == 0 * h_in_ns_unix) + assert (start_utc - end == h_in_ns_unix + or start_utc - end == 2 * h_in_ns_unix + or start_utc - end == 0 * h_in_ns_unix) def test_change_in_fields(): @@ -41,3 +45,24 @@ def test_change_in_fields(): start00, _ = tw.get_start_end() assert start20 - start00 == minutes, 'Time field did not update its value!' + + +def test_print_versions(modules=('numpy', 'straxen', 'non_existing_module')): + for return_string in [True, False]: + for include_git in [True, False]: + res = print_versions(modules, + return_string=return_string, + include_git=include_git) + if return_string: + assert res is not None + + +class HitAmplitude(unittest.TestCase): + def test_non_existing(self): + with self.assertRaises(ValueError): + straxen.hit_min_amplitude('non existing key') + + @staticmethod + def test_get_hit_amplitude(): + straxen.hit_min_amplitude('pmt_commissioning_initial') + straxen.hit_min_amplitude('pmt_commissioning_initial_he') diff --git a/tests/test_mongo_downloader.py b/tests/test_mongo_downloader.py new file mode 100644 index 000000000..b4a8733e9 --- /dev/null +++ b/tests/test_mongo_downloader.py @@ -0,0 +1,121 @@ +import unittest +import straxen +import os +import pymongo + + +def mongo_uri_not_set(): + return 'TEST_MONGO_URI' not in os.environ + + +@unittest.skipIf(mongo_uri_not_set(), "No access to test database") +class TestMongoDownloader(unittest.TestCase): + """ + Test the saving behavior of the context with the mogno downloader + + Requires write access to some pymongo server, the URI of witch is to be set + as an environment variable under: + + TEST_MONGO_URI + + At the moment this is just an empty database but you can also use some free + ATLAS mongo server. + """ + _run_test = True + + def setUp(self): + # Just to make sure we are running some mongo server, see test-class docstring + if 'TEST_MONGO_URI' not in os.environ: + self._run_test = False + return + uri = os.environ.get('TEST_MONGO_URI') + db_name = 'test_rundb' + collection_name = 'fs.files' + client = pymongo.MongoClient(uri) + database = client[db_name] + collection = database[collection_name] + self.downloader = straxen.MongoDownloader(collection=collection, + readonly=True, + file_database=None, + _test_on_init=False, + ) + self.uploader = straxen.MongoUploader(collection=collection, + readonly=False, + file_database=None, + _test_on_init=False, + ) + self.collection = collection + + def tearDown(self): + self.collection.drop() + + def test_up_and_download(self): + with self.assertRaises(ConnectionError): + # Should be empty! + self.downloader.test_find() + file_name = 'test.txt' + self.assertFalse(self.downloader.md5_stored(file_name)) + self.assertEqual(self.downloader.compute_md5(file_name), '') + file_content = 'This is a test' + with open(file_name, 'w') as f: + f.write(file_content) + self.assertTrue(os.path.exists(file_name)) + self.uploader.upload_from_dict({file_name: os.path.abspath(file_name)}) + self.assertTrue(self.uploader.md5_stored(file_name)) + self.assertTrue(self.downloader.config_exists(file_name)) + path = self.downloader.download_single(file_name) + path_hr = self.downloader.download_single(file_name, human_readable_file_name=True) + abs_path = self.downloader.get_abs_path(file_name) + + for p in [path, path_hr, abs_path]: + self.assertTrue(os.path.exists(p)) + read_file = straxen.get_resource(path) + self.assertTrue(file_content == read_file) + os.remove(file_name) + self.assertFalse(os.path.exists(file_name)) + self.downloader.test_find() + self.downloader.download_all() + # Now the test on init should work, let's double try + straxen.MongoDownloader(collection=self.collection, + file_database=None, + _test_on_init=True, + ) + + def test_invalid_methods(self): + """ + The following examples should NOT work, let's make sure the + right errors are raised + """ + with self.assertRaises(ValueError): + straxen.MongoDownloader(collection=self.collection, + file_database='NOT NONE', + ) + with self.assertRaises(ValueError): + straxen.MongoDownloader(collection='invalid type', + ) + with self.assertRaises(PermissionError): + straxen.MongoUploader(readonly=True) + + with self.assertRaises(ValueError): + self.uploader.upload_from_dict("A string is not a dict") + + with self.assertRaises(straxen.mongo_storage.CouldNotLoadError): + self.uploader.upload_single('no_such_file', 'no_such_file') + + with self.assertWarns(UserWarning): + self.uploader.upload_from_dict({'something': 'no_such_file'}) + + with self.assertRaises(ValueError): + straxen.MongoDownloader(collection=self.collection, + file_database=None, + _test_on_init=False, + store_files_at=False, + ) + with self.assertRaises(ValueError): + self.downloader.download_single('no_existing_file') + + with self.assertRaises(ValueError): + self.downloader._check_store_files_at('some_str') + + with self.assertRaises(PermissionError): + self.downloader._check_store_files_at([]) diff --git a/tests/test_mongo_interactions.py b/tests/test_mongo_interactions.py index e547dc558..47676249a 100644 --- a/tests/test_mongo_interactions.py +++ b/tests/test_mongo_interactions.py @@ -1,110 +1,110 @@ -""" -Test certain interactions with the runsdatabase. -NB! this only works if one has access to the database. This does not -work e.g. on travis jobs and therefore the tests failing locally will -not show up in Pull Requests. -""" - -import straxen -import os -from warnings import warn -from .test_plugins import test_run_id_nT - - -def test_select_runs(check_n_runs=2): - """ - Test (if we have a connection) if we can perform strax.select_runs - on the last two runs in the runs collection - - :param check_n_runs: int, the number of runs we want to check - """ - - if not straxen.utilix_is_configured(): - warn('Makes no sense to test the select runs because we do not ' - 'have access to the database.') - return - assert check_n_runs >= 1 - st = straxen.contexts.xenonnt_online(use_rucio=False) - run_col = st.storage[0].collection - - # Find the latest run in the runs collection - last_run = run_col.find_one(projection={'number': 1}, - sort=[('number', -1)] - ).get('number') - - # Set this number as the minimum run number. This limits the - # amount of documents checked and therefore keeps the test short. - st.storage[0].minimum_run_number = int(last_run) - (check_n_runs - 1) - st.select_runs() - - -def test_downloader(): - """Test if we can download a small file from the downloader""" - if not straxen.utilix_is_configured(): - warn('Cannot download because utilix is not configured') - return - - downloader = straxen.MongoDownloader() - path = downloader.download_single('to_pe_nt.npy') - assert os.path.exists(path) - - -def _patch_om_init(take_only): - """ - temp patch since om = straxen.OnlineMonitor() does not work with utilix - """ - header = 'RunDB' - user = straxen.uconfig.get(header, 'pymongo_user') - pwd = straxen.uconfig.get(header, 'pymongo_password') - url = straxen.uconfig.get(header, 'pymongo_url').split(',')[-1] - uri = f"mongodb://{user}:{pwd}@{url}" - return straxen.OnlineMonitor(uri=uri, take_only=take_only) - - -def test_online_monitor(target='online_peak_monitor', max_tries=3): - """ - See if we can get some data from the online monitor before max_tries - - :param target: target to test getting from the online monitor - :param max_tries: number of queries max allowed to get a non-failing - run - """ - if not straxen.utilix_is_configured(): - warn('Cannot test online monitor because utilix is not configured') - return - st = straxen.contexts.xenonnt_online(use_rucio=False) - om = _patch_om_init(target) - st.storage = [om] - max_run = None - for i in range(max_tries): - query = {'provides_meta': True, 'data_type': target} - if max_run is not None: - # One run failed before, lets try a more recent one. - query.update({'number': {"$gt": int(max_run)}}) - some_run = om.db[om.col_name].find_one(query, - projection={'number': 1, - 'metadata': 1, - 'lineage_hash': 1, - }) - if some_run.get('lineage_hash', False): - if some_run['lineage_hash'] != st.key_for("0", target).lineage_hash: - # We are doing a new release, therefore there is no - # matching data. This makes sense. - return - if some_run is None or some_run.get('number', None) is None: - print(f'Found None') - continue - elif 'exception' in some_run.get('metadata', {}): - # Did find a run, but it is bad, we need to find another one - print(f'Found {some_run.get("number", "No number")} with errors') - max_run = some_run.get("number", -1) - continue - else: - # Correctly written - run_id = f'{some_run["number"]:06}' - break - else: - raise FileNotFoundError(f'No non-failing {target} found in the online ' - f'monitor after {max_tries}. Looked for:\n' - f'{st.key_for("0", target)}') - st.get_array(run_id, target, seconds_range=(0, 1), allow_incomplete=True) +""" +Test certain interactions with the runsdatabase. +NB! this only works if one has access to the database. This does not +work e.g. on travis jobs and therefore the tests failing locally will +not show up in Pull Requests. +""" +import straxen +import os +import unittest +from pymongo import ReadPreference +import warnings + + +@unittest.skipIf(not straxen.utilix_is_configured(), "No db access, cannot test!") +class TestSelectRuns(unittest.TestCase): + def test_select_runs(self, check_n_runs=2): + """ + Test (if we have a connection) if we can perform strax.select_runs + on the last two runs in the runs collection + + :param check_n_runs: int, the number of runs we want to check + """ + self.assertTrue(check_n_runs >= 1) + st = straxen.contexts.xenonnt_online(use_rucio=False) + run_col = st.storage[0].collection + + # Find the latest run in the runs collection + last_run = run_col.find_one(projection={'number': 1}, + sort=[('number', -1)] + ).get('number') + + # Set this number as the minimum run number. This limits the + # amount of documents checked and therefore keeps the test short. + st.storage[0].minimum_run_number = int(last_run) - (check_n_runs - 1) + st.select_runs() + + +@unittest.skipIf(not straxen.utilix_is_configured(), + "Cannot download because utilix is not configured") +class TestDownloader(unittest.TestCase): + def test_downloader(self): + """Test if we can download a small file from the downloader""" + downloader = straxen.MongoDownloader() + path = downloader.download_single('to_pe_nt.npy') + self.assertTrue(os.path.exists(path)) + + +def _patch_om_init(take_only): + """ + temp patch since om = straxen.OnlineMonitor() does not work with utilix + """ + header = 'RunDB' + user = straxen.uconfig.get(header, 'pymongo_user') + pwd = straxen.uconfig.get(header, 'pymongo_password') + url = straxen.uconfig.get(header, 'pymongo_url').split(',')[-1] + uri = f"mongodb://{user}:{pwd}@{url}" + return straxen.OnlineMonitor(uri=uri, take_only=take_only) + + +@unittest.skipIf(not straxen.utilix_is_configured(), "No db access, cannot test!") +def test_online_monitor(target='online_peak_monitor', max_tries=3): + """ + See if we can get some data from the online monitor before max_tries + + :param target: target to test getting from the online monitor + :param max_tries: number of queries max allowed to get a non-failing + run + """ + st = straxen.contexts.xenonnt_online(use_rucio=False) + straxen.get_mongo_uri() + om = _patch_om_init(target) + st.storage = [om] + max_run = None + for i in range(max_tries): + query = {'provides_meta': True, 'data_type': target} + if max_run is not None: + # One run failed before, lets try a more recent one. + query.update({'number': {"$gt": int(max_run)}}) + collection = om.db[om.col_name].with_options( + read_preference=ReadPreference.SECONDARY_PREFERRED) + some_run = collection.find_one(query, + projection={'number': 1, + 'metadata': 1, + 'lineage_hash': 1, + }) + if some_run is not None and some_run.get('lineage_hash', False): + if some_run['lineage_hash'] != st.key_for("0", target).lineage_hash: + # We are doing a new release, therefore there is no + # matching data. This makes sense. + return + if some_run is None or some_run.get('number', None) is None: + print(f'Found None') + continue + elif 'exception' in some_run.get('metadata', {}): + # Did find a run, but it is bad, we need to find another one + print(f'Found {some_run.get("number", "No number")} with errors') + max_run = some_run.get("number", -1) + continue + else: + # Correctly written + run_id = f'{some_run["number"]:06}' + break + else: + if collection.find_one() is not None: + raise FileNotFoundError(f'No non-failing {target} found in the online ' + f'monitor after {max_tries}. Looked for:\n' + f'{st.key_for("0", target)}') + warnings.warn(f'Did not find any data in {om.col_name}!') + return + st.get_array(run_id, target, seconds_range=(0, 1), allow_incomplete=True) diff --git a/tests/test_nveto_recorder.py b/tests/test_nveto_recorder.py index 162aeedc4..4c4d86d44 100644 --- a/tests/test_nveto_recorder.py +++ b/tests/test_nveto_recorder.py @@ -6,7 +6,6 @@ class TestMergeIntervals(unittest.TestCase): - def setUp(self): self.intervals = np.zeros(4, dtype=strax.time_fields) self.intervals['time'] = [2, 3, 7, 20] @@ -34,7 +33,6 @@ def test_merge_overlapping_intervals(self): class TestCoincidence(unittest.TestCase): - def setUp(self): self.intervals = np.zeros(8, dtype=strax.time_fields) self.intervals['time'] = [3, 6, 9, 12, 15, 18, 21, 38] @@ -134,3 +132,12 @@ def _test_coincidence(self, resolving_time, coincidence, pre_trigger, endtime_is_correct = np.all(coincidence['endtime'] == endtime_truth) print(coincidence['endtime'], endtime_truth) assert endtime_is_correct, 'Coincidence does not have the correct endtime' + + +def test_nv_for_dummy_rr(): + """Basic test to run the nv rr for dummy raw-records""" + st = straxen.test_utils.nt_test_context(deregister=()) + st.context_config['forbid_creation_of'] = tuple() + st.register(straxen.test_utils.DummyRawRecords) + st.make(straxen.test_utils.nt_test_run_id, 'hitlets_nv') + st.make(straxen.test_utils.nt_test_run_id, 'events_tagged') diff --git a/tests/test_peaklet_processing.py b/tests/test_peaklet_processing.py index f21a4fdbc..d14b2c35e 100644 --- a/tests/test_peaklet_processing.py +++ b/tests/test_peaklet_processing.py @@ -3,7 +3,9 @@ import hypothesis.strategies as strat import strax +from strax.testutils import fake_hits import straxen +from straxen.plugins.peaklet_processing import get_tight_coin @settings(deadline=None) @@ -35,13 +37,45 @@ def test_n_hits(): records['length'] = 5 records['pulse_length'] = 5 records['dt'] = 1 - records['channel'] = [0, 1] + records['channel'] = [0, 1] records['data'][0, :5] = [0, 1, 1, 0, 1] records['data'][1, :5] = [0, 1, 0, 0, 0] - + st = straxen.contexts.xenonnt_online() st.set_config({'hit_min_amplitude': 1}) p = st.get_single_plugin('0', 'peaklets') res = p.compute(records, 0, 999) peaklets = res['peaklets'] assert peaklets['n_hits'] == 3, f"Peaklet has the wrong number of hits!" + + +@given(fake_hits, + strat.lists(elements=strat.integers(0, 9), min_size=20)) +@settings(deadline=None) +def test_tight_coincidence(hits, channel): + hits['area'] = 1 + hits['channel'] = channel[:len(hits)] # In case there are less channel then hits (unlikely) + gap_threshold = 10 + peaks = strax.find_peaks(hits, + adc_to_pe=np.ones(10), + right_extension=0, left_extension=0, + gap_threshold=gap_threshold, + min_channels=1, + min_area=0) + + peaks_max_time = peaks['time'] + peaks['length']//2 + hits_max_time = hits['time'] + hits['length']//2 + + left = 5 + right = 5 + tight_coin_channel = get_tight_coin(hits_max_time, + hits['channel'], + peaks_max_time, + left, + right, + ) + for ind, p_max_t in enumerate(peaks_max_time): + m_hits_in_peak = (hits_max_time >= (p_max_t - left)) + m_hits_in_peak &= (hits_max_time <= (p_max_t + right)) + n_channel = len(np.unique(hits[m_hits_in_peak]['channel'])) + assert n_channel == tight_coin_channel[ind], f'Wrong number of tight channel got {tight_coin_channel[ind]}, but expectd {n_channel}' # noqa diff --git a/tests/test_peaks.py b/tests/test_peaks.py index 1e412e0a4..c08bfc5ae 100644 --- a/tests/test_peaks.py +++ b/tests/test_peaks.py @@ -10,13 +10,13 @@ TEST_DATA_LENGTH = 3 R_TOL_DEFAULT = 1e-5 + def _not_close_to_0_or_1(x, rtol=R_TOL_DEFAULT): return not (np.isclose(x, 1, rtol=rtol) or np.isclose(x, 0, rtol=rtol)) class TestComputePeakBasics(unittest.TestCase): """Tests for peak basics plugin""" - def setUp(self, context=straxen.contexts.demo): self.st = context() self.n_top = self.st.config.get('n_top_pmts', 2) @@ -84,15 +84,15 @@ def create_unique_intervals(size, time_range=(0, 40), allow_zero_length=True): :param allow_zero_length: If true allow zero length intervals. """ strat = strategies.lists(elements=strategies.integers(*time_range), - min_size=size*2, - max_size=size*2 + min_size=size * 2, + max_size=size * 2 ).map(lambda x: _convert_to_interval(x, allow_zero_length)) return strat def _convert_to_interval(time_stamps, allow_zero_length): time_stamps = np.sort(time_stamps) - intervals = np.zeros(len(time_stamps)//2, strax.time_dt_fields) + intervals = np.zeros(len(time_stamps) // 2, strax.time_dt_fields) intervals['dt'] = 1 intervals['time'] = time_stamps[::2] intervals['length'] = time_stamps[1::2] - time_stamps[::2] diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 0284a4bc9..5c5685283 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -1,135 +1,15 @@ import tempfile +import unittest import strax -import numpy as np -from immutabledict import immutabledict import straxen -from straxen.common import pax_file, aux_repo -## -# Tools -## -# Let's make a dummy map for NVeto -nveto_pmt_dummy_df = {'channel': list(range(2000, 2120)), - 'x': list(range(120)), - 'y': list(range(120)), - 'z': list(range(120))} - -# Some configs are better obtained from the strax_auxiliary_files repo. -# Let's use small files, we don't want to spend a lot of time downloading -# some file. -testing_config_nT = dict( - nn_architecture= - aux_repo + 'f0df03e1f45b5bdd9be364c5caefdaf3c74e044e/fax_files/mlp_model.json', - nn_weights= - aux_repo + 'f0df03e1f45b5bdd9be364c5caefdaf3c74e044e/fax_files/mlp_model.h5', - gain_model=("to_pe_placeholder", True), - s2_xy_correction_map=pax_file('XENON1T_s2_xy_ly_SR0_24Feb2017.json'), - elife_conf=('elife_constant', 1e6), - baseline_samples_nv=10, - fdc_map=pax_file('XENON1T_FDC_SR0_data_driven_3d_correction_tf_nn_v0.json.gz'), - gain_model_nv=("adc_nv", True), - gain_model_mv=("adc_mv", True), - nveto_pmt_position_map=nveto_pmt_dummy_df, - s1_xyz_correction_map=pax_file('XENON1T_s1_xyz_lce_true_kr83m_SR0_pax-680_fdc-3d_v0.json'), - electron_drift_velocity=("electron_drift_velocity_constant", 1e-4), - s1_aft_map=aux_repo + 'ffdadba3439ae7922b19f5dd6479348b253c09b0/strax_files/s1_aft_UNITY_xyz_XENONnT.json', - s2_optical_map=aux_repo + '8a6f0c1a4da4f50546918cd15604f505d971a724/strax_files/s2_map_UNITY_xy_XENONnT.json', - s1_optical_map=aux_repo + '8a6f0c1a4da4f50546918cd15604f505d971a724/strax_files/s1_lce_UNITY_xyz_XENONnT.json', - electron_drift_time_gate=("electron_drift_time_gate_constant", 2700), - hit_min_amplitude='pmt_commissioning_initial', - hit_min_amplitude_nv=20, - hit_min_amplitude_mv=80, - hit_min_amplitude_he='pmt_commissioning_initial_he' -) - -testing_config_1T = dict( - hev_gain_model=('1T_to_pe_placeholder', False), - gain_model=('1T_to_pe_placeholder', False), - elife_conf=('elife_constant', 1e6), - electron_drift_velocity=("electron_drift_velocity_constant", 1e-4), - electron_drift_time_gate=("electron_drift_time_gate_constant", 1700), -) - -test_run_id_nT = '008900' -test_run_id_1T = '180423_1021' - - -@strax.takes_config( - strax.Option('secret_time_offset', default=0, track=False), - strax.Option('recs_per_chunk', default=10, track=False), - strax.Option('n_chunks', default=2, track=False, - help='Number of chunks for the dummy raw records we are writing here'), - strax.Option('channel_map', track=False, type=immutabledict, - help="frozendict mapping subdetector to (min, max) " - "channel number.") -) -class DummyRawRecords(strax.Plugin): - """ - Provide dummy raw records for the mayor raw_record types - """ - provides = ('raw_records', - 'raw_records_he', - 'raw_records_nv', - 'raw_records_aqmon', - 'raw_records_aux_mv', - 'raw_records_mv' - ) - parallel = 'process' - depends_on = tuple() - data_kind = immutabledict(zip(provides, provides)) - rechunk_on_save = False - dtype = {p: strax.raw_record_dtype() for p in provides} - - def setup(self): - self.channel_map_keys = {'he': 'he', - 'nv': 'nveto', - 'aqmon': 'aqmon', - 'aux_mv': 'aux_mv', - 's_mv': 'mv', - } # s_mv otherwise same as aux in endswith - - def source_finished(self): - return True - - def is_ready(self, chunk_i): - return chunk_i < self.config['n_chunks'] - - def compute(self, chunk_i): - t0 = chunk_i + self.config['secret_time_offset'] - if chunk_i < self.config['n_chunks'] - 1: - # One filled chunk - r = np.zeros(self.config['recs_per_chunk'], self.dtype['raw_records']) - r['time'] = t0 - r['length'] = r['dt'] = 1 - r['channel'] = np.arange(len(r)) - else: - # One empty chunk - r = np.zeros(0, self.dtype['raw_records']) - - res = {} - for p in self.provides: - rr = np.copy(r) - # Add detector specific channel offset: - for key, channel_key in self.channel_map_keys.items(): - if channel_key not in self.config['channel_map']: - # Channel map for 1T is different. - continue - if p.endswith(key): - s, e = self.config['channel_map'][channel_key] - rr['channel'] += s - res[p] = self.chunk(start=t0, end=t0 + 1, data=rr, data_type=p) - return res - - -# Don't concern ourselves with rr_aqmon et cetera -forbidden_plugins = tuple([p for p in - straxen.daqreader.DAQReader.provides - if p not in DummyRawRecords.provides]) +from straxen.test_utils import nt_test_run_id, DummyRawRecords, testing_config_1T, test_run_id_1T def _run_plugins(st, make_all=False, - run_id=test_run_id_nT, - **proces_kwargs): + run_id=nt_test_run_id, + from_scratch=False, + **process_kwargs): """ Try all plugins (except the DAQReader) for a given context (st) to see if we can really push some (empty) data from it and don't have any nasty @@ -137,16 +17,23 @@ def _run_plugins(st, """ with tempfile.TemporaryDirectory() as temp_dir: - st.storage = [strax.DataDirectory(temp_dir)] - - # As we use a temporary directory we should have a clean start - assert not st.is_stored(run_id, 'raw_records'), 'have RR???' + if from_scratch: + st.storage = [strax.DataDirectory(temp_dir)] + # As we use a temporary directory we should have a clean start + assert not st.is_stored(run_id, 'raw_records'), 'have RR???' + + # Don't concern ourselves with rr_aqmon et cetera + _forbidden_plugins = tuple([p for p in + straxen.daqreader.DAQReader.provides + if p not in + st._plugin_class_registry['raw_records'].provides]) + st.set_context_config({'forbid_creation_of': _forbidden_plugins}) # Create event info target = 'event_info' st.make(run_id=run_id, targets=target, - **proces_kwargs) + **process_kwargs) # The stuff should be there assert st.is_stored(run_id, target), f'Could not make {target}' @@ -155,11 +42,13 @@ def _run_plugins(st, return end_targets = set(st._get_end_targets(st._plugin_class_registry)) - for p in end_targets-set(forbidden_plugins): + for p in end_targets - set(_forbidden_plugins): + if 'raw' in p: + continue st.make(run_id, p) # Now make sure we can get some data for all plugins all_datatypes = set(st._plugin_class_registry.keys()) - for p in all_datatypes-set(forbidden_plugins): + for p in all_datatypes - set(_forbidden_plugins): should_be_stored = (st._plugin_class_registry[p].save_when == strax.SaveWhen.ALWAYS) if should_be_stored: @@ -168,23 +57,12 @@ def _run_plugins(st, print("Wonderful all plugins work (= at least they don't fail), bye bye") -def _update_context(st, max_workers, fallback_gains=None, nt=True): - # Change config to allow for testing both multiprocessing and lazy mode - st.set_context_config({'forbid_creation_of': forbidden_plugins}) +def _update_context(st, max_workers, nt=True): # Ignore strax-internal warnings st.set_context_config({'free_options': tuple(st.config.keys())}) - st.register(DummyRawRecords) - if nt and not straxen.utilix_is_configured(): - st.set_config(testing_config_nT) - del st._plugin_class_registry['peak_positions_mlp'] - del st._plugin_class_registry['peak_positions_cnn'] - del st._plugin_class_registry['peak_positions_gcn'] - st.register(straxen.PeakPositions1T) - print(f"Using {st._plugin_class_registry['peak_positions']} for posrec tests") - st.set_config({'gain_model': fallback_gains}) - - elif not nt: - if straxen.utilix_is_configured(): + if not nt: + st.register(DummyRawRecords) + if straxen.utilix_is_configured(warning_message=False): # Set some placeholder gain as this takes too long for 1T to load from CMT st.set_config({k: v for k, v in testing_config_1T.items() if k in ('hev_gain_model', 'gain_model')}) @@ -241,11 +119,6 @@ def _test_child_options(st, run_id): f'"{option_name}"!') -## -# Tests -## - - def test_1T(ncores=1): if ncores == 1: print('-- 1T lazy mode --') @@ -259,7 +132,7 @@ def test_1T(ncores=1): _plugin_class.save_when = strax.SaveWhen.ALWAYS # Run the test - _run_plugins(st, make_all=True, max_wokers=ncores, run_id=test_run_id_1T) + _run_plugins(st, make_all=True, max_workers=ncores, run_id=test_run_id_1T, from_scratch=True) # Test issue #233 st.search_field('cs1') @@ -274,25 +147,21 @@ def test_1T(ncores=1): def test_nT(ncores=1): if ncores == 1: print('-- nT lazy mode --') - st = straxen.contexts.xenonnt_online(_database_init=straxen.utilix_is_configured(), - use_rucio=False) - offline_gain_model = ("to_pe_placeholder", True) - _update_context(st, ncores, fallback_gains=offline_gain_model, nt=True) + init_database = straxen.utilix_is_configured(warning_message=False) + st = straxen.test_utils.nt_test_context( + _database_init=init_database, + use_rucio=False, + ) + _update_context(st, ncores, nt=True) # Lets take an abandoned run where we actually have gains for in the CMT - _run_plugins(st, make_all=True, max_wokers=ncores, run_id=test_run_id_nT) + _run_plugins(st, make_all=True, max_workers=ncores, run_id=nt_test_run_id) # Test issue #233 st.search_field('cs1') # Test of child plugins: - _test_child_options(st, test_run_id_nT) + _test_child_options(st, nt_test_run_id) print(st.context_config) def test_nT_mutlticore(): print('nT multicore') - test_nT(2) - -# Disable the test below as it saves some time in travis and gives limited new -# information as most development is on nT-plugins. -# def test_1T_mutlticore(): -# print('1T multicore') -# test_1T(2) + test_nT(3) diff --git a/tests/test_posrec_plugins.py b/tests/test_posrec_plugins.py new file mode 100644 index 000000000..fd04d531d --- /dev/null +++ b/tests/test_posrec_plugins.py @@ -0,0 +1,73 @@ +import os + +import strax +import straxen +import unittest +import numpy as np + + +@unittest.skipIf(not straxen.utilix_is_configured(), "No db access, cannot test!") +class TestPosRecAlgorithms(unittest.TestCase): + """ + Test several options for our posrec plugins + """ + @classmethod + def setUpClass(cls) -> None: + cls.target = 'peak_positions_mlp' + cls.config_name = 'tf_model_mlp' + cls.field = 'x_mlp' + cls.run_id = straxen.test_utils.nt_test_run_id + cls.st = straxen.test_utils.nt_test_context() + cls.st.make(cls.run_id, cls.target) + + def test_set_path(self): + """Test that we can reconstruct even if we set a hardcoded path""" + # Manually do a similar thing as the URL config does behind the + # scenes + + # Get current config + plugin = self.st.get_single_plugin(self.run_id, self.target) + cmt_config = plugin.config[self.config_name] + cmt_config_without_tf = cmt_config.replace('tf://', '') + + # Hack URLConfigs to give back intermediate results (this should be easier..) + st_fixed_path = self.st.new_context() + st_fixed_path.set_config({self.config_name: cmt_config_without_tf}) + plugin_fixed = st_fixed_path.get_single_plugin(self.run_id, self.target) + file_name = getattr(plugin_fixed, self.config_name) + self.assertTrue(os.path.exists(file_name)) + + # Now let's see if we can get the same results with both contexts + set_to_config = f'tf://{file_name}' + print(f'Setting option to {set_to_config}') + st_fixed_path.set_config({self.config_name: set_to_config}) + default_result = self.st.get_array(self.run_id, self.target)[self.field] + alt_result = st_fixed_path.get_array(self.run_id, self.target)[self.field] + self.assertTrue(np.all(np.isclose(default_result, alt_result))) + + def test_set_to_none(self): + """Test that we can set the config to None, giving only nan results""" + st_with_none = self.st.new_context() + st_with_none.set_config({self.config_name: None}) + alt_result = st_with_none.get_array(self.run_id, self.target) + self.assertTrue(np.all(np.isnan(alt_result[self.field]))) + + def test_bad_configs_raising_errors(self): + """Test that we get the right errors when we set invalid options""" + dummy_st = self.st.new_context() + dummy_st.set_config({self.config_name: 'some_path_without_tf_protocol'}) + + plugin = dummy_st.get_single_plugin(self.run_id, self.target) + with self.assertRaises(ValueError): + plugin.get_tf_model() + + dummy_st.set_config({self.config_name: 'tf://some_path_that_does_not_exists'}) + + plugin = dummy_st.get_single_plugin(self.run_id, self.target) + with self.assertRaises(FileNotFoundError): + plugin.get_tf_model() + + dummy_st.register(straxen.position_reconstruction.PeakPositionsBaseNT) + plugin_name = strax.camel_to_snake('PeakPositionsBaseNT') + with self.assertRaises(NotImplementedError): + dummy_st.get_single_plugin(self.run_id, plugin_name) diff --git a/tests/test_rucio.py b/tests/test_rucio.py index 4af401ee7..b52e84eb6 100644 --- a/tests/test_rucio.py +++ b/tests/test_rucio.py @@ -1,81 +1,78 @@ -import straxen -import unittest -import strax -import socket - - -class TestBasics(unittest.TestCase): - @classmethod - def setUpClass(cls) -> None: - """ - For testing purposes, slightly alter the RucioFrontend such that - we can run tests outside of dali too - """ - if not straxen.utilix_is_configured(): - return - if 'rcc' not in socket.getfqdn(): - # If we are not on RCC, for testing, add some dummy site - straxen.RucioFrontend.local_rses = {'UC_DALI_USERDISK': r'.rcc.', - 'test_rucio': f'{socket.getfqdn()}'} - straxen.RucioFrontend.get_rse_prefix = lambda *x: 'test_rucio' - - # Some non-existing keys that we will try finding in the test cases. - cls.test_keys = [ - strax.DataKey(run_id=run_id, - data_type='dtype', - lineage={'dtype': ['Plugin', '0.0.0.', {}],} - ) - for run_id in ('-1', '-2') - ] - - def test_load_context_defaults(self): - """Don't fail immediately if we start a context due to Rucio""" - if not straxen.utilix_is_configured(): - return - st = straxen.contexts.xenonnt_online(_minimum_run_number=10_000, - _maximum_run_number=10_010, - ) - st.select_runs() - - def test_find_local(self): - """Make sure that we don't find the non existing data""" - if not straxen.utilix_is_configured(): - return - rucio = straxen.RucioFrontend( - include_remote=False, - ) - self.assertRaises(strax.DataNotAvailable, - rucio.find, - self.test_keys[0] - ) - - def test_find_several_local(self): - """Let's try finding some keys (won't be available)""" - if not straxen.utilix_is_configured(): - return - rucio = straxen.RucioFrontend( - include_remote=False, - ) - print(rucio) - found = rucio.find_several(self.test_keys) - # We shouldn't find any of these - assert found == [False for _ in self.test_keys] - - def test_find_several_remote(self): - """ - Let's try running a find_several with the include remote. - This should fail but when no rucio is installed or else it - shouldn't find any data. - """ - if not straxen.utilix_is_configured(): - return - try: - rucio = straxen.RucioFrontend( - include_remote=True, - ) - except ImportError: - pass - else: - found = rucio.find_several(self.test_keys) - # We shouldn't find any of these - assert found == [False for _ in self.test_keys] +import straxen +import unittest +import strax +import socket + + +class TestBasics(unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + """ + For testing purposes, slightly alter the RucioFrontend such that + we can run tests outside of dali too + """ + if not straxen.utilix_is_configured(): + return + if 'rcc' not in socket.getfqdn(): + # If we are not on RCC, for testing, add some dummy site + straxen.RucioFrontend.local_rses = {'UC_DALI_USERDISK': r'.rcc.', + 'test_rucio': f'{socket.getfqdn()}'} + straxen.RucioFrontend.get_rse_prefix = lambda *x: 'test_rucio' + + # Some non-existing keys that we will try finding in the test cases. + cls.test_keys = [ + strax.DataKey(run_id=run_id, + data_type='dtype', + lineage={'dtype': ['Plugin', '0.0.0.', {}], } + ) + for run_id in ('-1', '-2') + ] + + @unittest.skipIf(not straxen.utilix_is_configured(), "No db access, cannot test!") + def test_load_context_defaults(self): + """Don't fail immediately if we start a context due to Rucio""" + st = straxen.contexts.xenonnt_online(_minimum_run_number=10_000, + _maximum_run_number=10_010, + ) + st.select_runs() + + @unittest.skipIf(not straxen.utilix_is_configured(), "No db access, cannot test!") + def test_find_local(self): + """Make sure that we don't find the non existing data""" + rucio = straxen.RucioFrontend(include_remote=False,) + self.assertRaises(strax.DataNotAvailable, + rucio.find, + self.test_keys[0] + ) + + @unittest.skipIf(not straxen.utilix_is_configured(), "No db access, cannot test!") + def test_find_several_local(self): + """Let's try finding some keys (won't be available)""" + rucio = straxen.RucioFrontend(include_remote=False,) + print(rucio) + found = rucio.find_several(self.test_keys) + # We shouldn't find any of these + assert found == [False for _ in self.test_keys] + + @unittest.skipIf(not straxen.utilix_is_configured(), "No db access, cannot test!") + def test_find_several_remote(self): + """ + Let's try running a find_several with the include remote. + This should fail but when no rucio is installed or else it + shouldn't find any data. + """ + try: + rucio = straxen.RucioFrontend(include_remote=True,) + except ImportError: + pass + else: + found = rucio.find_several(self.test_keys) + # We shouldn't find any of these + assert found == [False for _ in self.test_keys] + + @unittest.skipIf(not straxen.utilix_is_configured(), "No db access, cannot test!") + def test_find_local(self): + """Make sure that we don't find the non existing data""" + run_db = straxen.RunDB(rucio_path='./rucio_test') + with self.assertRaises(strax.DataNotAvailable): + run_db.find(self.test_keys[0]) diff --git a/tests/test_scada.py b/tests/test_scada.py index a23e97f20..f215ffed9 100644 --- a/tests/test_scada.py +++ b/tests/test_scada.py @@ -1,85 +1,175 @@ +import warnings +import pytz import numpy as np import straxen -import warnings +import unittest +import requests + + +class SCInterfaceTest(unittest.TestCase): + def setUp(self): + self.resources_available() + # Simple query test: + # Query 5 s of data: + self.start = 1609682275000000000 + # Add micro-second to check if query does not fail if inquery precsion > SC precision + self.start += 10**6 + self.end = self.start + 5*10**9 + + def test_wrong_querries(self): + parameters = {'SomeParameter': 'XE1T.CTPC.Board06.Chan011.VMon'} + + with self.assertRaises(ValueError): + # Runid but no context + df = self.sc.get_scada_values(parameters, + run_id='1', + every_nth_value=1, + query_type_lab=False, ) + + with self.assertRaises(ValueError): + # No time range specified + df = self.sc.get_scada_values(parameters, + every_nth_value=1, + query_type_lab=False, ) + + with self.assertRaises(ValueError): + # Start larger end + df = self.sc.get_scada_values(parameters, + start=2, + end=1, + every_nth_value=1, + query_type_lab=False, ) + + with self.assertRaises(ValueError): + # Start and/or end not in ns unix time + df = self.sc.get_scada_values(parameters, + start=1, + end=2, + every_nth_value=1, + query_type_lab=False, ) + + def test_pmt_names(self): + """ + Tests different query options for pmt list. + """ + pmts_dict = self.sc.find_pmt_names(pmts=12, current=True) + assert 'PMT12_HV' in pmts_dict.keys() + assert 'PMT12_I' in pmts_dict.keys() + assert pmts_dict['PMT12_HV'] == 'XE1T.CTPC.BOARD04.CHAN003.VMON' + + pmts_dict = self.sc.find_pmt_names(pmts=(12, 13)) + assert 'PMT12_HV' in pmts_dict.keys() + assert 'PMT13_HV' in pmts_dict.keys() + + with self.assertRaises(ValueError): + self.sc.find_pmt_names(pmts=12, current=False, hv=False) + + def test_token_expires(self): + self.sc.token_expires_in() + def test_convert_timezone(self): + parameters = {'SomeParameter': 'XE1T.CTPC.Board06.Chan011.VMon'} + df = self.sc.get_scada_values(parameters, + start=self.start, + end=self.end, + every_nth_value=1, + query_type_lab=False, ) -def test_query_sc_values(): - """ - Unity test for the SCADAInterface. Query a fixed range and check if - return is correct. - """ - - if not straxen.utilix_is_configured('scada', 'scdata_url'): - warnings.warn('Cannot test scada since we have no access to xenon secrets.') - return - - print('Testing SCADAInterface') - sc = straxen.SCADAInterface(use_progress_bar=False) - - print('Query single value:') - # Simple query test: - # Query 5 s of data: - start = 1609682275000000000 - # Add micro-second to check if query does not fail if inquery precsion > SC precision - start += 10**6 - end = start + 5 * 10**9 - parameters = {'SomeParameter': 'XE1T.CTPC.Board06.Chan011.VMon'} - - df = sc.get_scada_values(parameters, - start=start, - end=end, - every_nth_value=1, - query_type_lab=False,) - - assert df['SomeParameter'][0] // 1 == 1253, 'First values returned is not corrrect.' - assert np.all(np.isnan(df['SomeParameter'][1:])), 'Subsequent values are not correct.' - - # Test ffill option: - print('Testing forwardfill option:') - parameters = {'SomeParameter': 'XE1T.CRY_FCV104FMON.PI'} - df = sc.get_scada_values(parameters, - start=start, - end=end, - fill_gaps='forwardfill', - every_nth_value=1, - query_type_lab=False,) - assert np.all(np.isclose(df[:4], 2.079859)), 'First four values deviate from queried values.' - assert np.all(np.isclose(df[4:], 2.117820)), 'Last two values deviate from queried values.' - - print('Testing downsampling and averaging option:') - parameters = {'SomeParameter': 'XE1T.CRY_TE101_TCRYOBOTT_AI.PI'} - df_all = sc.get_scada_values(parameters, - start=start, - end=end, - fill_gaps='forwardfill', + df_strax = straxen.convert_time_zone(df, tz='strax') + assert df_strax.index.dtype.type is np.int64 + + df_etc = straxen.convert_time_zone(df, tz='Etc/GMT+0') + assert df_etc.index.dtype.tz is pytz.timezone('Etc/GMT+0') + + def test_query_sc_values(self): + """ + Unity test for the SCADAInterface. Query a fixed range and check if + return is correct. + """ + print('Testing SCADAInterface') + parameters = {'SomeParameter': 'XE1T.CTPC.Board06.Chan011.VMon'} + df = self.sc.get_scada_values(parameters, + start=self.start, + end=self.end, + every_nth_value=1, + query_type_lab=False, ) + + assert df['SomeParameter'][0] // 1 == 1253, 'First values returned is not corrrect.' + assert np.all(np.isnan(df['SomeParameter'][1:])), 'Subsequent values are not correct.' + + print('Testing forwardfill option:') + parameters = {'SomeParameter': 'XE1T.CRY_FCV104FMON.PI'} + df = self.sc.get_scada_values(parameters, + start=self.start, + end=self.end, + fill_gaps='forwardfill', + every_nth_value=1, + query_type_lab=False,) + assert np.all(np.isclose(df[:4], 2.079859)), 'First four values deviate from queried values.' + assert np.all(np.isclose(df[4:], 2.117820)), 'Last two values deviate from queried values.' + print('Testing interpolation option:') + self.sc.get_scada_values(parameters, + start=self.start, + end=self.end, + fill_gaps='interpolation', every_nth_value=1, query_type_lab=False,) - df = sc.get_scada_values(parameters, - start=start, - end=end, - fill_gaps='forwardfill', - down_sampling=True, - every_nth_value=2, - query_type_lab=False,) - - assert np.all(df_all[::2] == df), 'Downsampling did not return the correct values.' - - df = sc.get_scada_values(parameters, - start=start, - end=end, - fill_gaps='forwardfill', - every_nth_value=2, - query_type_lab=False,) - - # Compare average for each second value: - for ind, i in enumerate([0, 2, 4]): - assert np.isclose(np.mean(df_all[i:i + 2]), df['SomeParameter'][ind]), 'Averaging is incorrect.' - - # Testing lab query type: - df = sc.get_scada_values(parameters, - start=start, - end=end, - query_type_lab=True,) - - assert np.all(df['SomeParameter'] // 1 == -96), 'Not all values are correct for query type lab.' + print('Testing down sampling and averaging option:') + parameters = {'SomeParameter': 'XE1T.CRY_TE101_TCRYOBOTT_AI.PI'} + df_all = self.sc.get_scada_values(parameters, + start=self.start, + end=self.end, + fill_gaps='forwardfill', + every_nth_value=1, + query_type_lab=False, ) + + df = self.sc.get_scada_values(parameters, + start=self.start, + end=self.end, + fill_gaps='forwardfill', + down_sampling=True, + every_nth_value=2, + query_type_lab=False,) + + assert np.all(df_all[::2] == df), 'Downsampling did not return the correct values.' + + df = self.sc.get_scada_values(parameters, + start=self.start, + end=self.end, + fill_gaps='forwardfill', + every_nth_value=2, + query_type_lab=False,) + + # Compare average for each second value: + for ind, i in enumerate([0, 2, 4]): + is_correct = np.isclose(np.mean(df_all[i:i + 2]), df['SomeParameter'][ind]) + assert is_correct, 'Averaging is incorrect.' + + # Testing lab query type: + df = self.sc.get_scada_values(parameters, + start=self.start, + end=self.end, + query_type_lab=True,) + is_sorrect = np.all(df['SomeParameter'] // 1 == -96) + assert is_sorrect, 'Not all values are correct for query type lab.' + + @staticmethod + def test_average_scada(): + t = np.arange(0, 100, 10) + t_t, t_a = straxen.scada._average_scada(t / 1e9, t, 1) + assert len(t_a) == len(t), 'Scada deleted some of my 10 datapoints!' + + def resources_available(self): + """ + Exception to skip test if external requirements are not met. Otherwise define + Scada interface as self.sc. + """ + if not straxen.utilix_is_configured('scada','scdata_url',): + self.skipTest("Cannot test scada since we have no access to xenon secrets.)") + + try: + self.sc = straxen.SCADAInterface(use_progress_bar=False) + self.sc.get_new_token() + except requests.exceptions.SSLError: + self.skipTest("Cannot reach database since HTTPs certifcate expired.") diff --git a/tests/test_several.py b/tests/test_several.py deleted file mode 100644 index c5cb29566..000000000 --- a/tests/test_several.py +++ /dev/null @@ -1,280 +0,0 @@ -"""Test several functions distibuted over common.py, misc.py, scada.py""" -import straxen -import pandas -import os -import tempfile -from .test_basics import test_run_id_1T -import numpy as np -import strax -from matplotlib.pyplot import clf as plt_clf - - -def test_pmt_pos_1t(): - """ - Test if we can get the 1T PMT positions - """ - pandas.DataFrame(straxen.pmt_positions(True)) - - -def test_pmt_pos_nt(): - """ - Test if we can get the nT PMT positions - """ - pandas.DataFrame(straxen.pmt_positions(False)) - - -def test_secret(): - """ - Check something in the sectets. This should not work because we - don't have any. - """ - try: - straxen.get_secret('somethingnonexistent') - except ValueError: - # Good we got some message we cannot load something that does - # not exist, - pass - - -# If one of the test below fail, perhaps these values need to be updated. -# They were added on 27/11/2020 and may be outdated by now -EXPECTED_OUTCOMES_TEST_SEVERAL = { - 'n_peaks': 128, - 'n_s1': 8, - 'run_live_time': 0.17933107, - 'n_events': 2, -} - - -def test_several(): - """ - Test several other functions in straxen. Is kind of messy but saves - time as we won't load data many times - :return: - """ - with tempfile.TemporaryDirectory() as temp_dir: - try: - print("Temporary directory is ", temp_dir) - os.chdir(temp_dir) - - print("Downloading test data (if needed)") - st = straxen.contexts.demo() - st.make(test_run_id_1T, 'records') - # Ignore strax-internal warnings - st.set_context_config({'free_options': tuple(st.config.keys())}) - st.make(test_run_id_1T, 'records') - - print("Get peaks") - p = st.get_array(test_run_id_1T, 'peaks') - - # Do checks on there number of peaks - assertion_statement = ("Got /more peaks than expected, perhaps " - "the test is outdated or clustering has " - "really changed") - assert np.abs(len(p) - - EXPECTED_OUTCOMES_TEST_SEVERAL['n_peaks']) < 5, assertion_statement - - events = st.get_array(test_run_id_1T, 'event_info') - print('plot wf') - peak_i = 0 - st.plot_waveform(test_run_id_1T, time_range=(p[peak_i]['time'], strax.endtime(p[peak_i]))) - plt_clf() - - print('plot hit pattern') - peak_i = 1 - st.plot_hit_pattern(test_run_id_1T, time_range=(p[peak_i]['time'], strax.endtime(p[peak_i])), xenon1t=True) - plt_clf() - - print('plot (raw)records matrix') - peak_i = 2 - assert st.is_stored(test_run_id_1T, 'records'), "no records" - assert st.is_stored(test_run_id_1T, 'raw_records'), "no raw records" - st.plot_records_matrix(test_run_id_1T, time_range=(p[peak_i]['time'], - strax.endtime(p[peak_i]))) - - st.raw_records_matrix(test_run_id_1T, time_range=(p[peak_i]['time'], - strax.endtime(p[peak_i]))) - st.plot_waveform(test_run_id_1T, - time_range=(p[peak_i]['time'], - strax.endtime(p[peak_i])), - deep=True) - plt_clf() - - straxen.analyses.event_display.plot_single_event(st, - test_run_id_1T, - events, - xenon1t=True, - event_number=0, - records_matrix='raw') - st.event_display_simple(test_run_id_1T, - time_range=(events[0]['time'], - events[0]['endtime']), - xenon1t=True) - plt_clf() - - st.event_display_interactive(test_run_id_1T, time_range=(events[0]['time'], - events[0]['endtime']), - xenon1t=True) - plt_clf() - - print('plot aft') - st.plot_peaks_aft_histogram(test_run_id_1T) - plt_clf() - - print('plot event scatter') - st.event_scatter(test_run_id_1T) - plt_clf() - - print('plot event scatter') - st.plot_energy_spectrum(test_run_id_1T) - plt_clf() - - print('plot peak clsassification') - st.plot_peak_classification(test_run_id_1T) - plt_clf() - - print("plot holoviews") - peak_i = 3 - st.waveform_display(test_run_id_1T, - time_range=(p[peak_i]['time'], - strax.endtime(p[peak_i]))) - st.hvdisp_plot_pmt_pattern(test_run_id_1T, - time_range=(p[peak_i]['time'], - strax.endtime(p[peak_i]))) - st.hvdisp_plot_peak_waveforms(test_run_id_1T, - time_range=(p[peak_i]['time'], - strax.endtime(p[peak_i]))) - - - print('Plot single pulse:') - st.plot_pulses_tpc(test_run_id_1T, max_plots=2, plot_hits=True, ignore_time_warning=True) - - print("Check live-time") - live_time = straxen.get_livetime_sec(st, test_run_id_1T, things=p) - assertion_statement = "Live-time calculation is wrong" - assert live_time == EXPECTED_OUTCOMES_TEST_SEVERAL['run_live_time'], assertion_statement - - print('Check the peak_basics') - df = st.get_df(test_run_id_1T, 'peak_basics') - assertion_statement = ("Got less/more S1s than expected, perhaps " - "the test is outdated or classification " - "has really changed.") - assert np.abs(np.sum(df['type'] == 1) - - EXPECTED_OUTCOMES_TEST_SEVERAL['n_s1']) < 2, assertion_statement - df = df[:10] - - print("Check that we can write nice wiki dfs") - straxen.dataframe_to_wiki(df) - - print("Abuse the peaks to show that _average_scada works") - p = p[:10] - p_t, p_a = straxen.scada._average_scada( - p['time']/1e9, - p['time'], - 1) - assert len(p_a) == len(p), 'Scada deleted some of my 10 peaks!' - - print('Check the number of events') - events = st.get_array(test_run_id_1T, 'event_info_double') - assertion_statement = ("Got less/ore events than expected, " - "perhaps the test is outdated or something " - "changed in the processing.") - assert len(events) == EXPECTED_OUTCOMES_TEST_SEVERAL['n_events'], assertion_statement - - print("Plot bokkeh:") - fig = st.event_display_interactive(test_run_id_1T, - time_range=(events[0]['time'], - events[0]['endtime']), - xenon1t=True, - plot_record_matrix=True, - ) - fig.save('test_display.html') - - # Test data selector: - from straxen.analyses.bokeh_waveform_plot import DataSelectionHist - ds = DataSelectionHist('ds') - fig = ds.histogram2d(p, - p['area'], - p['area'], - bins=50, - hist_range=((0, 200), (0, 2000)), - log_color_scale=True, - clim=(10, None), - undeflow_color='white') - - import bokeh.plotting as bklt - bklt.save(fig, 'test_data_selector.html') - - # On windows, you cannot delete the current process' - # working directory, so we have to chdir out first. - finally: - os.chdir('..') - - -def test_plots(): - """Make some plots""" - c = np.ones(straxen.n_tpc_pmts) - straxen.plot_pmts(c) - straxen.plot_pmts(c, log_scale=True) - - -def test_print_version(): - straxen.print_versions(['strax', 'something_that_does_not_exist']) - - -def test_nt_minianalyses(): - """Number of tests to be run on nT like configs""" - if not straxen.utilix_is_configured(): - return - with tempfile.TemporaryDirectory() as temp_dir: - try: - print("Temporary directory is ", temp_dir) - os.chdir(temp_dir) - from .test_plugins import DummyRawRecords, testing_config_nT, test_run_id_nT - st = straxen.contexts.xenonnt_online(use_rucio=False) - rundb = st.storage[0] - rundb.readonly = True - st.storage = [rundb, strax.DataDirectory(temp_dir)] - - # We want to test the FDC map that only works with CMT - test_conf = testing_config_nT.copy() - del test_conf['fdc_map'] - - st.set_config(test_conf) - st.set_context_config(dict(forbid_creation_of=())) - st.register(DummyRawRecords) - - rr = st.get_array(test_run_id_nT, 'raw_records') - st.make(test_run_id_nT, 'records') - st.make(test_run_id_nT, 'peak_basics') - - st.daq_plot(test_run_id_nT, - time_range=(rr['time'][0], strax.endtime(rr)[-1]), - vmin=0.1, - vmax=1, - ) - - st.plot_records_matrix(test_run_id_nT, - time_range=(rr['time'][0], - strax.endtime(rr)[-1]), - vmin=0.1, - vmax=1, - group_by='ADC ID', - ) - plt_clf() - - st.make(test_run_id_nT, 'event_info') - st.load_corrected_positions(test_run_id_nT, - time_range=(rr['time'][0], - strax.endtime(rr)[-1]), - - ) - # This would be nice to add but with empty events it does not work - # st.event_display(test_run_id_nT, - # time_range=(rr['time'][0], - # strax.endtime(rr)[-1]), - # ) - # On windows, you cannot delete the current process'git p - # working directory, so we have to chdir out first. - finally: - os.chdir('..') diff --git a/tests/test_testing_suite.py b/tests/test_testing_suite.py index ead900937..86fc12698 100644 --- a/tests/test_testing_suite.py +++ b/tests/test_testing_suite.py @@ -1,16 +1,16 @@ -from straxen import _is_on_pytest - - -def test_testing_suite(): - """ - Make sure that we are always on a pytest when this is called. E.g. - pytest tests/test_testing_suite.py - - A counter example would e.g. be: - python tests/test_testing_suite.py - """ - assert _is_on_pytest() - - -if __name__ == '__main__': - test_testing_suite() +from straxen import _is_on_pytest + + +def test_testing_suite(): + """ + Make sure that we are always on a pytest when this is called. E.g. + pytest tests/test_testing_suite.py + + A counter example would e.g. be: + python tests/test_testing_suite.py + """ + assert _is_on_pytest() + + +if __name__ == '__main__': + test_testing_suite() diff --git a/tests/test_url_config.py b/tests/test_url_config.py new file mode 100644 index 000000000..b0e165611 --- /dev/null +++ b/tests/test_url_config.py @@ -0,0 +1,176 @@ +import json +import strax +import straxen +import fsspec +from straxen.test_utils import nt_test_context, nt_test_run_id +import unittest +import pickle +import random +import numpy as np + + +@straxen.URLConfig.register('random') +def generate_random(_): + return random.random() + + +@straxen.URLConfig.register('unpicklable') +def return_lamba(_): + return lambda x: x + + +@straxen.URLConfig.register('large-array') +def large_array(_): + return np.ones(1_000_000).tolist() + + +class ExamplePlugin(strax.Plugin): + depends_on = () + dtype = strax.time_fields + provides = ('test_data',) + test_config = straxen.URLConfig(default=42,) + cached_config = straxen.URLConfig(default=666, cache=1) + + def compute(self): + pass + + +class TestURLConfig(unittest.TestCase): + def setUp(self): + st = nt_test_context() + st.register(ExamplePlugin) + self.st = st + + def test_default(self): + p = self.st.get_single_plugin(nt_test_run_id, 'test_data') + self.assertEqual(p.test_config, 42) + + def test_literal(self): + self.st.set_config({'test_config': 666}) + p = self.st.get_single_plugin(nt_test_run_id, 'test_data') + self.assertEqual(p.test_config, 666) + + @unittest.skipIf(not straxen.utilix_is_configured(), "No db access, cannot test CMT.") + def test_cmt_protocol(self): + self.st.set_config({'test_config': 'cmt://elife?version=v1&run_id=plugin.run_id'}) + p = self.st.get_single_plugin(nt_test_run_id, 'test_data') + self.assertTrue(abs(p.test_config-219203.49884000001) < 1e-2) + + def test_json_protocol(self): + self.st.set_config({'test_config': 'json://[1,2,3]'}) + p = self.st.get_single_plugin(nt_test_run_id, 'test_data') + self.assertEqual(p.test_config, [1, 2, 3]) + + def test_format_protocol(self): + self.st.set_config({'test_config': 'format://{run_id}?run_id=plugin.run_id'}) + p = self.st.get_single_plugin(nt_test_run_id, 'test_data') + self.assertEqual(p.test_config, nt_test_run_id) + + def test_fsspec_protocol(self): + with fsspec.open('memory://test_file.json', mode='w') as f: + json.dump({"value": 999}, f) + self.st.set_config( + {'test_config': 'take://json://fsspec://memory://test_file.json?take=value'}) + p = self.st.get_single_plugin(nt_test_run_id, 'test_data') + self.assertEqual(p.test_config, 999) + + def test_chained(self): + self.st.set_config({'test_config': 'take://json://[1,2,3]?take=0'}) + p = self.st.get_single_plugin(nt_test_run_id, 'test_data') + self.assertEqual(p.test_config, 1) + + def test_take_nested(self): + self.st.set_config({'test_config': 'take://json://{"a":[1,2,3]}?take=a&take=0'}) + p = self.st.get_single_plugin(nt_test_run_id, 'test_data') + self.assertEqual(p.test_config, 1) + + @unittest.skipIf(not straxen.utilix_is_configured(), + "No db access, cannot test!") + def test_bodedga_get(self): + """Just a didactic example""" + self.st.set_config({ + 'test_config': + 'take://' + 'resource://' + 'XENONnT_numbers.json' + '?fmt=json' + '&take=g1' + '&take=v2' + '&take=value'}) + p = self.st.get_single_plugin(nt_test_run_id, 'test_data') + # Either g1 is 0, bodega changed or someone broke URLConfigs + self.assertTrue(p.test_config) + + st2 = self.st.new_context() + st2.set_config({'test_config': 'bodega://g1?bodega_version=v2'}) + p2 = st2.get_single_plugin(nt_test_run_id, 'test_data') + self.assertEqual(p.test_config, p2.test_config) + + def test_print_protocol_desc(self): + straxen.URLConfig.print_protocols() + + def test_cache(self): + p = self.st.get_single_plugin(nt_test_run_id, 'test_data') + + # sanity check that default value is not affected + self.assertEqual(p.cached_config, 666) + self.st.set_config({'cached_config': 'random://abc'}) + p = self.st.get_single_plugin(nt_test_run_id, 'test_data') + + # value is randomly generated when accessed so if + # its equal when we access it again, its coming from the cache + cached_value = p.cached_config + self.assertEqual(cached_value, p.cached_config) + + # now change the config to which will generate a new number + self.st.set_config({'cached_config': 'random://dfg'}) + p = self.st.get_single_plugin(nt_test_run_id, 'test_data') + + # sanity check that the new value is still consistent i.e. cached + self.assertEqual(p.cached_config, p.cached_config) + + # test if previous value is evicted, since cache size is 1 + self.assertNotEqual(cached_value, p.cached_config) + + # verify pickalibility of objects in cache dont affect plugin pickalibility + self.st.set_config({'cached_config': 'unpicklable://dfg'}) + p = self.st.get_single_plugin(nt_test_run_id, 'test_data') + with self.assertRaises(AttributeError): + pickle.dumps(p.cached_config) + pickle.dumps(p) + + def test_cache_size(self): + '''test the cache helper functions + ''' + # make sure the value has a detectable size + self.st.set_config({'cached_config': 'large-array://dfg'}) + p = self.st.get_single_plugin(nt_test_run_id, 'test_data') + + # fetch the value so its stored in the cache + value = p.cached_config + + # cache should now have finite size + self.assertGreater(straxen.config_cache_size_mb(), 0.0) + + # test if clearing cache works as expected + straxen.clear_config_caches() + self.assertEqual(straxen.config_cache_size_mb(), 0.0) + + def test_filter_kwargs(self): + all_kwargs = dict(a=1, b=2, c=3) + + # test a function that takes only a seubset of the kwargs + def func1(a=None, b=None): + return + + filtered1 = straxen.filter_kwargs(func1, all_kwargs) + self.assertEqual(filtered1, dict(a=1, b=2)) + func1(**filtered1) + + + # test function that accepts wildcard kwargs + def func2(**kwargs): + return + filtered2 = straxen.filter_kwargs(func2, all_kwargs) + self.assertEqual(filtered2, all_kwargs) + func2(**filtered2) diff --git a/tests/test_veto_hitlets.py b/tests/test_veto_hitlets.py index 8ce0c6c25..ec5ba9b03 100644 --- a/tests/test_veto_hitlets.py +++ b/tests/test_veto_hitlets.py @@ -6,7 +6,6 @@ class TestRemoveSwtichedOffChannels(unittest.TestCase): - def setUp(self): self.channel_range = (10, 19) self.to_pe = np.zeros(20) diff --git a/tests/test_veto_veto_regions.py b/tests/test_veto_veto_regions.py index 51f30743c..05c3f4216 100644 --- a/tests/test_veto_veto_regions.py +++ b/tests/test_veto_veto_regions.py @@ -7,6 +7,10 @@ class TestCreateVetoIntervals(unittest.TestCase): def setUp(self): dtype = straxen.plugins.veto_events.veto_event_dtype('nveto_eventbumber') dtype += straxen.plugins.veto_events.veto_event_positions_dtype()[2:] + # Get dtype from EventSync plugin: + p = straxen.plugins.veto_events.nVETOEventsSync() + dtype_sync = p.infer_dtype() + dtype += dtype_sync[2:] self.dtype = dtype self.events = np.zeros(4, self.dtype) @@ -14,7 +18,9 @@ def setUp(self): self.events['n_hits'] = 1 self.events['n_contributing_pmt'] = 1 self.events['time'] = [2, 5, 7, 20] + self.events['time_sync'] = [2, 5, 7, 20] self.events['endtime'] = [3, 7, 8, 22] + self.events['endtime_sync'] = [3, 7, 8, 22] def test_empty_inputs(self): events = np.zeros(0, self.dtype) @@ -32,7 +38,7 @@ def test_concatenate_overlapping_intervals(self): right_extension=right_extension) assert len(vetos) == 2, 'Got the wrong number of veto intervals!' - time_is_correct = vetos[0]['time'] == self.events['time'][0]-left_extension + time_is_correct = vetos[0]['time'] == self.events['time'][0] - left_extension assert time_is_correct, 'First veto event has the wrong time!' time_is_correct = vetos[0]['endtime'] == self.events['endtime'][2] + right_extension assert time_is_correct, 'First veto event has the wrong endtime!' @@ -61,12 +67,12 @@ def _test_threshold_type(events, field, threshold_type, threshold): left_extension=0, right_extension=0) print(events[field], thresholds, vetos) - assert len(vetos) == 0, f'Vetos for {threshold_type} threshold should be empty since it is below threshold!' + assert len(vetos) == 0, f'Vetos for {threshold_type} threshold should be empty since it is below threshold!' # noqa events[field] = threshold vetos = straxen.plugins.veto_veto_regions.create_veto_intervals(events, **thresholds, left_extension=0, right_extension=0) - assert len(vetos) == 1, f'{threshold_type} threshold did not work, have a wrong number of vetos!' + assert len(vetos) == 1, f'{threshold_type} threshold did not work, have a wrong number of vetos!' # noqa events[field] = 1