diff --git a/.coveragerc b/.coveragerc index 1bf19c310aa..3ba0b9591e0 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,7 +1,7 @@ [run] omit = - xarray/tests/* - xarray/core/dask_array_compat.py - xarray/core/npcompat.py - xarray/core/pdcompat.py - xarray/core/pycompat.py + */xarray/tests/* + */xarray/core/dask_array_compat.py + */xarray/core/npcompat.py + */xarray/core/pdcompat.py + */xarray/core/pycompat.py diff --git a/.github/ISSUE_TEMPLATE/bug-report.md b/.github/ISSUE_TEMPLATE/bug-report.md new file mode 100644 index 00000000000..02bc5d0f7b0 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.md @@ -0,0 +1,39 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: '' +assignees: '' + +--- + + + +**What happened**: + +**What you expected to happen**: + +**Minimal Complete Verifiable Example**: + +```python +# Put your MCVE code here +``` + +**Anything else we need to know?**: + +**Environment**: + +
Output of xr.show_versions() + + + + +
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md deleted file mode 100644 index c712cf27979..00000000000 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ /dev/null @@ -1,35 +0,0 @@ ---- -name: Bug report / Feature request -about: 'Post a problem or idea' -title: '' -labels: '' -assignees: '' - ---- - - - - -#### MCVE Code Sample - - -```python -# Your code here - -``` - -#### Expected Output - - -#### Problem Description - - - -#### Versions - -
Output of xr.show_versions() - - - - -
diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 00000000000..0ad7e5f3e13 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,8 @@ +blank_issues_enabled: false +contact_links: + - name: Usage question + url: https://github.com/pydata/xarray/discussions + about: | + Ask questions and discuss with other community members here. + If you have a question like "How do I concatenate a list of datasets?" then + please include a self-contained reproducible example if possible. diff --git a/.github/ISSUE_TEMPLATE/feature-request.md b/.github/ISSUE_TEMPLATE/feature-request.md new file mode 100644 index 00000000000..7021fe490aa --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request.md @@ -0,0 +1,22 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: '' +assignees: '' + +--- + + + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context about the feature request here. diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index a921bddaa23..09ef053bb39 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,6 +1,15 @@ - - [ ] Closes #xxxx - - [ ] Tests added - - [ ] Passes `isort -rc . && black . && mypy . && flake8` - - [ ] Fully documented, including `whats-new.rst` for all changes and `api.rst` for new API +- [ ] Closes #xxxx +- [ ] Tests added +- [ ] Passes `pre-commit run --all-files` +- [ ] User visible changes (including notable bug fixes) are documented in `whats-new.rst` +- [ ] New functions/methods are listed in `api.rst` + + + +

+ Overriding CI behaviors +

+ By default, the upstream dev CI is disabled on pull request and push events. You can override this behavior per commit by adding a [test-upstream] tag to the first line of the commit message. For documentation-only commits, you can skip the CI per commit by adding a [skip-ci] tag to the first line of the commit message +
diff --git a/.github/actions/detect-ci-trigger/action.yaml b/.github/actions/detect-ci-trigger/action.yaml new file mode 100644 index 00000000000..c255d0c57cc --- /dev/null +++ b/.github/actions/detect-ci-trigger/action.yaml @@ -0,0 +1,29 @@ +name: Detect CI Trigger +description: | + Detect a keyword used to control the CI in the subject line of a commit message. +inputs: + keyword: + description: | + The keyword to detect. + required: true +outputs: + trigger-found: + description: | + true if the keyword has been found in the subject line of the commit message + value: ${{ steps.detect-trigger.outputs.CI_TRIGGERED }} +runs: + using: "composite" + steps: + - name: detect trigger + id: detect-trigger + run: | + bash $GITHUB_ACTION_PATH/script.sh ${{ github.event_name }} ${{ inputs.keyword }} + shell: bash + - name: show detection result + run: | + echo "::group::final summary" + echo "commit message: ${{ steps.detect-trigger.outputs.COMMIT_MESSAGE }}" + echo "trigger keyword: ${{ inputs.keyword }}" + echo "trigger found: ${{ steps.detect-trigger.outputs.CI_TRIGGERED }}" + echo "::endgroup::" + shell: bash diff --git a/.github/actions/detect-ci-trigger/script.sh b/.github/actions/detect-ci-trigger/script.sh new file mode 100644 index 00000000000..c98175a5a08 --- /dev/null +++ b/.github/actions/detect-ci-trigger/script.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash +event_name="$1" +keyword="$2" + +echo "::group::fetch a sufficient number of commits" +echo "skipped" +# git log -n 5 2>&1 +# if [[ "$event_name" == "pull_request" ]]; then +# ref=$(git log -1 --format='%H') +# git -c protocol.version=2 fetch --deepen=2 --no-tags --prune --progress -q origin $ref 2>&1 +# git log FETCH_HEAD +# git checkout FETCH_HEAD +# else +# echo "nothing to do." +# fi +# git log -n 5 2>&1 +echo "::endgroup::" + +echo "::group::extracting the commit message" +echo "event name: $event_name" +if [[ "$event_name" == "pull_request" ]]; then + ref="HEAD^2" +else + ref="HEAD" +fi + +commit_message="$(git log -n 1 --pretty=format:%s "$ref")" + +if [[ $(echo $commit_message | wc -l) -le 1 ]]; then + echo "commit message: '$commit_message'" +else + echo -e "commit message:\n--- start ---\n$commit_message\n--- end ---" +fi +echo "::endgroup::" + +echo "::group::scanning for the keyword" +echo "searching for: '$keyword'" +if echo "$commit_message" | grep -qF "$keyword"; then + result="true" +else + result="false" +fi +echo "keyword detected: $result" +echo "::endgroup::" + +echo "::set-output name=COMMIT_MESSAGE::$commit_message" +echo "::set-output name=CI_TRIGGERED::$result" diff --git a/.github/stale.yml b/.github/stale.yml index f4835b5eeec..f4057844d01 100644 --- a/.github/stale.yml +++ b/.github/stale.yml @@ -56,4 +56,4 @@ limitPerRun: 1 # start with a small number # issues: # exemptLabels: -# - confirmed \ No newline at end of file +# - confirmed diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml new file mode 100644 index 00000000000..fdc61f2f4f7 --- /dev/null +++ b/.github/workflows/ci-additional.yaml @@ -0,0 +1,191 @@ +name: CI Additional +on: + push: + branches: + - "*" + pull_request: + branches: + - "*" + workflow_dispatch: # allows you to trigger manually + +jobs: + detect-ci-trigger: + name: detect ci trigger + runs-on: ubuntu-latest + if: github.event_name == 'push' || github.event_name == 'pull_request' + outputs: + triggered: ${{ steps.detect-trigger.outputs.trigger-found }} + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 2 + - uses: ./.github/actions/detect-ci-trigger + id: detect-trigger + with: + keyword: "[skip-ci]" + + test: + name: ${{ matrix.os }} ${{ matrix.env }} + runs-on: ${{ matrix.os }} + needs: detect-ci-trigger + if: needs.detect-ci-trigger.outputs.triggered == 'false' + defaults: + run: + shell: bash -l {0} + strategy: + fail-fast: false + matrix: + os: ["ubuntu-latest"] + env: + [ + "py37-bare-minimum", + "py37-min-all-deps", + "py37-min-nep18", + "py38-all-but-dask", + "py38-backend-api-v2", + "py38-flaky", + ] + steps: + - name: Cancel previous runs + uses: styfle/cancel-workflow-action@0.6.0 + with: + access_token: ${{ github.token }} + - uses: actions/checkout@v2 + with: + fetch-depth: 0 # Fetch all history for all branches and tags. + + - name: Set environment variables + run: | + if [[ ${{ matrix.env }} == "py38-backend-api-v2" ]] ; + then + echo "CONDA_ENV_FILE=ci/requirements/environment.yml" >> $GITHUB_ENV + echo "XARRAY_BACKEND_API=v2" >> $GITHUB_ENV + + elif [[ ${{ matrix.env }} == "py38-flaky" ]] ; + then + echo "CONDA_ENV_FILE=ci/requirements/environment.yml" >> $GITHUB_ENV + echo "PYTEST_EXTRA_FLAGS=--run-flaky --run-network-tests" >> $GITHUB_ENV + + else + echo "CONDA_ENV_FILE=ci/requirements/${{ matrix.env }}.yml" >> $GITHUB_ENV + fi + - name: Cache conda + uses: actions/cache@v2 + with: + path: ~/conda_pkgs_dir + key: + ${{ runner.os }}-conda-${{ matrix.env }}-${{ + hashFiles('ci/requirements/**.yml') }} + + - uses: conda-incubator/setup-miniconda@v2 + with: + channels: conda-forge + channel-priority: strict + mamba-version: "*" + activate-environment: xarray-tests + auto-update-conda: false + python-version: 3.8 + use-only-tar-bz2: true + + - name: Install conda dependencies + run: | + mamba env update -f $CONDA_ENV_FILE + + - name: Install xarray + run: | + python -m pip install --no-deps -e . + + - name: Version info + run: | + conda info -a + conda list + python xarray/util/print_versions.py + - name: Import xarray + run: | + python -c "import xarray" + - name: Run tests + run: | + python -m pytest -n 4 \ + --cov=xarray \ + --cov-report=xml \ + $PYTEST_EXTRA_FLAGS + + - name: Upload code coverage to Codecov + uses: codecov/codecov-action@v1 + with: + file: ./coverage.xml + flags: unittests,${{ matrix.env }} + env_vars: RUNNER_OS + name: codecov-umbrella + fail_ci_if_error: false + doctest: + name: Doctests + runs-on: "ubuntu-latest" + needs: detect-ci-trigger + if: needs.detect-ci-trigger.outputs.triggered == 'false' + defaults: + run: + shell: bash -l {0} + + steps: + - name: Cancel previous runs + uses: styfle/cancel-workflow-action@0.6.0 + with: + access_token: ${{ github.token }} + - uses: actions/checkout@v2 + with: + fetch-depth: 0 # Fetch all history for all branches and tags. + - uses: conda-incubator/setup-miniconda@v2 + with: + channels: conda-forge + channel-priority: strict + mamba-version: "*" + activate-environment: xarray-tests + auto-update-conda: false + python-version: "3.8" + + - name: Install conda dependencies + run: | + mamba env update -f ci/requirements/environment.yml + - name: Install xarray + run: | + python -m pip install --no-deps -e . + - name: Version info + run: | + conda info -a + conda list + python xarray/util/print_versions.py + - name: Run doctests + run: | + python -m pytest --doctest-modules xarray --ignore xarray/tests + + min-version-policy: + name: Minimum Version Policy + runs-on: "ubuntu-latest" + needs: detect-ci-trigger + if: needs.detect-ci-trigger.outputs.triggered == 'false' + defaults: + run: + shell: bash -l {0} + + steps: + - name: Cancel previous runs + uses: styfle/cancel-workflow-action@0.6.0 + with: + access_token: ${{ github.token }} + - uses: actions/checkout@v2 + with: + fetch-depth: 0 # Fetch all history for all branches and tags. + - uses: conda-incubator/setup-miniconda@v2 + with: + channels: conda-forge + channel-priority: strict + mamba-version: "*" + auto-update-conda: false + python-version: "3.8" + + - name: minimum versions policy + run: | + mamba install -y pyyaml conda + python ci/min_deps_check.py ci/requirements/py37-bare-minimum.yml + python ci/min_deps_check.py ci/requirements/py37-min-all-deps.yml diff --git a/.github/workflows/ci-pre-commit.yml b/.github/workflows/ci-pre-commit.yml new file mode 100644 index 00000000000..1ab5642367e --- /dev/null +++ b/.github/workflows/ci-pre-commit.yml @@ -0,0 +1,16 @@ +name: linting + +on: + push: + branches: "*" + pull_request: + branches: "*" + +jobs: + linting: + name: "pre-commit hooks" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + - uses: pre-commit/action@v2.0.0 diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 00000000000..7d7326eb5c2 --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,104 @@ +name: CI +on: + push: + branches: + - "*" + pull_request: + branches: + - "*" + workflow_dispatch: # allows you to trigger manually + +jobs: + detect-ci-trigger: + name: detect ci trigger + runs-on: ubuntu-latest + if: github.event_name == 'push' || github.event_name == 'pull_request' + outputs: + triggered: ${{ steps.detect-trigger.outputs.trigger-found }} + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 2 + - uses: ./.github/actions/detect-ci-trigger + id: detect-trigger + with: + keyword: "[skip-ci]" + test: + name: ${{ matrix.os }} py${{ matrix.python-version }} + runs-on: ${{ matrix.os }} + needs: detect-ci-trigger + if: needs.detect-ci-trigger.outputs.triggered == 'false' + defaults: + run: + shell: bash -l {0} + strategy: + fail-fast: false + matrix: + os: ["ubuntu-latest", "macos-latest", "windows-latest"] + python-version: ["3.7", "3.8"] + steps: + - name: Cancel previous runs + uses: styfle/cancel-workflow-action@0.6.0 + with: + access_token: ${{ github.token }} + - uses: actions/checkout@v2 + with: + fetch-depth: 0 # Fetch all history for all branches and tags. + - name: Set environment variables + run: | + if [[ ${{ matrix.os }} == windows* ]] ; + then + echo "CONDA_ENV_FILE=ci/requirements/environment-windows.yml" >> $GITHUB_ENV + else + echo "CONDA_ENV_FILE=ci/requirements/environment.yml" >> $GITHUB_ENV + + fi + echo "PYTHON_VERSION=${{ matrix.python-version }}" >> $GITHUB_ENV + + - name: Cache conda + uses: actions/cache@v2 + with: + path: ~/conda_pkgs_dir + key: + ${{ runner.os }}-conda-py${{ matrix.python-version }}-${{ + hashFiles('ci/requirements/**.yml') }} + - uses: conda-incubator/setup-miniconda@v2 + with: + channels: conda-forge + channel-priority: strict + mamba-version: "*" + activate-environment: xarray-tests + auto-update-conda: false + python-version: ${{ matrix.python-version }} + use-only-tar-bz2: true + + - name: Install conda dependencies + run: | + mamba env update -f $CONDA_ENV_FILE + + - name: Install xarray + run: | + python -m pip install --no-deps -e . + + - name: Version info + run: | + conda info -a + conda list + python xarray/util/print_versions.py + - name: Import xarray + run: | + python -c "import xarray" + - name: Run tests + run: | + python -m pytest -n 4 \ + --cov=xarray \ + --cov-report=xml + + - name: Upload code coverage to Codecov + uses: codecov/codecov-action@v1 + with: + file: ./coverage.xml + flags: unittests + env_vars: RUNNER_OS,PYTHON_VERSION + name: codecov-umbrella + fail_ci_if_error: false diff --git a/.github/workflows/parse_logs.py b/.github/workflows/parse_logs.py new file mode 100644 index 00000000000..4d3bea54e50 --- /dev/null +++ b/.github/workflows/parse_logs.py @@ -0,0 +1,57 @@ +# type: ignore +import argparse +import itertools +import pathlib +import textwrap + +parser = argparse.ArgumentParser() +parser.add_argument("filepaths", nargs="+", type=pathlib.Path) +args = parser.parse_args() + +filepaths = sorted(p for p in args.filepaths if p.is_file()) + + +def extract_short_test_summary_info(lines): + up_to_start_of_section = itertools.dropwhile( + lambda l: "=== short test summary info ===" not in l, + lines, + ) + up_to_section_content = itertools.islice(up_to_start_of_section, 1, None) + section_content = itertools.takewhile( + lambda l: l.startswith("FAILED"), up_to_section_content + ) + content = "\n".join(section_content) + + return content + + +def format_log_message(path): + py_version = path.name.split("-")[1] + summary = f"Python {py_version} Test Summary Info" + with open(path) as f: + data = extract_short_test_summary_info(line.rstrip() for line in f) + message = ( + textwrap.dedent( + """\ +
{summary} + + ``` + {data} + ``` + +
+ """ + ) + .rstrip() + .format(summary=summary, data=data) + ) + + return message + + +print("Parsing logs ...") +message = "\n\n".join(format_log_message(path) for path in filepaths) + +output_file = pathlib.Path("pytest-logs.txt") +print(f"Writing output file to: {output_file.absolute()}") +output_file.write_text(message) diff --git a/.github/workflows/upstream-dev-ci.yaml b/.github/workflows/upstream-dev-ci.yaml new file mode 100644 index 00000000000..dda762878c5 --- /dev/null +++ b/.github/workflows/upstream-dev-ci.yaml @@ -0,0 +1,174 @@ +name: CI Upstream +on: + push: + branches: + - master + pull_request: + branches: + - master + schedule: + - cron: "0 0 * * *" # Daily “At 00:00” UTC + workflow_dispatch: # allows you to trigger the workflow run manually + +jobs: + detect-ci-trigger: + name: detect upstream-dev ci trigger + runs-on: ubuntu-latest + if: github.event_name == 'push' || github.event_name == 'pull_request' + outputs: + triggered: ${{ steps.detect-trigger.outputs.trigger-found }} + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 2 + - uses: ./.github/actions/detect-ci-trigger + id: detect-trigger + with: + keyword: "[test-upstream]" + + upstream-dev: + name: upstream-dev + runs-on: ubuntu-latest + needs: detect-ci-trigger + if: | + always() + && github.repository == 'pydata/xarray' + && ( + (github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') + || needs.detect-ci-trigger.outputs.triggered == 'true' + ) + defaults: + run: + shell: bash -l {0} + strategy: + fail-fast: false + matrix: + python-version: ["3.8"] + outputs: + artifacts_availability: ${{ steps.status.outputs.ARTIFACTS_AVAILABLE }} + steps: + - name: Cancel previous runs + uses: styfle/cancel-workflow-action@0.6.0 + with: + access_token: ${{ github.token }} + - uses: actions/checkout@v2 + - uses: conda-incubator/setup-miniconda@v2 + with: + channels: conda-forge + channel-priority: strict + mamba-version: "*" + activate-environment: xarray-tests + auto-update-conda: false + python-version: ${{ matrix.python-version }} + - name: Set up conda environment + run: | + mamba env update -f ci/requirements/environment.yml + bash ci/install-upstream-wheels.sh + - name: Version info + run: | + conda info -a + conda list + python xarray/util/print_versions.py + - name: import xarray + run: | + python -c 'import xarray' + - name: Run Tests + if: success() + id: status + run: | + set -euo pipefail + python -m pytest -rf | tee output-${{ matrix.python-version }}-log || ( + echo '::set-output name=ARTIFACTS_AVAILABLE::true' && false + ) + - name: Upload artifacts + if: | + failure() + && steps.status.outcome == 'failure' + && github.event_name == 'schedule' + && github.repository == 'pydata/xarray' + uses: actions/upload-artifact@v2 + with: + name: output-${{ matrix.python-version }}-log + path: output-${{ matrix.python-version }}-log + retention-days: 5 + + report: + name: report + needs: upstream-dev + if: | + always() + && github.event_name == 'schedule' + && github.repository == 'pydata/xarray' + && needs.upstream-dev.outputs.artifacts_availability == 'true' + runs-on: ubuntu-latest + defaults: + run: + shell: bash + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + with: + python-version: "3.x" + - uses: actions/download-artifact@v2 + with: + path: /tmp/workspace/logs + - name: Move all log files into a single directory + run: | + rsync -a /tmp/workspace/logs/output-*/ ./logs + ls -R ./logs + - name: Parse logs + run: | + shopt -s globstar + python .github/workflows/parse_logs.py logs/**/*-log + - name: Report failures + uses: actions/github-script@v3 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + const pytest_logs = fs.readFileSync('pytest-logs.txt', 'utf8'); + const title = "⚠️ Nightly upstream-dev CI failed ⚠️" + const workflow_url = `https://github.com/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}` + const issue_body = `[Workflow Run URL](${workflow_url})\n${pytest_logs}` + + // Run GraphQL query against GitHub API to find the most recent open issue used for reporting failures + const query = `query($owner:String!, $name:String!, $creator:String!, $label:String!){ + repository(owner: $owner, name: $name) { + issues(first: 1, states: OPEN, filterBy: {createdBy: $creator, labels: [$label]}, orderBy: {field: CREATED_AT, direction: DESC}) { + edges { + node { + body + id + number + } + } + } + } + }`; + + const variables = { + owner: context.repo.owner, + name: context.repo.repo, + label: 'CI', + creator: "github-actions[bot]" + } + const result = await github.graphql(query, variables) + + // If no issue is open, create a new issue, + // else update the body of the existing issue. + if (result.repository.issues.edges.length === 0) { + github.issues.create({ + owner: variables.owner, + repo: variables.name, + body: issue_body, + title: title, + labels: [variables.label] + }) + } else { + github.issues.update({ + owner: variables.owner, + repo: variables.name, + issue_number: result.repository.issues.edges[0].node.number, + body: issue_body + }) + } diff --git a/.landscape.yml b/.landscape.yml deleted file mode 100644 index 754c5715463..00000000000 --- a/.landscape.yml +++ /dev/null @@ -1,14 +0,0 @@ -doc-warnings: yes -test-warnings: yes -strictness: medium -max-line-length: 79 -autodetect: yes -ignore-paths: - - ci - - doc - - examples - - LICENSES - - notebooks -pylint: - disable: - - dangerous-default-value diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 26bf4803ef6..b0fa21a7bf9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,24 +1,34 @@ # https://pre-commit.com/ repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.4.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml # isort should run before black as black sometimes tweaks the isort output - - repo: https://github.com/timothycrosley/isort - rev: 4.3.21-2 + - repo: https://github.com/PyCQA/isort + rev: 5.7.0 hooks: - id: isort - files: .+\.py$ # https://github.com/python/black#version-control-integration - - repo: https://github.com/python/black - rev: stable + - repo: https://github.com/psf/black + rev: 20.8b1 hooks: - id: black + - repo: https://github.com/keewis/blackdoc + rev: v0.3.2 + hooks: + - id: blackdoc - repo: https://gitlab.com/pycqa/flake8 - rev: 3.7.9 + rev: 3.8.4 hooks: - id: flake8 - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.761 # Must match ci/requirements/*.yml + rev: v0.790 # Must match ci/requirements/*.yml hooks: - id: mypy + exclude: "properties|asv_bench" # run this occasionally, ref discussion https://github.com/pydata/xarray/pull/3194 # - repo: https://github.com/asottile/pyupgrade # rev: v1.22.1 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000000..7a909aefd08 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1 @@ +Xarray's contributor guidelines [can be found in our online documentation](http://xarray.pydata.org/en/stable/contributing.html) diff --git a/HOW_TO_RELEASE.md b/HOW_TO_RELEASE.md index 3fdd1d7236d..5352d427909 100644 --- a/HOW_TO_RELEASE.md +++ b/HOW_TO_RELEASE.md @@ -1,70 +1,105 @@ -How to issue an xarray release in 16 easy steps +# How to issue an xarray release in 20 easy steps Time required: about an hour. +These instructions assume that `upstream` refers to the main repository: + +```sh +$ git remote -v +{...} +upstream https://github.com/pydata/xarray (fetch) +upstream https://github.com/pydata/xarray (push) +``` + + + 1. Ensure your master branch is synced to upstream: - ``` - git pull upstream master - ``` - 2. Look over whats-new.rst and the docs. Make sure "What's New" is complete - (check the date!) and consider adding a brief summary note describing the - release at the top. + ```sh + git switch master + git pull upstream master + ``` + 2. Confirm there are no commits on stable that are not yet merged + ([ref](https://github.com/pydata/xarray/pull/4440)): + ```sh + git merge upstream stable + ``` + 2. Add a list of contributors with: + ```sh + git log "$(git tag --sort="v:refname" | sed -n 'x;$p').." --format=%aN | sort -u | perl -pe 's/\n/$1, /' + ``` + or by substituting the _previous_ release in {0.X.Y-1}: + ```sh + git log v{0.X.Y-1}.. --format=%aN | sort -u | perl -pe 's/\n/$1, /' + ``` + This will return the number of contributors: + ```sh + git log v{0.X.Y-1}.. --format=%aN | sort -u | wc -l + ``` + 3. Write a release summary: ~50 words describing the high level features. This + will be used in the release emails, tweets, GitHub release notes, etc. + 4. Look over whats-new.rst and the docs. Make sure "What's New" is complete + (check the date!) and add the release summary at the top. Things to watch out for: - Important new features should be highlighted towards the top. - Function/method references should include links to the API docs. - Sometimes notes get added in the wrong section of whats-new, typically due to a bad merge. Check for these before a release by using git diff, - e.g., `git diff v0.X.Y whats-new.rst` where 0.X.Y is the previous + e.g., `git diff v{0.X.Y-1} whats-new.rst` where {0.X.Y-1} is the previous release. - 3. If you have any doubts, run the full test suite one final time! - ``` + 5. If possible, open a PR with the release summary and whatsnew changes. + 6. After merging, again ensure your master branch is synced to upstream: + ```sh + git pull upstream master + ``` + 7. If you have any doubts, run the full test suite one final time! + ```sh pytest ``` - 4. Check that the ReadTheDocs build is passing. - 5. On the master branch, commit the release in git: - ``` - git commit -am 'Release v0.X.Y' + 8. Check that the ReadTheDocs build is passing. + 9. On the master branch, commit the release in git: + ```sh + git commit -am 'Release v{0.X.Y}' ``` - 6. Tag the release: +10. Tag the release: + ```sh + git tag -a v{0.X.Y} -m 'v{0.X.Y}' ``` - git tag -a v0.X.Y -m 'v0.X.Y' - ``` - 7. Build source and binary wheels for pypi: - ``` - git clean -xdf # this deletes all uncommited changes! +11. Build source and binary wheels for PyPI: + ```sh + git clean -xdf # this deletes all uncommitted changes! python setup.py bdist_wheel sdist ``` - 8. Use twine to check the package build: +12. Use twine to check the package build: + ```sh + twine check dist/xarray-{0.X.Y}* ``` - twine check dist/xarray-0.X.Y* - ``` - 9. Use twine to register and upload the release on pypi. Be careful, you can't +13. Use twine to register and upload the release on PyPI. Be careful, you can't take this back! - ``` - twine upload dist/xarray-0.X.Y* + ```sh + twine upload dist/xarray-{0.X.Y}* ``` You will need to be listed as a package owner at - https://pypi.python.org/pypi/xarray for this to work. -10. Push your changes to master: - ``` + for this to work. +14. Push your changes to master: + ```sh git push upstream master git push upstream --tags ``` -11. Update the stable branch (used by ReadTheDocs) and switch back to master: - ``` - git checkout stable +15. Update the stable branch (used by ReadTheDocs) and switch back to master: + ```sh + git switch stable git rebase master - git push upstream stable - git checkout master + git push --force upstream stable + git switch master ``` - It's OK to force push to 'stable' if necessary. (We also update the stable - branch with `git cherrypick` for documentation only fixes that apply the + It's OK to force push to `stable` if necessary. (We also update the stable + branch with `git cherry-pick` for documentation only fixes that apply the current released version.) -12. Add a section for the next release (v.X.Y+1) to doc/whats-new.rst: - ``` - .. _whats-new.0.X.Y+1: +16. Add a section for the next release {0.X.Y+1} to doc/whats-new.rst: + ```rst + .. _whats-new.{0.X.Y+1}: - v0.X.Y+1 (unreleased) + v{0.X.Y+1} (unreleased) --------------------- Breaking changes @@ -86,20 +121,20 @@ Time required: about an hour. Internal Changes ~~~~~~~~~~~~~~~~ ``` -13. Commit your changes and push to master again: - ``` +17. Commit your changes and push to master again: + ```sh git commit -am 'New whatsnew section' git push upstream master ``` You're done pushing to master! -14. Issue the release on GitHub. Click on "Draft a new release" at - https://github.com/pydata/xarray/releases. Type in the version number, but - don't bother to describe it -- we maintain that on the docs instead. -15. Update the docs. Login to https://readthedocs.org/projects/xray/versions/ +18. Issue the release on GitHub. Click on "Draft a new release" at + . Type in the version number + and paste the release summary in the notes. +19. Update the docs. Login to and switch your new release tag (at the bottom) from "Inactive" to "Active". It should now build automatically. -16. Issue the release announcement! For bug fix releases, I usually only email - xarray@googlegroups.com. For major/feature releases, I will email a broader +20. Issue the release announcement to mailing lists & Twitter. For bug fix releases, I + usually only email xarray@googlegroups.com. For major/feature releases, I will email a broader list (no more than once every 3-6 months): - pydata@googlegroups.com - xarray@googlegroups.com @@ -109,18 +144,10 @@ Time required: about an hour. Google search will turn up examples of prior release announcements (look for "ANN xarray"). - You can get a list of contributors with: - ``` - git log "$(git tag --sort="v:refname" | sed -n 'x;$p').." --format="%aN" | sort -u - ``` - or by substituting the _previous_ release in: - ``` - git log v0.X.Y-1.. --format="%aN" | sort -u - ``` - NB: copying this output into a Google Groups form can cause - [issues](https://groups.google.com/forum/#!topic/xarray/hK158wAviPs) with line breaks, so take care -Note on version numbering: + + +## Note on version numbering We follow a rough approximation of semantic version. Only major releases (0.X.0) should include breaking changes. Minor releases (0.X.Y) are for bug fixes and diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index cbfb8c8cdca..00000000000 --- a/MANIFEST.in +++ /dev/null @@ -1,7 +0,0 @@ -include LICENSE -recursive-include licenses * -recursive-include doc * -prune doc/_build -prune doc/generated -global-exclude .DS_Store -recursive-include xarray/static * diff --git a/README.rst b/README.rst index 5ee7234f221..e258a8ccd23 100644 --- a/README.rst +++ b/README.rst @@ -1,8 +1,8 @@ xarray: N-D labeled arrays and datasets ======================================= -.. image:: https://dev.azure.com/xarray/xarray/_apis/build/status/pydata.xarray?branchName=master - :target: https://dev.azure.com/xarray/xarray/_build/latest?definitionId=1&branchName=master +.. image:: https://github.com/pydata/xarray/workflows/CI/badge.svg?branch=master + :target: https://github.com/pydata/xarray/actions?query=workflow%3ACI .. image:: https://codecov.io/gh/pydata/xarray/branch/master/graph/badge.svg :target: https://codecov.io/gh/pydata/xarray .. image:: https://readthedocs.org/projects/xray/badge/?version=latest @@ -13,6 +13,8 @@ xarray: N-D labeled arrays and datasets :target: https://pypi.python.org/pypi/xarray/ .. image:: https://img.shields.io/badge/code%20style-black-000000.svg :target: https://github.com/python/black +.. image:: https://zenodo.org/badge/DOI/10.5281/zenodo.598201.svg + :target: https://doi.org/10.5281/zenodo.598201 **xarray** (formerly **xray**) is an open source project and Python package diff --git a/asv_bench/benchmarks/indexing.py b/asv_bench/benchmarks/indexing.py index c4cfbbbdfdf..859c41c913d 100644 --- a/asv_bench/benchmarks/indexing.py +++ b/asv_bench/benchmarks/indexing.py @@ -1,3 +1,5 @@ +import os + import numpy as np import pandas as pd @@ -138,3 +140,22 @@ def setup(self): def time_indexing(self): self.ds.isel(time=self.time_filter) + + +class HugeAxisSmallSliceIndexing: + # https://github.com/pydata/xarray/pull/4560 + def setup(self): + self.filepath = "test_indexing_huge_axis_small_slice.nc" + if not os.path.isfile(self.filepath): + xr.Dataset( + {"a": ("x", np.arange(10_000_000))}, + coords={"x": np.arange(10_000_000)}, + ).to_netcdf(self.filepath, format="NETCDF4") + + self.ds = xr.open_dataset(self.filepath) + + def time_indexing(self): + self.ds.isel(x=slice(100)) + + def cleanup(self): + self.ds.close() diff --git a/asv_bench/benchmarks/pandas.py b/asv_bench/benchmarks/pandas.py new file mode 100644 index 00000000000..42ef18ac0c2 --- /dev/null +++ b/asv_bench/benchmarks/pandas.py @@ -0,0 +1,24 @@ +import numpy as np +import pandas as pd + +from . import parameterized + + +class MultiIndexSeries: + def setup(self, dtype, subset): + data = np.random.rand(100000).astype(dtype) + index = pd.MultiIndex.from_product( + [ + list("abcdefhijk"), + list("abcdefhijk"), + pd.date_range(start="2000-01-01", periods=1000, freq="B"), + ] + ) + series = pd.Series(data, index) + if subset: + series = series[::3] + self.series = series + + @parameterized(["dtype", "subset"], ([int, float], [True, False])) + def time_to_xarray(self, dtype, subset): + self.series.to_xarray() diff --git a/azure-pipelines.yml b/azure-pipelines.yml deleted file mode 100644 index ff85501c555..00000000000 --- a/azure-pipelines.yml +++ /dev/null @@ -1,128 +0,0 @@ -variables: - pytest_extra_flags: '' - allow_failure: false - upstream_dev: false - -jobs: - -- job: Linux - strategy: - matrix: - py36-bare-minimum: - conda_env: py36-bare-minimum - py36-min-all-deps: - conda_env: py36-min-all-deps - py36-min-nep18: - conda_env: py36-min-nep18 - py36: - conda_env: py36 - py37: - conda_env: py37 - py38: - conda_env: py38 - py38-all-but-dask: - conda_env: py38-all-but-dask - py38-upstream-dev: - conda_env: py38 - upstream_dev: true - py38-flaky: - conda_env: py38 - pytest_extra_flags: --run-flaky --run-network-tests - allow_failure: true - pool: - vmImage: 'ubuntu-16.04' - steps: - - template: ci/azure/unit-tests.yml - -- job: MacOSX - strategy: - matrix: - py38: - conda_env: py38 - pool: - vmImage: 'macOS-10.15' - steps: - - template: ci/azure/unit-tests.yml - -- job: Windows - strategy: - matrix: - py37: - conda_env: py37-windows - pool: - vmImage: 'vs2017-win2016' - steps: - - template: ci/azure/unit-tests.yml - -- job: LintFlake8 - pool: - vmImage: 'ubuntu-16.04' - steps: - - task: UsePythonVersion@0 - - bash: python -m pip install flake8 - displayName: Install flake8 - - bash: flake8 - displayName: flake8 lint checks - -- job: FormattingBlack - pool: - vmImage: 'ubuntu-16.04' - steps: - - task: UsePythonVersion@0 - - bash: python -m pip install black - displayName: Install black - - bash: black --check . - displayName: black formatting check - -- job: TypeChecking - variables: - conda_env: py38 - pool: - vmImage: 'ubuntu-16.04' - steps: - - template: ci/azure/install.yml - - bash: | - source activate xarray-tests - mypy . - displayName: mypy type checks - -- job: isort - variables: - conda_env: py38 - pool: - vmImage: 'ubuntu-16.04' - steps: - - template: ci/azure/install.yml - - bash: | - source activate xarray-tests - isort -rc --check . - displayName: isort formatting checks - -- job: MinimumVersionsPolicy - pool: - vmImage: 'ubuntu-16.04' - steps: - - template: ci/azure/add-conda-to-path.yml - - bash: | - conda install -y pyyaml - python ci/min_deps_check.py ci/requirements/py36-bare-minimum.yml - python ci/min_deps_check.py ci/requirements/py36-min-all-deps.yml - displayName: minimum versions policy - -- job: Docs - pool: - vmImage: 'ubuntu-16.04' - steps: - - template: ci/azure/install.yml - parameters: - env_file: ci/requirements/doc.yml - - bash: | - source activate xarray-tests - # Replicate the exact environment created by the readthedocs CI - conda install --yes --quiet -c pkgs/main mock pillow sphinx sphinx_rtd_theme - displayName: Replicate readthedocs CI environment - - bash: | - source activate xarray-tests - cd doc - sphinx-build -W --keep-going -j auto -b html -d _build/doctrees . _build/html - displayName: Build HTML docs diff --git a/ci/azure/add-conda-to-path.yml b/ci/azure/add-conda-to-path.yml deleted file mode 100644 index e5173835388..00000000000 --- a/ci/azure/add-conda-to-path.yml +++ /dev/null @@ -1,18 +0,0 @@ -# https://docs.microsoft.com/en-us/azure/devops/pipelines/languages/anaconda -steps: - -- bash: | - echo "##vso[task.prependpath]$CONDA/bin" - displayName: Add conda to PATH (Linux) - condition: eq(variables['Agent.OS'], 'Linux') - -- bash: | - echo "##vso[task.prependpath]$CONDA/bin" - sudo chown -R $USER $CONDA - displayName: Add conda to PATH (OS X) - condition: eq(variables['Agent.OS'], 'Darwin') - -- powershell: | - Write-Host "##vso[task.prependpath]$env:CONDA\Scripts" - displayName: Add conda to PATH (Windows) - condition: eq(variables['Agent.OS'], 'Windows_NT') diff --git a/ci/azure/install.yml b/ci/azure/install.yml deleted file mode 100644 index 60559dd2064..00000000000 --- a/ci/azure/install.yml +++ /dev/null @@ -1,47 +0,0 @@ -parameters: - env_file: ci/requirements/$CONDA_ENV.yml - -steps: - -- template: add-conda-to-path.yml - -- bash: | - conda update -y conda - conda env create -n xarray-tests --file ${{ parameters.env_file }} - displayName: Install conda dependencies - -- bash: | - source activate xarray-tests - python -m pip install \ - -f https://7933911d6844c6c53a7d-47bd50c35cd79bd838daf386af554a83.ssl.cf2.rackcdn.com \ - --no-deps \ - --pre \ - --upgrade \ - matplotlib \ - numpy \ - scipy - python -m pip install \ - --no-deps \ - --upgrade \ - git+https://github.com/dask/dask \ - git+https://github.com/dask/distributed \ - git+https://github.com/zarr-developers/zarr \ - git+https://github.com/Unidata/cftime \ - git+https://github.com/mapbox/rasterio \ - git+https://github.com/hgrecco/pint \ - git+https://github.com/pydata/bottleneck \ - git+https://github.com/pandas-dev/pandas - condition: eq(variables['UPSTREAM_DEV'], 'true') - displayName: Install upstream dev dependencies - -- bash: | - source activate xarray-tests - python -m pip install --no-deps -e . - displayName: Install xarray - -- bash: | - source activate xarray-tests - conda info -a - conda list - python xarray/util/print_versions.py - displayName: Version info diff --git a/ci/azure/unit-tests.yml b/ci/azure/unit-tests.yml deleted file mode 100644 index 7ee5132632f..00000000000 --- a/ci/azure/unit-tests.yml +++ /dev/null @@ -1,34 +0,0 @@ -steps: - -- template: install.yml - -- bash: | - source activate xarray-tests - python -OO -c "import xarray" - displayName: Import xarray - -# Work around for allowed test failures: -# https://github.com/microsoft/azure-pipelines-tasks/issues/9302 -- bash: | - source activate xarray-tests - pytest \ - --junitxml=junit/test-results.xml \ - --cov=xarray \ - --cov-report=xml \ - $(pytest_extra_flags) || [ "$ALLOW_FAILURE" = "true" ] - displayName: Run tests - -- bash: | - curl https://codecov.io/bash > codecov.sh - bash codecov.sh -t 688f4d53-31bb-49b5-8370-4ce6f792cf3d - displayName: Upload coverage to codecov.io - -# TODO: publish coverage results to Azure, once we can merge them across -# multiple jobs: https://stackoverflow.com/questions/56776185 - -- task: PublishTestResults@2 - condition: succeededOrFailed() - inputs: - testResultsFiles: '**/test-*.xml' - failTaskOnFailedTests: false - testRunTitle: '$(Agent.JobName)' diff --git a/ci/install-upstream-wheels.sh b/ci/install-upstream-wheels.sh new file mode 100755 index 00000000000..fe3e706f6a6 --- /dev/null +++ b/ci/install-upstream-wheels.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash + +# TODO: add sparse back in, once Numba works with the development version of +# NumPy again: https://github.com/pydata/xarray/issues/4146 + +conda uninstall -y --force \ + numpy \ + scipy \ + pandas \ + matplotlib \ + dask \ + distributed \ + zarr \ + cftime \ + rasterio \ + pint \ + bottleneck \ + sparse +python -m pip install \ + -i https://pypi.anaconda.org/scipy-wheels-nightly/simple \ + --no-deps \ + --pre \ + --upgrade \ + numpy \ + scipy \ + pandas +python -m pip install \ + -f https://7933911d6844c6c53a7d-47bd50c35cd79bd838daf386af554a83.ssl.cf2.rackcdn.com \ + --no-deps \ + --pre \ + --upgrade \ + matplotlib +python -m pip install \ + --no-deps \ + --upgrade \ + git+https://github.com/dask/dask \ + git+https://github.com/dask/distributed \ + git+https://github.com/zarr-developers/zarr \ + git+https://github.com/Unidata/cftime \ + git+https://github.com/mapbox/rasterio \ + git+https://github.com/hgrecco/pint \ + git+https://github.com/pydata/bottleneck # \ + # git+https://github.com/pydata/sparse diff --git a/ci/min_deps_check.py b/ci/min_deps_check.py index 527093cf5bc..3ffab645e8e 100755 --- a/ci/min_deps_check.py +++ b/ci/min_deps_check.py @@ -1,15 +1,16 @@ """Fetch from conda database all available versions of the xarray dependencies and their -publication date. Compare it against requirements/py36-min-all-deps.yml to verify the +publication date. Compare it against requirements/py37-min-all-deps.yml to verify the policy on obsolete dependencies is being followed. Print a pretty report :) """ -import subprocess +import itertools import sys -from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta from typing import Dict, Iterator, Optional, Tuple +import conda.api import yaml +CHANNELS = ["conda-forge", "defaults"] IGNORE_DEPS = { "black", "coveralls", @@ -21,11 +22,26 @@ "pytest", "pytest-cov", "pytest-env", + "pytest-xdist", } -POLICY_MONTHS = {"python": 42, "numpy": 24, "pandas": 12, "scipy": 12} -POLICY_MONTHS_DEFAULT = 6 - +POLICY_MONTHS = {"python": 42, "numpy": 24, "setuptools": 42} +POLICY_MONTHS_DEFAULT = 12 +POLICY_OVERRIDE = { + # dask < 2.9 has trouble with nan-reductions + # TODO remove this special case and the matching note in installing.rst + # after January 2021. + "dask": (2, 9), + "distributed": (2, 9), + # setuptools-scm doesn't work with setuptools < 36.7 (Nov 2017). + # The conda metadata is malformed for setuptools < 38.4 (Jan 2018) + # (it's missing a timestamp which prevents this tool from working). + # setuptools < 40.4 (Sep 2018) from conda-forge cannot be installed into a py37 + # environment + # TODO remove this special case and the matching note in installing.rst + # after March 2022. + "setuptools": (40, 4), +} has_errors = False @@ -40,7 +56,7 @@ def warning(msg: str) -> None: def parse_requirements(fname) -> Iterator[Tuple[str, int, int, Optional[int]]]: - """Load requirements/py36-min-all-deps.yml + """Load requirements/py37-min-all-deps.yml Yield (package name, major version, minor version, [patch version]) """ @@ -76,30 +92,23 @@ def query_conda(pkg: str) -> Dict[Tuple[int, int], datetime]: Return map of {(major version, minor version): publication date} """ - stdout = subprocess.check_output( - ["conda", "search", pkg, "--info", "-c", "defaults", "-c", "conda-forge"] - ) - out = {} # type: Dict[Tuple[int, int], datetime] - major = None - minor = None - - for row in stdout.decode("utf-8").splitlines(): - label, _, value = row.partition(":") - label = label.strip() - if label == "file name": - value = value.strip()[len(pkg) :] - smajor, sminor = value.split("-")[1].split(".")[:2] - major = int(smajor) - minor = int(sminor) - if label == "timestamp": - assert major is not None - assert minor is not None - ts = datetime.strptime(value.split()[0].strip(), "%Y-%m-%d") - - if (major, minor) in out: - out[major, minor] = min(out[major, minor], ts) - else: - out[major, minor] = ts + + def metadata(entry): + version = entry.version + + time = datetime.fromtimestamp(entry.timestamp) + major, minor = map(int, version.split(".")[:2]) + + return (major, minor), time + + raw_data = conda.api.SubdirData.query_all(pkg, channels=CHANNELS) + data = sorted(metadata(entry) for entry in raw_data if entry.timestamp != 0) + + release_dates = { + version: [time for _, time in group if time is not None] + for version, group in itertools.groupby(data, key=lambda x: x[0]) + } + out = {version: min(dates) for version, dates in release_dates.items() if dates} # Hardcoded fix to work around incorrect dates in conda if pkg == "python": @@ -151,6 +160,11 @@ def process_pkg( policy_minor = minor policy_published_actual = published + try: + policy_major, policy_minor = POLICY_OVERRIDE[pkg] + except KeyError: + pass + if (req_major, req_minor) < (policy_major, policy_minor): status = "<" elif (req_major, req_minor) > (policy_major, policy_minor): @@ -182,16 +196,14 @@ def fmt_version(major: int, minor: int, patch: int = None) -> str: def main() -> None: fname = sys.argv[1] - with ThreadPoolExecutor(8) as ex: - futures = [ - ex.submit(process_pkg, pkg, major, minor, patch) - for pkg, major, minor, patch in parse_requirements(fname) - ] - rows = [f.result() for f in futures] - - print("Package Required Policy Status") - print("------------- -------------------- -------------------- ------") - fmt = "{:13} {:7} ({:10}) {:7} ({:10}) {}" + rows = [ + process_pkg(pkg, major, minor, patch) + for pkg, major, minor, patch in parse_requirements(fname) + ] + + print("Package Required Policy Status") + print("----------------- -------------------- -------------------- ------") + fmt = "{:17} {:7} ({:10}) {:7} ({:10}) {}" for row in rows: print(fmt.format(*row)) diff --git a/ci/requirements/doc.yml b/ci/requirements/doc.yml index 2987303c92a..e092272654b 100644 --- a/ci/requirements/doc.yml +++ b/ci/requirements/doc.yml @@ -2,6 +2,7 @@ name: xarray-docs channels: # Don't change to pkgs/main, as it causes random timeouts in readthedocs - conda-forge + - nodefaults dependencies: - python=3.8 - bottleneck @@ -13,15 +14,21 @@ dependencies: - ipython - iris>=2.3 - jupyter_client + - matplotlib-base - nbsphinx - netcdf4>=1.5 - numba - numpy>=1.17 - - numpydoc - pandas>=1.0 - rasterio>=1.1 - seaborn - setuptools - - sphinx>=2.3 + - sphinx=3.3 - sphinx_rtd_theme>=0.4 - - zarr>=2.4 \ No newline at end of file + - sphinx-autosummary-accessors + - zarr>=2.4 + - pip + - pip: + - scanpydoc + # relative to this file. Needs to be editable to be accepted. + - -e ../.. diff --git a/ci/requirements/py37-windows.yml b/ci/requirements/environment-windows.yml similarity index 73% rename from ci/requirements/py37-windows.yml rename to ci/requirements/environment-windows.yml index e9e5c7a900a..6de2bc8dc64 100644 --- a/ci/requirements/py37-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -2,27 +2,21 @@ name: xarray-tests channels: - conda-forge dependencies: - - python=3.7 - - black - boto3 - bottleneck - cartopy # - cdms2 # Not available on Windows - # - cfgrib # Causes Python interpreter crash on Windows + # - cfgrib # Causes Python interpreter crash on Windows: https://github.com/pydata/xarray/pull/3340 - cftime - - coveralls - dask - distributed - - flake8 - h5netcdf - - h5py + - h5py=2 - hdf5 - hypothesis - iris - - isort - lxml # Optional dep of pydap - - matplotlib - - mypy=0.761 # Must match .pre-commit-config.yaml + - matplotlib-base - nc-time-axis - netcdf4 - numba @@ -30,12 +24,14 @@ dependencies: - pandas - pint - pip + - pre-commit - pseudonetcdf - pydap # - pynio # Not available on Windows - pytest - pytest-cov - pytest-env + - pytest-xdist - rasterio - scipy - seaborn diff --git a/ci/requirements/py37.yml b/ci/requirements/environment.yml similarity index 73% rename from ci/requirements/py37.yml rename to ci/requirements/environment.yml index dba3926596e..0f59d9570c8 100644 --- a/ci/requirements/py37.yml +++ b/ci/requirements/environment.yml @@ -1,41 +1,38 @@ name: xarray-tests channels: - conda-forge + - nodefaults dependencies: - - python=3.7 - - black - boto3 - bottleneck - cartopy - cdms2 - cfgrib - cftime - - coveralls - dask - distributed - - flake8 - h5netcdf - - h5py + - h5py=2 - hdf5 - hypothesis - iris - - isort - lxml # Optional dep of pydap - - matplotlib - - mypy=0.761 # Must match .pre-commit-config.yaml + - matplotlib-base - nc-time-axis - netcdf4 - numba - numpy - pandas - pint - - pip + - pip=20.2 + - pre-commit - pseudonetcdf - pydap - - pynio + # - pynio: not compatible with netCDF4>1.5.3; only tested in py37-bare-minimum - pytest - pytest-cov - pytest-env + - pytest-xdist - rasterio - scipy - seaborn diff --git a/ci/requirements/py36.yml b/ci/requirements/py36.yml deleted file mode 100644 index a500173f277..00000000000 --- a/ci/requirements/py36.yml +++ /dev/null @@ -1,47 +0,0 @@ -name: xarray-tests -channels: - - conda-forge -dependencies: - - python=3.6 - - black - - boto3 - - bottleneck - - cartopy - - cdms2 - - cfgrib - - cftime - - coveralls - - dask - - distributed - - flake8 - - h5netcdf - - h5py - - hdf5 - - hypothesis - - iris - - isort - - lxml # Optional dep of pydap - - matplotlib - - mypy=0.761 # Must match .pre-commit-config.yaml - - nc-time-axis - - netcdf4 - - numba - - numpy - - pandas - - pint - - pip - - pseudonetcdf - - pydap - - pynio - - pytest - - pytest-cov - - pytest-env - - rasterio - - scipy - - seaborn - - setuptools - - sparse - - toolz - - zarr - - pip: - - numbagg diff --git a/ci/requirements/py36-bare-minimum.yml b/ci/requirements/py37-bare-minimum.yml similarity index 69% rename from ci/requirements/py36-bare-minimum.yml rename to ci/requirements/py37-bare-minimum.yml index 00fef672855..fbeb87032b7 100644 --- a/ci/requirements/py36-bare-minimum.yml +++ b/ci/requirements/py37-bare-minimum.yml @@ -1,13 +1,15 @@ name: xarray-tests channels: - conda-forge + - nodefaults dependencies: - - python=3.6 + - python=3.7 - coveralls - pip - pytest - pytest-cov - pytest-env + - pytest-xdist - numpy=1.15 - pandas=0.25 - - setuptools=41.2 + - setuptools=40.4 diff --git a/ci/requirements/py36-min-all-deps.yml b/ci/requirements/py37-min-all-deps.yml similarity index 73% rename from ci/requirements/py36-min-all-deps.yml rename to ci/requirements/py37-min-all-deps.yml index 86540197dcc..feef86ddf5c 100644 --- a/ci/requirements/py36-min-all-deps.yml +++ b/ci/requirements/py37-min-all-deps.yml @@ -1,12 +1,13 @@ name: xarray-tests channels: - conda-forge + - nodefaults dependencies: # MINIMUM VERSIONS POLICY: see doc/installing.rst # Run ci/min_deps_check.py to verify that this file respects the policy. # When upgrading python, numpy, or pandas, must also change # doc/installing.rst and setup.py. - - python=3.6 + - python=3.7 - black - boto3=1.9 - bottleneck=1.2 @@ -15,8 +16,8 @@ dependencies: - cfgrib=0.9 - cftime=1.0 - coveralls - - dask=2.2 - - distributed=2.2 + - dask=2.9 + - distributed=2.9 - flake8 - h5netcdf=0.7 - h5py=2.9 # Policy allows for 2.10, but it's a conflict-fest @@ -25,15 +26,14 @@ dependencies: - iris=2.2 - isort - lxml=4.4 # Optional dep of pydap - - matplotlib=3.1 - - msgpack-python=0.6 # remove once distributed is bumped. distributed GH3491 - - mypy=0.761 # Must match .pre-commit-config.yaml + - matplotlib-base=3.1 + - mypy=0.782 # Must match .pre-commit-config.yaml - nc-time-axis=1.2 - netcdf4=1.4 - - numba=0.44 + - numba=0.46 - numpy=1.15 - pandas=0.25 - # - pint # See py36-min-nep18.yml + # - pint # See py37-min-nep18.yml - pip - pseudonetcdf=3.0 - pydap=3.2 @@ -41,11 +41,12 @@ dependencies: - pytest - pytest-cov - pytest-env + - pytest-xdist - rasterio=1.0 - scipy=1.3 - seaborn=0.9 - - setuptools=41.2 - # - sparse # See py36-min-nep18.yml + - setuptools=40.4 + # - sparse # See py37-min-nep18.yml - toolz=0.10 - zarr=2.3 - pip: diff --git a/ci/requirements/py36-min-nep18.yml b/ci/requirements/py37-min-nep18.yml similarity index 62% rename from ci/requirements/py36-min-nep18.yml rename to ci/requirements/py37-min-nep18.yml index a5eded49cd4..aea86261a0e 100644 --- a/ci/requirements/py36-min-nep18.yml +++ b/ci/requirements/py37-min-nep18.yml @@ -1,21 +1,22 @@ name: xarray-tests channels: - conda-forge + - nodefaults dependencies: # Optional dependencies that require NEP18, such as sparse and pint, # require drastically newer packages than everything else - - python=3.6 + - python=3.7 - coveralls - - dask=2.4 - - distributed=2.4 - - msgpack-python=0.6 # remove once distributed is bumped. distributed GH3491 + - dask=2.9 + - distributed=2.9 - numpy=1.17 - pandas=0.25 - - pint=0.11 + - pint=0.15 - pip - pytest - pytest-cov - pytest-env - - scipy=1.2 - - setuptools=41.2 + - pytest-xdist + - scipy=1.3 + - setuptools=40.4 - sparse=0.8 diff --git a/ci/requirements/py38-all-but-dask.yml b/ci/requirements/py38-all-but-dask.yml index a375d9e1e5a..14930f5272d 100644 --- a/ci/requirements/py38-all-but-dask.yml +++ b/ci/requirements/py38-all-but-dask.yml @@ -1,6 +1,7 @@ name: xarray-tests channels: - conda-forge + - nodefaults dependencies: - python=3.8 - black @@ -13,13 +14,13 @@ dependencies: - coveralls - flake8 - h5netcdf - - h5py + - h5py=2 - hdf5 - hypothesis - isort - lxml # Optional dep of pydap - - matplotlib - - mypy=0.761 # Must match .pre-commit-config.yaml + - matplotlib-base + - mypy=0.790 # Must match .pre-commit-config.yaml - nc-time-axis - netcdf4 - numba @@ -29,10 +30,11 @@ dependencies: - pip - pseudonetcdf - pydap - - pynio + # - pynio: not compatible with netCDF4>1.5.3; only tested in py37-bare-minimum - pytest - pytest-cov - pytest-env + - pytest-xdist - rasterio - scipy - seaborn diff --git a/ci/requirements/py38.yml b/ci/requirements/py38.yml deleted file mode 100644 index 24602f884e9..00000000000 --- a/ci/requirements/py38.yml +++ /dev/null @@ -1,47 +0,0 @@ -name: xarray-tests -channels: - - conda-forge -dependencies: - - python=3.8 - - black - - boto3 - - bottleneck - - cartopy - - cdms2 - - cfgrib - - cftime - - coveralls - - dask - - distributed - - flake8 - - h5netcdf - - h5py - - hdf5 - - hypothesis - - iris - - isort - - lxml # Optional dep of pydap - - matplotlib - - mypy=0.761 # Must match .pre-commit-config.yaml - - nc-time-axis - - netcdf4 - - numba - - numpy - - pandas - - pint - - pip - - pseudonetcdf - - pydap - - pynio - - pytest - - pytest-cov - - pytest-env - - rasterio - - scipy - - seaborn - - setuptools - - sparse - - toolz - - zarr - - pip: - - numbagg diff --git a/conftest.py b/conftest.py index 712af1d3759..862a1a1d0bc 100644 --- a/conftest.py +++ b/conftest.py @@ -19,16 +19,23 @@ def pytest_runtest_setup(item): pytest.skip("set --run-flaky option to run flaky tests") if "network" in item.keywords and not item.config.getoption("--run-network-tests"): pytest.skip( - "set --run-network-tests to run test requiring an " "internet connection" + "set --run-network-tests to run test requiring an internet connection" ) @pytest.fixture(autouse=True) -def add_standard_imports(doctest_namespace): +def add_standard_imports(doctest_namespace, tmpdir): import numpy as np import pandas as pd + import xarray as xr doctest_namespace["np"] = np doctest_namespace["pd"] = pd doctest_namespace["xr"] = xr + + # always seed numpy.random to make the examples deterministic + np.random.seed(0) + + # always switch to the temporary directory, so files get written there + tmpdir.chdir() diff --git a/doc/_templates/autosummary/accessor.rst b/doc/_templates/autosummary/accessor.rst new file mode 100644 index 00000000000..4ba745cd6fd --- /dev/null +++ b/doc/_templates/autosummary/accessor.rst @@ -0,0 +1,6 @@ +{{ fullname }} +{{ underline }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessor:: {{ (module.split('.')[1:] + [objname]) | join('.') }} diff --git a/doc/_templates/autosummary/accessor_attribute.rst b/doc/_templates/autosummary/accessor_attribute.rst new file mode 100644 index 00000000000..b5ad65d6a73 --- /dev/null +++ b/doc/_templates/autosummary/accessor_attribute.rst @@ -0,0 +1,6 @@ +{{ fullname }} +{{ underline }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessorattribute:: {{ (module.split('.')[1:] + [objname]) | join('.') }} diff --git a/doc/_templates/autosummary/accessor_callable.rst b/doc/_templates/autosummary/accessor_callable.rst new file mode 100644 index 00000000000..7a3301814f5 --- /dev/null +++ b/doc/_templates/autosummary/accessor_callable.rst @@ -0,0 +1,6 @@ +{{ fullname }} +{{ underline }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessorcallable:: {{ (module.split('.')[1:] + [objname]) | join('.') }}.__call__ diff --git a/doc/_templates/autosummary/accessor_method.rst b/doc/_templates/autosummary/accessor_method.rst new file mode 100644 index 00000000000..aefbba6ef1b --- /dev/null +++ b/doc/_templates/autosummary/accessor_method.rst @@ -0,0 +1,6 @@ +{{ fullname }} +{{ underline }} + +.. currentmodule:: {{ module.split('.')[0] }} + +.. autoaccessormethod:: {{ (module.split('.')[1:] + [objname]) | join('.') }} diff --git a/doc/_templates/autosummary/base.rst b/doc/_templates/autosummary/base.rst new file mode 100644 index 00000000000..53f2a29c193 --- /dev/null +++ b/doc/_templates/autosummary/base.rst @@ -0,0 +1,3 @@ +:github_url: {{ fullname | github_url | escape_underscores }} + +{% extends "!autosummary/base.rst" %} diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index cc9517a98ba..e5492ec73a4 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -9,8 +9,6 @@ .. autosummary:: :toctree: generated/ - auto_combine - Dataset.nbytes Dataset.chunks @@ -18,6 +16,8 @@ Dataset.any Dataset.argmax Dataset.argmin + Dataset.idxmax + Dataset.idxmin Dataset.max Dataset.min Dataset.mean @@ -41,8 +41,6 @@ core.rolling.DatasetCoarsen.all core.rolling.DatasetCoarsen.any - core.rolling.DatasetCoarsen.argmax - core.rolling.DatasetCoarsen.argmin core.rolling.DatasetCoarsen.count core.rolling.DatasetCoarsen.max core.rolling.DatasetCoarsen.mean @@ -54,6 +52,7 @@ core.rolling.DatasetCoarsen.var core.rolling.DatasetCoarsen.boundary core.rolling.DatasetCoarsen.coord_func + core.rolling.DatasetCoarsen.keep_attrs core.rolling.DatasetCoarsen.obj core.rolling.DatasetCoarsen.side core.rolling.DatasetCoarsen.trim_excess @@ -68,8 +67,6 @@ core.groupby.DatasetGroupBy.where core.groupby.DatasetGroupBy.all core.groupby.DatasetGroupBy.any - core.groupby.DatasetGroupBy.argmax - core.groupby.DatasetGroupBy.argmin core.groupby.DatasetGroupBy.count core.groupby.DatasetGroupBy.max core.groupby.DatasetGroupBy.mean @@ -85,8 +82,6 @@ core.resample.DatasetResample.all core.resample.DatasetResample.any core.resample.DatasetResample.apply - core.resample.DatasetResample.argmax - core.resample.DatasetResample.argmin core.resample.DatasetResample.assign core.resample.DatasetResample.assign_coords core.resample.DatasetResample.bfill @@ -123,11 +118,15 @@ core.rolling.DatasetRolling.var core.rolling.DatasetRolling.center core.rolling.DatasetRolling.dim + core.rolling.DatasetRolling.keep_attrs core.rolling.DatasetRolling.min_periods core.rolling.DatasetRolling.obj core.rolling.DatasetRolling.rollings core.rolling.DatasetRolling.window + core.weighted.DatasetWeighted.obj + core.weighted.DatasetWeighted.weights + core.rolling_exp.RollingExp.mean Dataset.argsort @@ -160,6 +159,8 @@ DataArray.any DataArray.argmax DataArray.argmin + DataArray.idxmax + DataArray.idxmin DataArray.max DataArray.min DataArray.mean @@ -183,8 +184,6 @@ core.rolling.DataArrayCoarsen.all core.rolling.DataArrayCoarsen.any - core.rolling.DataArrayCoarsen.argmax - core.rolling.DataArrayCoarsen.argmin core.rolling.DataArrayCoarsen.count core.rolling.DataArrayCoarsen.max core.rolling.DataArrayCoarsen.mean @@ -196,6 +195,7 @@ core.rolling.DataArrayCoarsen.var core.rolling.DataArrayCoarsen.boundary core.rolling.DataArrayCoarsen.coord_func + core.rolling.DataArrayCoarsen.keep_attrs core.rolling.DataArrayCoarsen.obj core.rolling.DataArrayCoarsen.side core.rolling.DataArrayCoarsen.trim_excess @@ -209,8 +209,6 @@ core.groupby.DataArrayGroupBy.where core.groupby.DataArrayGroupBy.all core.groupby.DataArrayGroupBy.any - core.groupby.DataArrayGroupBy.argmax - core.groupby.DataArrayGroupBy.argmin core.groupby.DataArrayGroupBy.count core.groupby.DataArrayGroupBy.max core.groupby.DataArrayGroupBy.mean @@ -226,8 +224,6 @@ core.resample.DataArrayResample.all core.resample.DataArrayResample.any core.resample.DataArrayResample.apply - core.resample.DataArrayResample.argmax - core.resample.DataArrayResample.argmin core.resample.DataArrayResample.assign_coords core.resample.DataArrayResample.bfill core.resample.DataArrayResample.count @@ -263,11 +259,15 @@ core.rolling.DataArrayRolling.var core.rolling.DataArrayRolling.center core.rolling.DataArrayRolling.dim + core.rolling.DataArrayRolling.keep_attrs core.rolling.DataArrayRolling.min_periods core.rolling.DataArrayRolling.obj core.rolling.DataArrayRolling.window core.rolling.DataArrayRolling.window_labels + core.weighted.DataArrayWeighted.obj + core.weighted.DataArrayWeighted.weights + DataArray.argsort DataArray.clip DataArray.conj @@ -291,6 +291,14 @@ core.accessor_dt.DatetimeAccessor.days_in_month core.accessor_dt.DatetimeAccessor.daysinmonth core.accessor_dt.DatetimeAccessor.hour + core.accessor_dt.DatetimeAccessor.is_leap_year + core.accessor_dt.DatetimeAccessor.is_month_end + core.accessor_dt.DatetimeAccessor.is_month_start + core.accessor_dt.DatetimeAccessor.is_quarter_end + core.accessor_dt.DatetimeAccessor.is_quarter_start + core.accessor_dt.DatetimeAccessor.is_year_end + core.accessor_dt.DatetimeAccessor.is_year_start + core.accessor_dt.DatetimeAccessor.isocalendar core.accessor_dt.DatetimeAccessor.microsecond core.accessor_dt.DatetimeAccessor.minute core.accessor_dt.DatetimeAccessor.month @@ -305,6 +313,14 @@ core.accessor_dt.DatetimeAccessor.weekofyear core.accessor_dt.DatetimeAccessor.year + core.accessor_dt.TimedeltaAccessor.ceil + core.accessor_dt.TimedeltaAccessor.floor + core.accessor_dt.TimedeltaAccessor.round + core.accessor_dt.TimedeltaAccessor.days + core.accessor_dt.TimedeltaAccessor.microseconds + core.accessor_dt.TimedeltaAccessor.nanoseconds + core.accessor_dt.TimedeltaAccessor.seconds + core.accessor_str.StringAccessor.capitalize core.accessor_str.StringAccessor.center core.accessor_str.StringAccessor.contains @@ -379,6 +395,7 @@ Variable.min Variable.no_conflicts Variable.notnull + Variable.pad Variable.prod Variable.quantile Variable.rank @@ -452,6 +469,7 @@ IndexVariable.min IndexVariable.no_conflicts IndexVariable.notnull + IndexVariable.pad IndexVariable.prod IndexVariable.quantile IndexVariable.rank @@ -554,6 +572,16 @@ ufuncs.tanh ufuncs.trunc + plot.plot + plot.line + plot.step + plot.hist + plot.contour + plot.contourf + plot.imshow + plot.pcolormesh + plot.scatter + plot.FacetGrid.map_dataarray plot.FacetGrid.set_titles plot.FacetGrid.set_ticks @@ -562,14 +590,17 @@ CFTimeIndex.all CFTimeIndex.any CFTimeIndex.append + CFTimeIndex.argsort CFTimeIndex.argmax CFTimeIndex.argmin - CFTimeIndex.argsort CFTimeIndex.asof CFTimeIndex.asof_locs CFTimeIndex.astype + CFTimeIndex.calendar + CFTimeIndex.ceil CFTimeIndex.contains CFTimeIndex.copy + CFTimeIndex.days_in_month CFTimeIndex.delete CFTimeIndex.difference CFTimeIndex.drop @@ -580,6 +611,7 @@ CFTimeIndex.equals CFTimeIndex.factorize CFTimeIndex.fillna + CFTimeIndex.floor CFTimeIndex.format CFTimeIndex.get_indexer CFTimeIndex.get_indexer_for @@ -620,6 +652,7 @@ CFTimeIndex.reindex CFTimeIndex.rename CFTimeIndex.repeat + CFTimeIndex.round CFTimeIndex.searchsorted CFTimeIndex.set_names CFTimeIndex.set_value @@ -656,6 +689,7 @@ CFTimeIndex.dayofyear CFTimeIndex.dtype CFTimeIndex.empty + CFTimeIndex.freq CFTimeIndex.has_duplicates CFTimeIndex.hasnans CFTimeIndex.hour @@ -683,13 +717,10 @@ backends.NetCDF4DataStore.encode backends.NetCDF4DataStore.encode_attribute backends.NetCDF4DataStore.encode_variable - backends.NetCDF4DataStore.get backends.NetCDF4DataStore.get_attrs backends.NetCDF4DataStore.get_dimensions backends.NetCDF4DataStore.get_encoding backends.NetCDF4DataStore.get_variables - backends.NetCDF4DataStore.items - backends.NetCDF4DataStore.keys backends.NetCDF4DataStore.load backends.NetCDF4DataStore.open backends.NetCDF4DataStore.open_store_variable @@ -703,28 +734,26 @@ backends.NetCDF4DataStore.store backends.NetCDF4DataStore.store_dataset backends.NetCDF4DataStore.sync - backends.NetCDF4DataStore.values - backends.NetCDF4DataStore.attrs backends.NetCDF4DataStore.autoclose - backends.NetCDF4DataStore.dimensions backends.NetCDF4DataStore.ds backends.NetCDF4DataStore.format backends.NetCDF4DataStore.is_remote backends.NetCDF4DataStore.lock - backends.NetCDF4DataStore.variables + backends.H5NetCDFStore.autoclose backends.H5NetCDFStore.close backends.H5NetCDFStore.encode backends.H5NetCDFStore.encode_attribute backends.H5NetCDFStore.encode_variable - backends.H5NetCDFStore.get + backends.H5NetCDFStore.format backends.H5NetCDFStore.get_attrs backends.H5NetCDFStore.get_dimensions backends.H5NetCDFStore.get_encoding backends.H5NetCDFStore.get_variables - backends.H5NetCDFStore.items - backends.H5NetCDFStore.keys + backends.H5NetCDFStore.is_remote backends.H5NetCDFStore.load + backends.H5NetCDFStore.lock + backends.H5NetCDFStore.open backends.H5NetCDFStore.open_store_variable backends.H5NetCDFStore.prepare_variable backends.H5NetCDFStore.set_attribute @@ -736,39 +765,25 @@ backends.H5NetCDFStore.store backends.H5NetCDFStore.store_dataset backends.H5NetCDFStore.sync - backends.H5NetCDFStore.values - backends.H5NetCDFStore.attrs - backends.H5NetCDFStore.dimensions backends.H5NetCDFStore.ds - backends.H5NetCDFStore.variables backends.PydapDataStore.close - backends.PydapDataStore.get backends.PydapDataStore.get_attrs backends.PydapDataStore.get_dimensions backends.PydapDataStore.get_encoding backends.PydapDataStore.get_variables - backends.PydapDataStore.items - backends.PydapDataStore.keys backends.PydapDataStore.load backends.PydapDataStore.open backends.PydapDataStore.open_store_variable - backends.PydapDataStore.values - backends.PydapDataStore.attrs - backends.PydapDataStore.dimensions - backends.PydapDataStore.variables backends.ScipyDataStore.close backends.ScipyDataStore.encode backends.ScipyDataStore.encode_attribute backends.ScipyDataStore.encode_variable - backends.ScipyDataStore.get backends.ScipyDataStore.get_attrs backends.ScipyDataStore.get_dimensions backends.ScipyDataStore.get_encoding backends.ScipyDataStore.get_variables - backends.ScipyDataStore.items - backends.ScipyDataStore.keys backends.ScipyDataStore.load backends.ScipyDataStore.open_store_variable backends.ScipyDataStore.prepare_variable @@ -781,11 +796,7 @@ backends.ScipyDataStore.store backends.ScipyDataStore.store_dataset backends.ScipyDataStore.sync - backends.ScipyDataStore.values - backends.ScipyDataStore.attrs - backends.ScipyDataStore.dimensions backends.ScipyDataStore.ds - backends.ScipyDataStore.variables backends.FileManager.acquire backends.FileManager.acquire_context diff --git a/doc/api.rst b/doc/api.rst index b37c84e7a81..ceab7dcc976 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -21,14 +21,16 @@ Top-level functions broadcast concat merge - auto_combine combine_by_coords combine_nested where set_options + infer_freq full_like zeros_like ones_like + cov + corr dot polyval map_blocks @@ -173,6 +175,7 @@ Computation Dataset.quantile Dataset.differentiate Dataset.integrate + Dataset.map_blocks Dataset.polyfit **Aggregation**: @@ -229,6 +232,15 @@ Reshaping and reorganizing Dataset.sortby Dataset.broadcast_like +Plotting +-------- + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_method.rst + + Dataset.plot.scatter + DataArray ========= @@ -348,7 +360,6 @@ Computation DataArray.rolling_exp DataArray.weighted DataArray.coarsen - DataArray.dt DataArray.resample DataArray.get_axis_num DataArray.diff @@ -357,7 +368,8 @@ Computation DataArray.differentiate DataArray.integrate DataArray.polyfit - DataArray.str + DataArray.map_blocks + **Aggregation**: :py:attr:`~DataArray.all` @@ -397,6 +409,121 @@ Computation :py:attr:`~core.groupby.DataArrayGroupBy.where` :py:attr:`~core.groupby.DataArrayGroupBy.quantile` + +String manipulation +------------------- + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_method.rst + + DataArray.str.capitalize + DataArray.str.center + DataArray.str.contains + DataArray.str.count + DataArray.str.decode + DataArray.str.encode + DataArray.str.endswith + DataArray.str.find + DataArray.str.get + DataArray.str.index + DataArray.str.isalnum + DataArray.str.isalpha + DataArray.str.isdecimal + DataArray.str.isdigit + DataArray.str.isnumeric + DataArray.str.isspace + DataArray.str.istitle + DataArray.str.isupper + DataArray.str.len + DataArray.str.ljust + DataArray.str.lower + DataArray.str.lstrip + DataArray.str.match + DataArray.str.pad + DataArray.str.repeat + DataArray.str.replace + DataArray.str.rfind + DataArray.str.rindex + DataArray.str.rjust + DataArray.str.rstrip + DataArray.str.slice + DataArray.str.slice_replace + DataArray.str.startswith + DataArray.str.strip + DataArray.str.swapcase + DataArray.str.title + DataArray.str.translate + DataArray.str.upper + DataArray.str.wrap + DataArray.str.zfill + +Datetimelike properties +----------------------- + +**Datetime properties**: + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_attribute.rst + + DataArray.dt.year + DataArray.dt.month + DataArray.dt.day + DataArray.dt.hour + DataArray.dt.minute + DataArray.dt.second + DataArray.dt.microsecond + DataArray.dt.nanosecond + DataArray.dt.dayofweek + DataArray.dt.weekday + DataArray.dt.weekday_name + DataArray.dt.dayofyear + DataArray.dt.quarter + DataArray.dt.days_in_month + DataArray.dt.daysinmonth + DataArray.dt.season + DataArray.dt.time + DataArray.dt.is_month_start + DataArray.dt.is_month_end + DataArray.dt.is_quarter_end + DataArray.dt.is_year_start + DataArray.dt.is_leap_year + +**Datetime methods**: + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_method.rst + + DataArray.dt.floor + DataArray.dt.ceil + DataArray.dt.isocalendar + DataArray.dt.round + DataArray.dt.strftime + +**Timedelta properties**: + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_attribute.rst + + DataArray.dt.days + DataArray.dt.seconds + DataArray.dt.microseconds + DataArray.dt.nanoseconds + +**Timedelta methods**: + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_method.rst + + DataArray.dt.floor + DataArray.dt.ceil + DataArray.dt.round + + Reshaping and reorganizing -------------------------- @@ -413,6 +540,27 @@ Reshaping and reorganizing DataArray.sortby DataArray.broadcast_like +Plotting +-------- + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_callable.rst + + DataArray.plot + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_method.rst + + DataArray.plot.contourf + DataArray.plot.contour + DataArray.plot.hist + DataArray.plot.imshow + DataArray.plot.line + DataArray.plot.pcolormesh + DataArray.plot.step + .. _api.ufuncs: Universal functions @@ -423,7 +571,9 @@ Universal functions With recent versions of numpy, dask and xarray, NumPy ufuncs are now supported directly on all xarray and dask objects. This obviates the need for the ``xarray.ufuncs`` module, which should not be used for new code - unless compatibility with versions of NumPy prior to v1.13 is required. + unless compatibility with versions of NumPy prior to v1.13 is + required. They will be removed once support for NumPy prior to + v1.17 is dropped. These functions are copied from NumPy, but extended to work on NumPy arrays, dask arrays and all xarray objects. You can find them in the ``xarray.ufuncs`` @@ -518,7 +668,6 @@ Dataset methods Dataset.load Dataset.chunk Dataset.unify_chunks - Dataset.map_blocks Dataset.filter_by_attrs Dataset.info @@ -550,7 +699,6 @@ DataArray methods DataArray.load DataArray.chunk DataArray.unify_chunks - DataArray.map_blocks Coordinates objects =================== @@ -660,25 +808,6 @@ Creating custom indexes cftime_range -Plotting -======== - -.. autosummary:: - :toctree: generated/ - - Dataset.plot - plot.scatter - DataArray.plot - plot.plot - plot.contourf - plot.contour - plot.hist - plot.imshow - plot.line - plot.pcolormesh - plot.step - plot.FacetGrid - Faceting -------- .. autosummary:: @@ -766,3 +895,10 @@ Deprecated / Pending Deprecation Dataset.apply core.groupby.DataArrayGroupBy.apply core.groupby.DatasetGroupBy.apply + +.. autosummary:: + :toctree: generated/ + :template: autosummary/accessor_attribute.rst + + DataArray.dt.weekofyear + DataArray.dt.week diff --git a/doc/combining.rst b/doc/combining.rst index 05b7f2efc50..edd34826e6d 100644 --- a/doc/combining.rst +++ b/doc/combining.rst @@ -4,11 +4,12 @@ Combining data -------------- .. ipython:: python - :suppress: + :suppress: import numpy as np import pandas as pd import xarray as xr + np.random.seed(123456) * For combining datasets or data arrays along a single dimension, see concatenate_. @@ -28,20 +29,22 @@ that dimension: .. ipython:: python - arr = xr.DataArray(np.random.randn(2, 3), - [('x', ['a', 'b']), ('y', [10, 20, 30])]) - arr[:, :1] - # this resembles how you would use np.concatenate - xr.concat([arr[:, :1], arr[:, 1:]], dim='y') + da = xr.DataArray( + np.arange(6).reshape(2, 3), [("x", ["a", "b"]), ("y", [10, 20, 30])] + ) + da.isel(y=slice(0, 1)) # same as da[:, :1] + # This resembles how you would use np.concatenate: + xr.concat([da[:, :1], da[:, 1:]], dim="y") + # For more friendly pandas-like indexing you can use: + xr.concat([da.isel(y=slice(0, 1)), da.isel(y=slice(1, None))], dim="y") In addition to combining along an existing dimension, ``concat`` can create a new dimension by stacking lower dimensional arrays together: .. ipython:: python - arr[0] - # to combine these 1d arrays into a 2d array in numpy, you would use np.array - xr.concat([arr[0], arr[1]], 'x') + da.sel(x="a") + xr.concat([da.isel(x=0), da.isel(x=1)], "x") If the second argument to ``concat`` is a new dimension name, the arrays will be concatenated along that new dimension, which is always inserted as the first @@ -49,7 +52,7 @@ dimension: .. ipython:: python - xr.concat([arr[0], arr[1]], 'new_dim') + xr.concat([da.isel(x=0), da.isel(x=1)], "new_dim") The second argument to ``concat`` can also be an :py:class:`~pandas.Index` or :py:class:`~xarray.DataArray` object as well as a string, in which case it is @@ -57,14 +60,14 @@ used to label the values along the new dimension: .. ipython:: python - xr.concat([arr[0], arr[1]], pd.Index([-90, -100], name='new_dim')) + xr.concat([da.isel(x=0), da.isel(x=1)], pd.Index([-90, -100], name="new_dim")) Of course, ``concat`` also works on ``Dataset`` objects: .. ipython:: python - ds = arr.to_dataset(name='foo') - xr.concat([ds.sel(x='a'), ds.sel(x='b')], 'x') + ds = da.to_dataset(name="foo") + xr.concat([ds.sel(x="a"), ds.sel(x="b")], "x") :py:func:`~xarray.concat` has a number of options which provide deeper control over which variables are concatenated and how it handles conflicting variables @@ -84,8 +87,8 @@ To combine variables and coordinates between multiple ``DataArray`` and/or .. ipython:: python - xr.merge([ds, ds.rename({'foo': 'bar'})]) - xr.merge([xr.DataArray(n, name='var%d' % n) for n in range(5)]) + xr.merge([ds, ds.rename({"foo": "bar"})]) + xr.merge([xr.DataArray(n, name="var%d" % n) for n in range(5)]) If you merge another dataset (or a dictionary including data array objects), by default the resulting dataset will be aligned on the **union** of all index @@ -93,7 +96,7 @@ coordinates: .. ipython:: python - other = xr.Dataset({'bar': ('x', [1, 2, 3, 4]), 'x': list('abcd')}) + other = xr.Dataset({"bar": ("x", [1, 2, 3, 4]), "x": list("abcd")}) xr.merge([ds, other]) This ensures that ``merge`` is non-destructive. ``xarray.MergeError`` is raised @@ -116,7 +119,7 @@ used in the :py:class:`~xarray.Dataset` constructor: .. ipython:: python - xr.Dataset({'a': arr[:-1], 'b': arr[1:]}) + xr.Dataset({"a": da.isel(x=slice(0, 1)), "b": da.isel(x=slice(1, 2))}) .. _combine: @@ -131,8 +134,8 @@ are filled with ``NaN``. For example: .. ipython:: python - ar0 = xr.DataArray([[0, 0], [0, 0]], [('x', ['a', 'b']), ('y', [-1, 0])]) - ar1 = xr.DataArray([[1, 1], [1, 1]], [('x', ['b', 'c']), ('y', [0, 1])]) + ar0 = xr.DataArray([[0, 0], [0, 0]], [("x", ["a", "b"]), ("y", [-1, 0])]) + ar1 = xr.DataArray([[1, 1], [1, 1]], [("x", ["b", "c"]), ("y", [0, 1])]) ar0.combine_first(ar1) ar1.combine_first(ar0) @@ -152,7 +155,7 @@ variables with new values: .. ipython:: python - ds.update({'space': ('space', [10.2, 9.4, 3.9])}) + ds.update({"space": ("space", [10.2, 9.4, 3.9])}) However, dimensions are still required to be consistent between different Dataset variables, so you cannot change the size of a dimension unless you @@ -170,7 +173,7 @@ syntax: .. ipython:: python - ds['baz'] = xr.DataArray([9, 9, 9, 9, 9], coords=[('x', list('abcde'))]) + ds["baz"] = xr.DataArray([9, 9, 9, 9, 9], coords=[("x", list("abcde"))]) ds.baz Equals and identical @@ -186,14 +189,14 @@ values: .. ipython:: python - arr.equals(arr.copy()) + da.equals(da.copy()) :py:attr:`~xarray.Dataset.identical` also checks attributes, and the name of each object: .. ipython:: python - arr.identical(arr.rename('bar')) + da.identical(da.rename("bar")) :py:attr:`~xarray.Dataset.broadcast_equals` does a more relaxed form of equality check that allows variables to have different dimensions, as long as values @@ -201,8 +204,8 @@ are constant along those new dimensions: .. ipython:: python - left = xr.Dataset(coords={'x': 0}) - right = xr.Dataset({'x': [0, 0, 0]}) + left = xr.Dataset(coords={"x": 0}) + right = xr.Dataset({"x": [0, 0, 0]}) left.broadcast_equals(right) Like pandas objects, two xarray objects are still equal or identical if they have @@ -213,7 +216,7 @@ numpy): .. ipython:: python - arr == arr.copy() + da == da.copy() Note that ``NaN`` does not compare equal to ``NaN`` in element-wise comparison; you may need to deal with missing values explicitly. @@ -231,9 +234,9 @@ coordinates as long as any non-missing values agree or are disjoint: .. ipython:: python - ds1 = xr.Dataset({'a': ('x', [10, 20, 30, np.nan])}, {'x': [1, 2, 3, 4]}) - ds2 = xr.Dataset({'a': ('x', [np.nan, 30, 40, 50])}, {'x': [2, 3, 4, 5]}) - xr.merge([ds1, ds2], compat='no_conflicts') + ds1 = xr.Dataset({"a": ("x", [10, 20, 30, np.nan])}, {"x": [1, 2, 3, 4]}) + ds2 = xr.Dataset({"a": ("x", [np.nan, 30, 40, 50])}, {"x": [2, 3, 4, 5]}) + xr.merge([ds1, ds2], compat="no_conflicts") Note that due to the underlying representation of missing values as floating point numbers (``NaN``), variable data type is not always preserved when merging @@ -244,16 +247,6 @@ in this manner. Combining along multiple dimensions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. note:: - - There are currently three combining functions with similar names: - :py:func:`~xarray.auto_combine`, :py:func:`~xarray.combine_by_coords`, and - :py:func:`~xarray.combine_nested`. This is because - ``auto_combine`` is in the process of being deprecated in favour of the other - two functions, which are more general. If your code currently relies on - ``auto_combine``, then you will be able to get similar functionality by using - ``combine_nested``. - For combining many objects along multiple dimensions xarray provides :py:func:`~xarray.combine_nested` and :py:func:`~xarray.combine_by_coords`. These functions use a combination of ``concat`` and ``merge`` across different @@ -273,10 +266,12 @@ datasets into a doubly-nested list, e.g: .. ipython:: python - arr = xr.DataArray(name='temperature', data=np.random.randint(5, size=(2, 2)), dims=['x', 'y']) + arr = xr.DataArray( + name="temperature", data=np.random.randint(5, size=(2, 2)), dims=["x", "y"] + ) arr ds_grid = [[arr, arr], [arr, arr]] - xr.combine_nested(ds_grid, concat_dim=['x', 'y']) + xr.combine_nested(ds_grid, concat_dim=["x", "y"]) :py:func:`~xarray.combine_nested` can also be used to explicitly merge datasets with different variables. For example if we have 4 datasets, which are divided @@ -286,10 +281,10 @@ we wish to use ``merge`` instead of ``concat``: .. ipython:: python - temp = xr.DataArray(name='temperature', data=np.random.randn(2), dims=['t']) - precip = xr.DataArray(name='precipitation', data=np.random.randn(2), dims=['t']) + temp = xr.DataArray(name="temperature", data=np.random.randn(2), dims=["t"]) + precip = xr.DataArray(name="precipitation", data=np.random.randn(2), dims=["t"]) ds_grid = [[temp, precip], [temp, precip]] - xr.combine_nested(ds_grid, concat_dim=['t', None]) + xr.combine_nested(ds_grid, concat_dim=["t", None]) :py:func:`~xarray.combine_by_coords` is for combining objects which have dimension coordinates which specify their relationship to and order relative to one @@ -302,8 +297,8 @@ coordinates, not on their position in the list passed to ``combine_by_coords``. .. ipython:: python :okwarning: - x1 = xr.DataArray(name='foo', data=np.random.randn(3), coords=[('x', [0, 1, 2])]) - x2 = xr.DataArray(name='foo', data=np.random.randn(3), coords=[('x', [3, 4, 5])]) + x1 = xr.DataArray(name="foo", data=np.random.randn(3), coords=[("x", [0, 1, 2])]) + x2 = xr.DataArray(name="foo", data=np.random.randn(3), coords=[("x", [3, 4, 5])]) xr.combine_by_coords([x2, x1]) These functions can be used by :py:func:`~xarray.open_mfdataset` to open many diff --git a/doc/computation.rst b/doc/computation.rst index 4b8014c4782..dcfe270a942 100644 --- a/doc/computation.rst +++ b/doc/computation.rst @@ -18,17 +18,19 @@ Arithmetic operations with a single DataArray automatically vectorize (like numpy) over all array values: .. ipython:: python - :suppress: + :suppress: import numpy as np import pandas as pd import xarray as xr + np.random.seed(123456) .. ipython:: python - arr = xr.DataArray(np.random.RandomState(0).randn(2, 3), - [('x', ['a', 'b']), ('y', [10, 20, 30])]) + arr = xr.DataArray( + np.random.RandomState(0).randn(2, 3), [("x", ["a", "b"]), ("y", [10, 20, 30])] + ) arr - 3 abs(arr) @@ -45,7 +47,7 @@ Use :py:func:`~xarray.where` to conditionally switch between values: .. ipython:: python - xr.where(arr > 0, 'positive', 'negative') + xr.where(arr > 0, "positive", "negative") Use `@` to perform matrix multiplication: @@ -73,14 +75,14 @@ methods for working with missing data from pandas: .. ipython:: python - x = xr.DataArray([0, 1, np.nan, np.nan, 2], dims=['x']) + x = xr.DataArray([0, 1, np.nan, np.nan, 2], dims=["x"]) x.isnull() x.notnull() x.count() - x.dropna(dim='x') + x.dropna(dim="x") x.fillna(-1) - x.ffill('x') - x.bfill('x') + x.ffill("x") + x.bfill("x") Like pandas, xarray uses the float value ``np.nan`` (not-a-number) to represent missing values. @@ -90,9 +92,12 @@ for filling missing values via 1D interpolation. .. ipython:: python - x = xr.DataArray([0, 1, np.nan, np.nan, 2], dims=['x'], - coords={'xx': xr.Variable('x', [0, 1, 1.1, 1.9, 3])}) - x.interpolate_na(dim='x', method='linear', use_coordinate='xx') + x = xr.DataArray( + [0, 1, np.nan, np.nan, 2], + dims=["x"], + coords={"xx": xr.Variable("x", [0, 1, 1.1, 1.9, 3])}, + ) + x.interpolate_na(dim="x", method="linear", use_coordinate="xx") Note that xarray slightly diverges from the pandas ``interpolate`` syntax by providing the ``use_coordinate`` keyword which facilitates a clear specification @@ -110,8 +115,8 @@ applied along particular dimension(s): .. ipython:: python - arr.sum(dim='x') - arr.std(['x', 'y']) + arr.sum(dim="x") + arr.std(["x", "y"]) arr.min() @@ -121,7 +126,7 @@ for wrapping code designed to work with numpy arrays), you can use the .. ipython:: python - arr.get_axis_num('y') + arr.get_axis_num("y") These operations automatically skip missing values, like in pandas: @@ -142,8 +147,7 @@ method supports rolling window aggregation: .. ipython:: python - arr = xr.DataArray(np.arange(0, 7.5, 0.5).reshape(3, 5), - dims=('x', 'y')) + arr = xr.DataArray(np.arange(0, 7.5, 0.5).reshape(3, 5), dims=("x", "y")) arr :py:meth:`~xarray.DataArray.rolling` is applied along one dimension using the @@ -184,9 +188,16 @@ a value when aggregating: r = arr.rolling(y=3, center=True, min_periods=2) r.mean() +From version 0.17, xarray supports multidimensional rolling, + +.. ipython:: python + + r = arr.rolling(x=2, y=3, min_periods=2) + r.mean() + .. tip:: - Note that rolling window aggregations are faster and use less memory when bottleneck_ is installed. This only applies to numpy-backed xarray objects. + Note that rolling window aggregations are faster and use less memory when bottleneck_ is installed. This only applies to numpy-backed xarray objects with 1d-rolling. .. _bottleneck: https://github.com/pydata/bottleneck/ @@ -194,8 +205,9 @@ We can also manually iterate through ``Rolling`` objects: .. code:: python - for label, arr_window in r: - # arr_window is a view of x + for label, arr_window in r: + # arr_window is a view of x + ... .. _comput.rolling_exp: @@ -222,9 +234,9 @@ windowed rolling, convolution, short-time FFT etc. .. ipython:: python # rolling with 2-point stride - rolling_da = r.construct('window_dim', stride=2) + rolling_da = r.construct(x="x_win", y="y_win", stride=2) rolling_da - rolling_da.mean('window_dim', skipna=False) + rolling_da.mean(["x_win", "y_win"], skipna=False) Because the ``DataArray`` given by ``r.construct('window_dim')`` is a view of the original array, it is memory efficient. @@ -232,8 +244,8 @@ You can also use ``construct`` to compute a weighted rolling sum: .. ipython:: python - weight = xr.DataArray([0.25, 0.5, 0.25], dims=['window']) - arr.rolling(y=3).construct('window').dot(weight) + weight = xr.DataArray([0.25, 0.5, 0.25], dims=["window"]) + arr.rolling(y=3).construct(y="window").dot(weight) .. note:: numpy's Nan-aggregation functions such as ``nansum`` copy the original array. @@ -254,52 +266,52 @@ support weighted ``sum`` and weighted ``mean``. .. ipython:: python - coords = dict(month=('month', [1, 2, 3])) + coords = dict(month=("month", [1, 2, 3])) - prec = xr.DataArray([1.1, 1.0, 0.9], dims=('month', ), coords=coords) - weights = xr.DataArray([31, 28, 31], dims=('month', ), coords=coords) + prec = xr.DataArray([1.1, 1.0, 0.9], dims=("month",), coords=coords) + weights = xr.DataArray([31, 28, 31], dims=("month",), coords=coords) Create a weighted object: .. ipython:: python - weighted_prec = prec.weighted(weights) - weighted_prec + weighted_prec = prec.weighted(weights) + weighted_prec Calculate the weighted sum: .. ipython:: python - weighted_prec.sum() + weighted_prec.sum() Calculate the weighted mean: .. ipython:: python - weighted_prec.mean(dim="month") + weighted_prec.mean(dim="month") The weighted sum corresponds to: .. ipython:: python - weighted_sum = (prec * weights).sum() - weighted_sum + weighted_sum = (prec * weights).sum() + weighted_sum and the weighted mean to: .. ipython:: python - weighted_mean = weighted_sum / weights.sum() - weighted_mean + weighted_mean = weighted_sum / weights.sum() + weighted_mean However, the functions also take missing values in the data into account: .. ipython:: python - data = xr.DataArray([np.NaN, 2, 4]) - weights = xr.DataArray([8, 1, 1]) + data = xr.DataArray([np.NaN, 2, 4]) + weights = xr.DataArray([8, 1, 1]) - data.weighted(weights).mean() + data.weighted(weights).mean() Using ``(data * weights).sum() / weights.sum()`` would (incorrectly) result in 0.6. @@ -309,16 +321,16 @@ If the weights add up to to 0, ``sum`` returns 0: .. ipython:: python - data = xr.DataArray([1.0, 1.0]) - weights = xr.DataArray([-1.0, 1.0]) + data = xr.DataArray([1.0, 1.0]) + weights = xr.DataArray([-1.0, 1.0]) - data.weighted(weights).sum() + data.weighted(weights).sum() and ``mean`` returns ``NaN``: .. ipython:: python - data.weighted(weights).mean() + data.weighted(weights).mean() .. note:: @@ -336,18 +348,21 @@ methods. This supports the block aggregation along multiple dimensions, .. ipython:: python - x = np.linspace(0, 10, 300) - t = pd.date_range('15/12/1999', periods=364) - da = xr.DataArray(np.sin(x) * np.cos(np.linspace(0, 1, 364)[:, np.newaxis]), - dims=['time', 'x'], coords={'time': t, 'x': x}) - da + x = np.linspace(0, 10, 300) + t = pd.date_range("15/12/1999", periods=364) + da = xr.DataArray( + np.sin(x) * np.cos(np.linspace(0, 1, 364)[:, np.newaxis]), + dims=["time", "x"], + coords={"time": t, "x": x}, + ) + da In order to take a block mean for every 7 days along ``time`` dimension and every 2 points along ``x`` dimension, .. ipython:: python - da.coarsen(time=7, x=2).mean() + da.coarsen(time=7, x=2).mean() :py:meth:`~xarray.DataArray.coarsen` raises an ``ValueError`` if the data length is not a multiple of the corresponding window size. @@ -356,14 +371,14 @@ the excess entries or padding ``nan`` to insufficient entries, .. ipython:: python - da.coarsen(time=30, x=2, boundary='trim').mean() + da.coarsen(time=30, x=2, boundary="trim").mean() If you want to apply a specific function to coordinate, you can pass the function or method name to ``coord_func`` option, .. ipython:: python - da.coarsen(time=7, x=2, coord_func={'time': 'min'}).mean() + da.coarsen(time=7, x=2, coord_func={"time": "min"}).mean() .. _compute.using_coordinates: @@ -377,24 +392,25 @@ central finite differences using their coordinates, .. ipython:: python - a = xr.DataArray([0, 1, 2, 3], dims=['x'], coords=[[0.1, 0.11, 0.2, 0.3]]) + a = xr.DataArray([0, 1, 2, 3], dims=["x"], coords=[[0.1, 0.11, 0.2, 0.3]]) a - a.differentiate('x') + a.differentiate("x") This method can be used also for multidimensional arrays, .. ipython:: python - a = xr.DataArray(np.arange(8).reshape(4, 2), dims=['x', 'y'], - coords={'x': [0.1, 0.11, 0.2, 0.3]}) - a.differentiate('x') + a = xr.DataArray( + np.arange(8).reshape(4, 2), dims=["x", "y"], coords={"x": [0.1, 0.11, 0.2, 0.3]} + ) + a.differentiate("x") :py:meth:`~xarray.DataArray.integrate` computes integration based on trapezoidal rule using their coordinates, .. ipython:: python - a.integrate('x') + a.integrate("x") .. note:: These methods are limited to simple cartesian geometry. Differentiation @@ -412,9 +428,9 @@ best fitting coefficients along a given dimension and for a given order, .. ipython:: python - x = xr.DataArray(np.arange(10), dims=['x'], name='x') - a = xr.DataArray(3 + 4 * x, dims=['x'], coords={'x': x}) - out = a.polyfit(dim='x', deg=1, full=True) + x = xr.DataArray(np.arange(10), dims=["x"], name="x") + a = xr.DataArray(3 + 4 * x, dims=["x"], coords={"x": x}) + out = a.polyfit(dim="x", deg=1, full=True) out The method outputs a dataset containing the coefficients (and more if `full=True`). @@ -443,9 +459,9 @@ arrays with different sizes aligned along different dimensions: .. ipython:: python - a = xr.DataArray([1, 2], [('x', ['a', 'b'])]) + a = xr.DataArray([1, 2], [("x", ["a", "b"])]) a - b = xr.DataArray([-1, -2, -3], [('y', [10, 20, 30])]) + b = xr.DataArray([-1, -2, -3], [("y", [10, 20, 30])]) b With xarray, we can apply binary mathematical operations to these arrays, and @@ -460,7 +476,7 @@ appeared: .. ipython:: python - c = xr.DataArray(np.arange(6).reshape(3, 2), [b['y'], a['x']]) + c = xr.DataArray(np.arange(6).reshape(3, 2), [b["y"], a["x"]]) c a + c @@ -494,7 +510,7 @@ operations. The default result of a binary operation is by the *intersection* .. ipython:: python - arr = xr.DataArray(np.arange(3), [('x', range(3))]) + arr = xr.DataArray(np.arange(3), [("x", range(3))]) arr + arr[:-1] If coordinate values for a dimension are missing on either argument, all @@ -503,7 +519,7 @@ matching dimensions must have the same size: .. ipython:: :verbatim: - In [1]: arr + xr.DataArray([1, 2], dims='x') + In [1]: arr + xr.DataArray([1, 2], dims="x") ValueError: arguments without labels along dimension 'x' cannot be aligned because they have different dimension size(s) {2} than the size of the aligned dimension labels: 3 @@ -562,16 +578,20 @@ variables: .. ipython:: python - ds = xr.Dataset({'x_and_y': (('x', 'y'), np.random.randn(3, 5)), - 'x_only': ('x', np.random.randn(3))}, - coords=arr.coords) + ds = xr.Dataset( + { + "x_and_y": (("x", "y"), np.random.randn(3, 5)), + "x_only": ("x", np.random.randn(3)), + }, + coords=arr.coords, + ) ds > 0 Datasets support most of the same methods found on data arrays: .. ipython:: python - ds.mean(dim='x') + ds.mean(dim="x") abs(ds) Datasets also support NumPy ufuncs (requires NumPy v1.13 or newer), or @@ -594,7 +614,7 @@ Arithmetic between two datasets matches data variables of the same name: .. ipython:: python - ds2 = xr.Dataset({'x_and_y': 0, 'x_only': 100}) + ds2 = xr.Dataset({"x_and_y": 0, "x_only": 100}) ds - ds2 Similarly to index based alignment, the result has the intersection of all @@ -638,7 +658,7 @@ any additional arguments: .. ipython:: python squared_error = lambda x, y: (x - y) ** 2 - arr1 = xr.DataArray([0, 1, 2, 3], dims='x') + arr1 = xr.DataArray([0, 1, 2, 3], dims="x") xr.apply_ufunc(squared_error, arr1, 1) For using more complex operations that consider some array values collectively, @@ -658,21 +678,21 @@ to set ``axis=-1``. As an example, here is how we would wrap .. code-block:: python def vector_norm(x, dim, ord=None): - return xr.apply_ufunc(np.linalg.norm, x, - input_core_dims=[[dim]], - kwargs={'ord': ord, 'axis': -1}) + return xr.apply_ufunc( + np.linalg.norm, x, input_core_dims=[[dim]], kwargs={"ord": ord, "axis": -1} + ) .. ipython:: python - :suppress: + :suppress: def vector_norm(x, dim, ord=None): - return xr.apply_ufunc(np.linalg.norm, x, - input_core_dims=[[dim]], - kwargs={'ord': ord, 'axis': -1}) + return xr.apply_ufunc( + np.linalg.norm, x, input_core_dims=[[dim]], kwargs={"ord": ord, "axis": -1} + ) .. ipython:: python - vector_norm(arr1, dim='x') + vector_norm(arr1, dim="x") Because ``apply_ufunc`` follows a standard convention for ufuncs, it plays nicely with tools for building vectorized functions, like diff --git a/doc/conf.py b/doc/conf.py index 578f9cf550d..d83e966f3fa 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # xarray documentation build configuration file, created by # sphinx-quickstart on Thu Feb 6 18:57:54 2014. @@ -20,12 +19,10 @@ import sys from contextlib import suppress -# make sure the source version is preferred (#3567) -root = pathlib.Path(__file__).absolute().parent.parent -os.environ["PYTHONPATH"] = str(root) -sys.path.insert(0, str(root)) +import sphinx_autosummary_accessors +from jinja2.defaults import DEFAULT_FILTERS -import xarray # isort:skip +import xarray allowed_failures = set() @@ -39,7 +36,7 @@ print("pip environment:") subprocess.run(["pip", "list"]) -print("xarray: %s, %s" % (xarray.__version__, xarray.__file__)) +print(f"xarray: {xarray.__version__}, {xarray.__file__}") with suppress(ImportError): import matplotlib @@ -47,14 +44,14 @@ matplotlib.use("Agg") try: - import rasterio + import rasterio # noqa: F401 except ImportError: allowed_failures.update( ["gallery/plot_rasterio_rgb.py", "gallery/plot_rasterio.py"] ) try: - import cartopy + import cartopy # noqa: F401 except ImportError: allowed_failures.update( [ @@ -79,10 +76,11 @@ "sphinx.ext.extlinks", "sphinx.ext.mathjax", "sphinx.ext.napoleon", - "numpydoc", "IPython.sphinxext.ipython_directive", "IPython.sphinxext.ipython_console_highlighting", "nbsphinx", + "sphinx_autosummary_accessors", + "scanpydoc.rtd_github_links", ] extlinks = { @@ -102,16 +100,78 @@ """ autosummary_generate = True + +# for scanpydoc's jinja filter +project_dir = pathlib.Path(__file__).parent.parent +html_context = { + "github_user": "pydata", + "github_repo": "xarray", + "github_version": "master", +} + autodoc_typehints = "none" -napoleon_use_param = True -napoleon_use_rtype = True +napoleon_google_docstring = False +napoleon_numpy_docstring = True + +napoleon_use_param = False +napoleon_use_rtype = False +napoleon_preprocess_types = True +napoleon_type_aliases = { + # general terms + "sequence": ":term:`sequence`", + "iterable": ":term:`iterable`", + "callable": ":py:func:`callable`", + "dict_like": ":term:`dict-like `", + "dict-like": ":term:`dict-like `", + "mapping": ":term:`mapping`", + "file-like": ":term:`file-like `", + # special terms + # "same type as caller": "*same type as caller*", # does not work, yet + # "same type as values": "*same type as values*", # does not work, yet + # stdlib type aliases + "MutableMapping": "~collections.abc.MutableMapping", + "sys.stdout": ":obj:`sys.stdout`", + "timedelta": "~datetime.timedelta", + "string": ":class:`string `", + # numpy terms + "array_like": ":term:`array_like`", + "array-like": ":term:`array-like `", + "scalar": ":term:`scalar`", + "array": ":term:`array`", + "hashable": ":term:`hashable `", + # matplotlib terms + "color-like": ":py:func:`color-like `", + "matplotlib colormap name": ":doc:matplotlib colormap name ", + "matplotlib axes object": ":py:class:`matplotlib axes object `", + "colormap": ":py:class:`colormap `", + # objects without namespace + "DataArray": "~xarray.DataArray", + "Dataset": "~xarray.Dataset", + "Variable": "~xarray.Variable", + "ndarray": "~numpy.ndarray", + "MaskedArray": "~numpy.ma.MaskedArray", + "dtype": "~numpy.dtype", + "ComplexWarning": "~numpy.ComplexWarning", + "Index": "~pandas.Index", + "MultiIndex": "~pandas.MultiIndex", + "CategoricalIndex": "~pandas.CategoricalIndex", + "TimedeltaIndex": "~pandas.TimedeltaIndex", + "DatetimeIndex": "~pandas.DatetimeIndex", + "Series": "~pandas.Series", + "DataFrame": "~pandas.DataFrame", + "Categorical": "~pandas.Categorical", + "Path": "~~pathlib.Path", + # objects with abbreviated namespace (from pandas) + "pd.Index": "~pandas.Index", + "pd.NaT": "~pandas.NaT", +} numpydoc_class_members_toctree = True numpydoc_show_class_members = False # Add any paths that contain templates here, relative to this directory. -templates_path = ["_templates"] +templates_path = ["_templates", sphinx_autosummary_accessors.templates_path] # The suffix of source filenames. source_suffix = ".rst" @@ -270,21 +330,21 @@ # -- Options for LaTeX output --------------------------------------------- -latex_elements = { - # The paper size ('letterpaper' or 'a4paper'). - # 'papersize': 'letterpaper', - # The font size ('10pt', '11pt' or '12pt'). - # 'pointsize': '10pt', - # Additional stuff for the LaTeX preamble. - # 'preamble': '', -} +# latex_elements = { +# # The paper size ('letterpaper' or 'a4paper'). +# # 'papersize': 'letterpaper', +# # The font size ('10pt', '11pt' or '12pt'). +# # 'pointsize': '10pt', +# # Additional stuff for the LaTeX preamble. +# # 'preamble': '', +# } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). -latex_documents = [ - ("index", "xarray.tex", "xarray Documentation", "xarray Developers", "manual") -] +# latex_documents = [ +# ("index", "xarray.tex", "xarray Documentation", "xarray Developers", "manual") +# ] # The name of an image file (relative to this directory) to place at the top of # the title page. @@ -311,7 +371,7 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [("index", "xarray", "xarray Documentation", ["xarray Developers"], 1)] +# man_pages = [("index", "xarray", "xarray Documentation", ["xarray Developers"], 1)] # If true, show URL addresses after external links. # man_show_urls = False @@ -322,17 +382,17 @@ # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) -texinfo_documents = [ - ( - "index", - "xarray", - "xarray Documentation", - "xarray Developers", - "xarray", - "N-D labeled arrays and datasets in Python.", - "Miscellaneous", - ) -] +# texinfo_documents = [ +# ( +# "index", +# "xarray", +# "xarray Documentation", +# "xarray Developers", +# "xarray", +# "N-D labeled arrays and datasets in Python.", +# "Miscellaneous", +# ) +# ] # Documents to append as an appendix to all manuals. # texinfo_appendices = [] @@ -352,10 +412,20 @@ "python": ("https://docs.python.org/3/", None), "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), "iris": ("https://scitools.org.uk/iris/docs/latest", None), - "numpy": ("https://docs.scipy.org/doc/numpy", None), + "numpy": ("https://numpy.org/doc/stable", None), "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), "numba": ("https://numba.pydata.org/numba-doc/latest", None), "matplotlib": ("https://matplotlib.org", None), "dask": ("https://docs.dask.org/en/latest", None), "cftime": ("https://unidata.github.io/cftime", None), + "rasterio": ("https://rasterio.readthedocs.io/en/latest", None), + "sparse": ("https://sparse.pydata.org/en/latest/", None), } + + +def escape_underscores(string): + return string.replace("_", r"\_") + + +def setup(app): + DEFAULT_FILTERS["escape_underscores"] = escape_underscores diff --git a/doc/contributing.rst b/doc/contributing.rst index f581bcd9741..9c4ce5a0af2 100644 --- a/doc/contributing.rst +++ b/doc/contributing.rst @@ -40,8 +40,8 @@ report will allow others to reproduce the bug and provide insight into fixing. S `this stackoverflow article `_ for tips on writing a good bug report. -Trying the bug-producing code out on the *master* branch is often a worthwhile exercise -to confirm the bug still exists. It is also worth searching existing bug reports and +Trying out the bug-producing code on the *master* branch is often a worthwhile exercise +to confirm that the bug still exists. It is also worth searching existing bug reports and pull requests to see if the issue has already been reported and/or fixed. Bug reports must: @@ -51,8 +51,9 @@ Bug reports must: `_:: ```python - >>> import xarray as xr - >>> df = xr.Dataset(...) + import xarray as xr + df = xr.Dataset(...) + ... ``` @@ -148,11 +149,16 @@ We'll now kick off a two-step process: 1. Install the build dependencies 2. Build and install xarray -.. code-block:: none +.. code-block:: sh # Create and activate the build environment - # This is for Linux and MacOS. On Windows, use py37-windows.yml instead. - conda env create -f ci/requirements/py37.yml + conda create -c conda-forge -n xarray-tests python=3.8 + + # This is for Linux and MacOS + conda env update -f ci/requirements/environment.yml + + # On windows, use environment-windows.yml instead + conda env update -f ci/requirements/environment-windows.yml conda activate xarray-tests @@ -162,7 +168,10 @@ We'll now kick off a two-step process: # Build and install xarray pip install -e . -At this point you should be able to import *xarray* from your locally built version:: +At this point you should be able to import *xarray* from your locally +built version: + +.. code-block:: sh $ python # start an interpreter >>> import xarray @@ -186,7 +195,7 @@ Creating a branch ----------------- You want your master branch to reflect only production-ready code, so create a -feature branch for making your changes. For example:: +feature branch before making your changes. For example:: git branch shiny-new-feature git checkout shiny-new-feature @@ -203,12 +212,12 @@ and switch in between them using the ``git checkout`` command. To update this branch, you need to retrieve the changes from the master branch:: git fetch upstream - git rebase upstream/master + git merge upstream/master -This will replay your commits on top of the latest *xarray* git master. If this +This will combine your commits with the latest *xarray* git master. If this leads to merge conflicts, you must resolve these before submitting your pull request. If you have uncommitted changes, you will need to ``git stash`` them -prior to updating. This will effectively store your changes and they can be +prior to updating. This will effectively store your changes, which can be reapplied after updating. .. _contributing.documentation: @@ -249,30 +258,32 @@ Some other important things to know about the docs: - The docstrings follow the **Numpy Docstring Standard**, which is used widely in the Scientific Python community. This standard specifies the format of the different sections of the docstring. See `this document - `_ + `_ for a detailed explanation, or look at some of the existing functions to extend it in a similar manner. - The tutorials make heavy use of the `ipython directive `_ sphinx extension. This directive lets you put code in the documentation which will be run - during the doc build. For example:: + during the doc build. For example: + + .. code:: rst .. ipython:: python x = 2 - x**3 + x ** 3 will be rendered as:: In [1]: x = 2 - In [2]: x**3 + In [2]: x ** 3 Out[2]: 8 Almost all code examples in the docs are run (and the output saved) during the doc build. This approach means that code examples will always be up to date, - but it does make the doc building a bit more complex. + but it does make building the docs a bit more complex. - Our API documentation in ``doc/api.rst`` houses the auto-generated documentation from the docstrings. For classes, there are a few subtleties @@ -290,7 +301,7 @@ Requirements Make sure to follow the instructions on :ref:`creating a development environment above `, but to build the docs you need to use the environment file ``ci/requirements/doc.yml``. -.. code-block:: none +.. code-block:: sh # Create and activate the docs environment conda env create -f ci/requirements/doc.yml @@ -313,7 +324,7 @@ Then you can find the HTML output in the folder ``xarray/doc/_build/html/``. The first time you build the docs, it will take quite a while because it has to run all the code examples and build all the generated docstring pages. In subsequent -evocations, sphinx will try to only build the pages that have been modified. +evocations, Sphinx will try to only build the pages that have been modified. If you want to do a full clean build, do:: @@ -347,34 +358,19 @@ Code Formatting xarray uses several tools to ensure a consistent code format throughout the project: -- `Black `_ for standardized code formatting +- `Black `_ for standardized + code formatting +- `blackdoc `_ for + standardized code formatting in documentation - `Flake8 `_ for general code quality - `isort `_ for standardized order in imports. See also `flake8-isort `_. - `mypy `_ for static type checking on `type hints `_ -``pip``:: - - pip install black flake8 isort mypy - -and then run from the root of the Xarray repository:: - - isort -rc . - black -t py36 . - flake8 - mypy . - -to auto-format your code. Additionally, many editors have plugins that will -apply ``black`` as you edit files. - -Optionally, you may wish to setup `pre-commit hooks `_ +We highly recommend that you setup `pre-commit hooks `_ to automatically run all the above tools every time you make a git commit. This -can be done by installing ``pre-commit``:: - - pip install pre-commit - -and then running:: +can be done by running:: pre-commit install @@ -396,12 +392,8 @@ Testing With Continuous Integration ----------------------------------- The *xarray* test suite runs automatically the -`Azure Pipelines `__, -continuous integration service, once your pull request is submitted. However, -if you wish to run the test suite on a branch prior to submitting the pull -request, then Azure Pipelines -`needs to be configured `_ -for your GitHub repository. +`GitHub Actions `__, +continuous integration service, once your pull request is submitted. A pull-request will be considered for merging when you have an all 'green' build. If any tests are failing, then you will get a red 'X', where you can click through to see the @@ -431,7 +423,7 @@ taken from the original GitHub issue. However, it is always worth considering a use cases and writing corresponding tests. Adding tests is one of the most common requests after code is pushed to *xarray*. Therefore, -it is worth getting in the habit of writing tests ahead of time so this is never an issue. +it is worth getting in the habit of writing tests ahead of time so that this is never an issue. Like many packages, *xarray* uses `pytest `_ and the convenient @@ -467,7 +459,7 @@ typically find tests wrapped in a class. .. code-block:: python class TestReallyCoolFeature: - .... + ... Going forward, we are moving to a more *functional* style using the `pytest `__ framework, which offers a richer @@ -477,7 +469,7 @@ writing test classes, we will write test functions like this: .. code-block:: python def test_really_cool_feature(): - .... + ... Using ``pytest`` ~~~~~~~~~~~~~~~~ @@ -508,17 +500,23 @@ We would name this file ``test_cool_feature.py`` and put in an appropriate place from xarray.testing import assert_equal - @pytest.mark.parametrize('dtype', ['int8', 'int16', 'int32', 'int64']) + @pytest.mark.parametrize("dtype", ["int8", "int16", "int32", "int64"]) def test_dtypes(dtype): assert str(np.dtype(dtype)) == dtype - @pytest.mark.parametrize('dtype', ['float32', - pytest.param('int16', marks=pytest.mark.skip), - pytest.param('int32', marks=pytest.mark.xfail( - reason='to show how it works'))]) + @pytest.mark.parametrize( + "dtype", + [ + "float32", + pytest.param("int16", marks=pytest.mark.skip), + pytest.param( + "int32", marks=pytest.mark.xfail(reason="to show how it works") + ), + ], + ) def test_mark(dtype): - assert str(np.dtype(dtype)) == 'float32' + assert str(np.dtype(dtype)) == "float32" @pytest.fixture @@ -526,7 +524,7 @@ We would name this file ``test_cool_feature.py`` and put in an appropriate place return xr.DataArray([1, 2, 3]) - @pytest.fixture(params=['int8', 'int16', 'int32', 'int64']) + @pytest.fixture(params=["int8", "int16", "int32", "int64"]) def dtype(request): return request.param @@ -610,7 +608,7 @@ need to install `pytest-xdist` via:: pip install pytest-xdist -Then, run pytest with the optional -n argument: +Then, run pytest with the optional -n argument:: pytest xarray -n 4 @@ -797,7 +795,7 @@ release. To submit a pull request: This request then goes to the repository maintainers, and they will review the code. If you need to make more changes, you can make them in your branch, add them to a new commit, push them to GitHub, and the pull request -will be automatically updated. Pushing them to GitHub again is done by:: +will automatically be updated. Pushing them to GitHub again is done by:: git push origin shiny-new-feature @@ -809,8 +807,7 @@ Delete your merged branch (optional) ------------------------------------ Once your feature branch is accepted into upstream, you'll probably want to get rid of -the branch. First, merge upstream master into your branch so git knows it is safe to -delete your branch:: +the branch. First, update your ``master`` branch to check that the merge was successful:: git fetch upstream git checkout master @@ -818,12 +815,14 @@ delete your branch:: Then you can do:: - git branch -d shiny-new-feature + git branch -D shiny-new-feature -Make sure you use a lower-case ``-d``, or else git won't warn you if your feature -branch has not actually been merged. +You need to use a upper-case ``-D`` because the branch was squashed into a +single commit before merging. Be careful with this because ``git`` won't warn +you if you accidentally delete an unmerged branch. -The branch will still exist on GitHub, so to delete it there do:: +If you didn't delete your branch using GitHub's interface, then it will still exist on +GitHub. To delete it there do:: git push origin --delete shiny-new-feature @@ -840,8 +839,7 @@ PR checklist - **Properly format your code** and verify that it passes the formatting guidelines set by `Black `_ and `Flake8 `_. See `"Code formatting" `_. You can use `pre-commit `_ to run these automatically on each commit. - - Run ``black .`` in the root directory. This may modify some files. Confirm and commit any formatting changes. - - Run ``flake8`` in the root directory. If this fails, it will log an error message. + - Run ``pre-commit run --all-files`` in the root directory. This may modify some files. Confirm and commit any formatting changes. - **Push your code and** `create a PR on GitHub `_. -- **Use a helpful title for your pull request** by summarizing the main contributions rather than using the latest commit message. If this addresses an `issue `_, please `reference it `_. +- **Use a helpful title for your pull request** by summarizing the main contributions rather than using the latest commit message. If the PR addresses an `issue `_, please `reference it `_. diff --git a/doc/dask.rst b/doc/dask.rst index 07b3939af6e..4844967350b 100644 --- a/doc/dask.rst +++ b/doc/dask.rst @@ -1,3 +1,5 @@ +.. currentmodule:: xarray + .. _dask: Parallel computing with Dask @@ -56,19 +58,26 @@ argument to :py:func:`~xarray.open_dataset` or using the import numpy as np import pandas as pd import xarray as xr + np.random.seed(123456) np.set_printoptions(precision=3, linewidth=100, threshold=100, edgeitems=3) - ds = xr.Dataset({'temperature': (('time', 'latitude', 'longitude'), - np.random.randn(30, 180, 180)), - 'time': pd.date_range('2015-01-01', periods=30), - 'longitude': np.arange(180), - 'latitude': np.arange(89.5, -90.5, -1)}) - ds.to_netcdf('example-data.nc') + ds = xr.Dataset( + { + "temperature": ( + ("time", "latitude", "longitude"), + np.random.randn(30, 180, 180), + ), + "time": pd.date_range("2015-01-01", periods=30), + "longitude": np.arange(180), + "latitude": np.arange(89.5, -90.5, -1), + } + ) + ds.to_netcdf("example-data.nc") .. ipython:: python - ds = xr.open_dataset('example-data.nc', chunks={'time': 10}) + ds = xr.open_dataset("example-data.nc", chunks={"time": 10}) ds In this example ``latitude`` and ``longitude`` do not appear in the ``chunks`` @@ -83,7 +92,7 @@ use :py:func:`~xarray.open_mfdataset`:: xr.open_mfdataset('my/files/*.nc', parallel=True) This function will automatically concatenate and merge datasets into one in -the simple cases that it understands (see :py:func:`~xarray.auto_combine` +the simple cases that it understands (see :py:func:`~xarray.combine_by_coords` for the full disclaimer). By default, :py:meth:`~xarray.open_mfdataset` will chunk each netCDF file into a single Dask array; again, supply the ``chunks`` argument to control the size of the resulting Dask arrays. In more complex cases, you can @@ -106,7 +115,7 @@ usual way. .. ipython:: python - ds.to_netcdf('manipulated-example-data.nc') + ds.to_netcdf("manipulated-example-data.nc") By setting the ``compute`` argument to ``False``, :py:meth:`~xarray.Dataset.to_netcdf` will return a ``dask.delayed`` object that can be computed later. @@ -114,8 +123,9 @@ will return a ``dask.delayed`` object that can be computed later. .. ipython:: python from dask.diagnostics import ProgressBar + # or distributed.progress when using the distributed scheduler - delayed_obj = ds.to_netcdf('manipulated-example-data.nc', compute=False) + delayed_obj = ds.to_netcdf("manipulated-example-data.nc", compute=False) with ProgressBar(): results = delayed_obj.compute() @@ -141,8 +151,9 @@ Dask DataFrames do not support multi-indexes so the coordinate variables from th :suppress: import os - os.remove('example-data.nc') - os.remove('manipulated-example-data.nc') + + os.remove("example-data.nc") + os.remove("manipulated-example-data.nc") Using Dask with xarray ---------------------- @@ -199,7 +210,7 @@ Dask arrays using the :py:meth:`~xarray.Dataset.persist` method: .. ipython:: python - ds = ds.persist() + ds = ds.persist() :py:meth:`~xarray.Dataset.persist` is particularly useful when using a distributed cluster because the data will be loaded into distributed memory @@ -224,11 +235,11 @@ sizes of Dask arrays is done with the :py:meth:`~xarray.Dataset.chunk` method: .. ipython:: python :suppress: - ds = ds.chunk({'time': 10}) + ds = ds.chunk({"time": 10}) .. ipython:: python - rechunked = ds.chunk({'latitude': 100, 'longitude': 100}) + rechunked = ds.chunk({"latitude": 100, "longitude": 100}) You can view the size of existing chunks on an array by viewing the :py:attr:`~xarray.Dataset.chunks` attribute: @@ -256,6 +267,7 @@ lazy Dask arrays, in the :ref:`xarray.ufuncs ` module: .. ipython:: python import xarray.ufuncs as xu + xu.sin(rechunked) To access Dask arrays directly, use the new @@ -274,12 +286,21 @@ loaded into Dask or not: .. _dask.automatic-parallelization: -Automatic parallelization -------------------------- +Automatic parallelization with ``apply_ufunc`` and ``map_blocks`` +----------------------------------------------------------------- Almost all of xarray's built-in operations work on Dask arrays. If you want to -use a function that isn't wrapped by xarray, one option is to extract Dask -arrays from xarray objects (``.data``) and use Dask directly. +use a function that isn't wrapped by xarray, and have it applied in parallel on +each block of your xarray object, you have three options: + +1. Extract Dask arrays from xarray objects (``.data``) and use Dask directly. +2. Use :py:func:`~xarray.apply_ufunc` to apply functions that consume and return NumPy arrays. +3. Use :py:func:`~xarray.map_blocks`, :py:meth:`Dataset.map_blocks` or :py:meth:`DataArray.map_blocks` + to apply functions that consume and return xarray objects. + + +``apply_ufunc`` +~~~~~~~~~~~~~~~ Another option is to use xarray's :py:func:`~xarray.apply_ufunc`, which can automate `embarrassingly parallel @@ -302,24 +323,32 @@ we use to calculate `Spearman's rank-correlation coefficient ` and @@ -453,15 +470,15 @@ dataset variables: .. ipython:: python - ds.rename({'temperature': 'temp', 'precipitation': 'precip'}) + ds.rename({"temperature": "temp", "precipitation": "precip"}) The related :py:meth:`~xarray.Dataset.swap_dims` method allows you do to swap dimension and non-dimension variables: .. ipython:: python - ds.coords['day'] = ('time', [6, 7, 8]) - ds.swap_dims({'time': 'day'}) + ds.coords["day"] = ("time", [6, 7, 8]) + ds.swap_dims({"time": "day"}) .. _coordinates: @@ -519,8 +536,8 @@ To convert back and forth between data and coordinates, you can use the .. ipython:: python ds.reset_coords() - ds.set_coords(['temperature', 'precipitation']) - ds['temperature'].reset_coords(drop=True) + ds.set_coords(["temperature", "precipitation"]) + ds["temperature"].reset_coords(drop=True) Notice that these operations skip coordinates with names given by dimensions, as used for indexing. This mostly because we are not entirely sure how to @@ -544,7 +561,7 @@ logic used for merging coordinates in arithmetic operations .. ipython:: python - alt = xr.Dataset(coords={'z': [10], 'lat': 0, 'lon': 0}) + alt = xr.Dataset(coords={"z": [10], "lat": 0, "lon": 0}) ds.coords.merge(alt.coords) The ``coords.merge`` method may be useful if you want to implement your own @@ -560,7 +577,7 @@ To convert a coordinate (or any ``DataArray``) into an actual .. ipython:: python - ds['time'].to_index() + ds["time"].to_index() A useful shortcut is the ``indexes`` property (on both ``DataArray`` and ``Dataset``), which lazily constructs a dictionary whose keys are given by each @@ -577,9 +594,10 @@ Xarray supports labeling coordinate values with a :py:class:`pandas.MultiIndex`: .. ipython:: python - midx = pd.MultiIndex.from_arrays([['R', 'R', 'V', 'V'], [.1, .2, .7, .9]], - names=('band', 'wn')) - mda = xr.DataArray(np.random.rand(4), coords={'spec': midx}, dims='spec') + midx = pd.MultiIndex.from_arrays( + [["R", "R", "V", "V"], [0.1, 0.2, 0.7, 0.9]], names=("band", "wn") + ) + mda = xr.DataArray(np.random.rand(4), coords={"spec": midx}, dims="spec") mda For convenience multi-index levels are directly accessible as "virtual" or @@ -587,8 +605,8 @@ For convenience multi-index levels are directly accessible as "virtual" or .. ipython:: python - mda['band'] - mda.wn + mda["band"] + mda.wn Indexing with multi-index levels is also possible using the ``sel`` method (see :ref:`multi-level indexing`). diff --git a/doc/duckarrays.rst b/doc/duckarrays.rst new file mode 100644 index 00000000000..ba13d5160ae --- /dev/null +++ b/doc/duckarrays.rst @@ -0,0 +1,65 @@ +.. currentmodule:: xarray + +Working with numpy-like arrays +============================== + +.. warning:: + + This feature should be considered experimental. Please report any bug you may find on + xarray’s github repository. + +Numpy-like arrays (:term:`duck array`) extend the :py:class:`numpy.ndarray` with +additional features, like propagating physical units or a different layout in memory. + +:py:class:`DataArray` and :py:class:`Dataset` objects can wrap these duck arrays, as +long as they satisfy certain conditions (see :ref:`internals.duck_arrays`). + +.. note:: + + For ``dask`` support see :ref:`dask`. + + +Missing features +---------------- +Most of the API does support :term:`duck array` objects, but there are a few areas where +the code will still cast to ``numpy`` arrays: + +- dimension coordinates, and thus all indexing operations: + + * :py:meth:`Dataset.sel` and :py:meth:`DataArray.sel` + * :py:meth:`Dataset.loc` and :py:meth:`DataArray.loc` + * :py:meth:`Dataset.drop_sel` and :py:meth:`DataArray.drop_sel` + * :py:meth:`Dataset.reindex`, :py:meth:`Dataset.reindex_like`, + :py:meth:`DataArray.reindex` and :py:meth:`DataArray.reindex_like`: duck arrays in + data variables and non-dimension coordinates won't be casted + +- functions and methods that depend on external libraries or features of ``numpy`` not + covered by ``__array_function__`` / ``__array_ufunc__``: + + * :py:meth:`Dataset.ffill` and :py:meth:`DataArray.ffill` (uses ``bottleneck``) + * :py:meth:`Dataset.bfill` and :py:meth:`DataArray.bfill` (uses ``bottleneck``) + * :py:meth:`Dataset.interp`, :py:meth:`Dataset.interp_like`, + :py:meth:`DataArray.interp` and :py:meth:`DataArray.interp_like` (uses ``scipy``): + duck arrays in data variables and non-dimension coordinates will be casted in + addition to not supporting duck arrays in dimension coordinates + * :py:meth:`Dataset.rolling_exp` and :py:meth:`DataArray.rolling_exp` (uses + ``numbagg``) + * :py:meth:`Dataset.rolling` and :py:meth:`DataArray.rolling` (uses internal functions + of ``numpy``) + * :py:meth:`Dataset.interpolate_na` and :py:meth:`DataArray.interpolate_na` (uses + :py:class:`numpy.vectorize`) + * :py:func:`apply_ufunc` with ``vectorize=True`` (uses :py:class:`numpy.vectorize`) + +- incompatibilities between different :term:`duck array` libraries: + + * :py:meth:`Dataset.chunk` and :py:meth:`DataArray.chunk`: this fails if the data was + not already chunked and the :term:`duck array` (e.g. a ``pint`` quantity) should + wrap the new ``dask`` array; changing the chunk sizes works. + + +Extensions using duck arrays +---------------------------- +Here's a list of libraries extending ``xarray`` to make working with wrapped duck arrays +easier: + +- `pint-xarray `_ diff --git a/doc/examples.rst b/doc/examples.rst index 1d48d29bcc5..102138b6e4e 100644 --- a/doc/examples.rst +++ b/doc/examples.rst @@ -2,7 +2,7 @@ Examples ======== .. toctree:: - :maxdepth: 2 + :maxdepth: 1 examples/weather-data examples/monthly-means @@ -15,7 +15,7 @@ Examples Using apply_ufunc ------------------ .. toctree:: - :maxdepth: 2 + :maxdepth: 1 examples/apply_ufunc_vectorize_1d diff --git a/doc/examples/apply_ufunc_vectorize_1d.ipynb b/doc/examples/apply_ufunc_vectorize_1d.ipynb index 6d18d48fdb5..a79a4868b63 100644 --- a/doc/examples/apply_ufunc_vectorize_1d.ipynb +++ b/doc/examples/apply_ufunc_vectorize_1d.ipynb @@ -333,7 +333,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now our function currently only works on one vector of data which is not so useful given our 3D dataset.\n", + "Now our function currently only works on one vector of data which is not so useful given our 3D dataset.\n", "Let's try passing the whole dataset. We add a `print` statement so we can see what our function receives." ] }, diff --git a/doc/examples/area_weighted_temperature.ipynb b/doc/examples/area_weighted_temperature.ipynb index 72876e3fc29..de705966583 100644 --- a/doc/examples/area_weighted_temperature.ipynb +++ b/doc/examples/area_weighted_temperature.ipynb @@ -106,7 +106,7 @@ "source": [ "### Creating weights\n", "\n", - "For a for a rectangular grid the cosine of the latitude is proportional to the grid cell area." + "For a rectangular grid the cosine of the latitude is proportional to the grid cell area." ] }, { diff --git a/doc/faq.rst b/doc/faq.rst index 576cec5c2b1..a2b8be47e06 100644 --- a/doc/faq.rst +++ b/doc/faq.rst @@ -4,11 +4,12 @@ Frequently Asked Questions ========================== .. ipython:: python - :suppress: + :suppress: import numpy as np import pandas as pd import xarray as xr + np.random.seed(123456) @@ -103,21 +104,21 @@ code fragment .. ipython:: python arr = xr.DataArray([1, 2, 3]) - pd.Series({'x': arr[0], 'mean': arr.mean(), 'std': arr.std()}) + pd.Series({"x": arr[0], "mean": arr.mean(), "std": arr.std()}) does not yield the pandas DataFrame we expected. We need to specify the type conversion ourselves: .. ipython:: python - pd.Series({'x': arr[0], 'mean': arr.mean(), 'std': arr.std()}, dtype=float) + pd.Series({"x": arr[0], "mean": arr.mean(), "std": arr.std()}, dtype=float) Alternatively, we could use the ``item`` method or the ``float`` constructor to convert values one at a time .. ipython:: python - pd.Series({'x': arr[0].item(), 'mean': float(arr.mean())}) + pd.Series({"x": arr[0].item(), "mean": float(arr.mean())}) .. _approach to metadata: diff --git a/doc/gallery/README.txt b/doc/gallery/README.txt index b17f803696b..63f7d477cf4 100644 --- a/doc/gallery/README.txt +++ b/doc/gallery/README.txt @@ -2,4 +2,3 @@ Gallery ======= - diff --git a/doc/groupby.rst b/doc/groupby.rst index 223185bd0d5..d0c0b1849f9 100644 --- a/doc/groupby.rst +++ b/doc/groupby.rst @@ -26,11 +26,12 @@ Split Let's create a simple example dataset: .. ipython:: python - :suppress: + :suppress: import numpy as np import pandas as pd import xarray as xr + np.random.seed(123456) .. ipython:: python @@ -47,20 +48,20 @@ use a DataArray directly), we get back a ``GroupBy`` object: .. ipython:: python - ds.groupby('letters') + ds.groupby("letters") This object works very similarly to a pandas GroupBy object. You can view the group indices with the ``groups`` attribute: .. ipython:: python - ds.groupby('letters').groups + ds.groupby("letters").groups You can also iterate over groups in ``(label, group)`` pairs: .. ipython:: python - list(ds.groupby('letters')) + list(ds.groupby("letters")) Just like in pandas, creating a GroupBy object is cheap: it does not actually split the data until you access particular values. @@ -75,8 +76,8 @@ a customized coordinate, but xarray facilitates this via the .. ipython:: python - x_bins = [0,25,50] - ds.groupby_bins('x', x_bins).groups + x_bins = [0, 25, 50] + ds.groupby_bins("x", x_bins).groups The binning is implemented via :func:`pandas.cut`, whose documentation details how the bins are assigned. As seen in the example above, by default, the bins are @@ -86,8 +87,8 @@ choose `float` labels which identify the bin centers: .. ipython:: python - x_bin_labels = [12.5,37.5] - ds.groupby_bins('x', x_bins, labels=x_bin_labels).groups + x_bin_labels = [12.5, 37.5] + ds.groupby_bins("x", x_bins, labels=x_bin_labels).groups Apply @@ -102,7 +103,8 @@ concatenated back together along the group axis: def standardize(x): return (x - x.mean()) / x.std() - arr.groupby('letters').map(standardize) + + arr.groupby("letters").map(standardize) GroupBy objects also have a :py:meth:`~xarray.core.groupby.DatasetGroupBy.reduce` method and methods like :py:meth:`~xarray.core.groupby.DatasetGroupBy.mean` as shortcuts for applying an @@ -110,19 +112,19 @@ aggregation function: .. ipython:: python - arr.groupby('letters').mean(dim='x') + arr.groupby("letters").mean(dim="x") Using a groupby is thus also a convenient shortcut for aggregating over all dimensions *other than* the provided one: .. ipython:: python - ds.groupby('x').std(...) + ds.groupby("x").std(...) .. note:: We use an ellipsis (`...`) here to indicate we want to reduce over all - other dimensions + other dimensions First and last @@ -134,7 +136,7 @@ values for group along the grouped dimension: .. ipython:: python - ds.groupby('letters').first(...) + ds.groupby("letters").first(...) By default, they skip missing values (control this with ``skipna``). @@ -149,9 +151,9 @@ coordinates. For example: .. ipython:: python - alt = arr.groupby('letters').mean(...) + alt = arr.groupby("letters").mean(...) alt - ds.groupby('letters') - alt + ds.groupby("letters") - alt This last line is roughly equivalent to the following:: @@ -169,11 +171,11 @@ the ``squeeze`` parameter: .. ipython:: python - next(iter(arr.groupby('x'))) + next(iter(arr.groupby("x"))) .. ipython:: python - next(iter(arr.groupby('x', squeeze=False))) + next(iter(arr.groupby("x", squeeze=False))) Although xarray will attempt to automatically :py:attr:`~xarray.DataArray.transpose` dimensions back into their original order @@ -197,13 +199,17 @@ __ http://cfconventions.org/cf-conventions/v1.6.0/cf-conventions.html#_two_dimen .. ipython:: python - da = xr.DataArray([[0,1],[2,3]], - coords={'lon': (['ny','nx'], [[30,40],[40,50]] ), - 'lat': (['ny','nx'], [[10,10],[20,20]] ),}, - dims=['ny','nx']) + da = xr.DataArray( + [[0, 1], [2, 3]], + coords={ + "lon": (["ny", "nx"], [[30, 40], [40, 50]]), + "lat": (["ny", "nx"], [[10, 10], [20, 20]]), + }, + dims=["ny", "nx"], + ) da - da.groupby('lon').sum(...) - da.groupby('lon').map(lambda x: x - x.mean(), shortcut=False) + da.groupby("lon").sum(...) + da.groupby("lon").map(lambda x: x - x.mean(), shortcut=False) Because multidimensional groups have the ability to generate a very large number of bins, coarse-binning via :py:meth:`~xarray.Dataset.groupby_bins` @@ -211,13 +217,13 @@ may be desirable: .. ipython:: python - da.groupby_bins('lon', [0,45,50]).sum() + da.groupby_bins("lon", [0, 45, 50]).sum() These methods group by `lon` values. It is also possible to groupby each -cell in a grid, regardless of value, by stacking multiple dimensions, +cell in a grid, regardless of value, by stacking multiple dimensions, applying your function, and then unstacking the result: .. ipython:: python - stacked = da.stack(gridcell=['ny', 'nx']) - stacked.groupby('gridcell').sum(...).unstack('gridcell') + stacked = da.stack(gridcell=["ny", "nx"]) + stacked.groupby("gridcell").sum(...).unstack("gridcell") diff --git a/doc/howdoi.rst b/doc/howdoi.rst index 84c0c786027..3604d66bd0c 100644 --- a/doc/howdoi.rst +++ b/doc/howdoi.rst @@ -59,4 +59,3 @@ How do I ... - ``obj.dt.ceil``, ``obj.dt.floor``, ``obj.dt.round``. See :ref:`dt_accessor` for more. * - make a mask that is ``True`` where an object contains any of the values in a array - :py:meth:`Dataset.isin`, :py:meth:`DataArray.isin` - diff --git a/doc/index.rst b/doc/index.rst index 972eb0a732e..ee44d0ad4d9 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -60,6 +60,7 @@ Documentation * :doc:`io` * :doc:`dask` * :doc:`plotting` +* :doc:`duckarrays` .. toctree:: :maxdepth: 1 @@ -80,6 +81,7 @@ Documentation io dask plotting + duckarrays **Help & reference** @@ -107,6 +109,7 @@ Documentation See also -------- +- `Xarray's Tutorial`_ presented at the 2020 SciPy Conference (`video recording`_). - Stephan Hoyer and Joe Hamman's `Journal of Open Research Software paper`_ describing the xarray project. - The `UW eScience Institute's Geohackweek`_ tutorial on xarray for geospatial data scientists. - Stephan Hoyer's `SciPy2015 talk`_ introducing xarray to a general audience. @@ -114,6 +117,8 @@ See also xarray to users familiar with netCDF. - `Nicolas Fauchereau's tutorial`_ on xarray for netCDF users. +.. _Xarray's Tutorial: https://xarray-contrib.github.io/xarray-tutorial/ +.. _video recording: https://youtu.be/mecN-Ph_-78 .. _Journal of Open Research Software paper: http://doi.org/10.5334/jors.148 .. _UW eScience Institute's Geohackweek : https://geohackweek.github.io/nDarrays/ .. _SciPy2015 talk: https://www.youtube.com/watch?v=X0pAhJgySxk diff --git a/doc/indexing.rst b/doc/indexing.rst index cfbb84a8343..78766b8fd81 100644 --- a/doc/indexing.rst +++ b/doc/indexing.rst @@ -4,11 +4,12 @@ Indexing and selecting data =========================== .. ipython:: python - :suppress: + :suppress: import numpy as np import pandas as pd import xarray as xr + np.random.seed(123456) xarray offers extremely flexible indexing routines that combine the best @@ -60,9 +61,13 @@ DataArray: .. ipython:: python - da = xr.DataArray(np.random.rand(4, 3), - [('time', pd.date_range('2000-01-01', periods=4)), - ('space', ['IA', 'IL', 'IN'])]) + da = xr.DataArray( + np.random.rand(4, 3), + [ + ("time", pd.date_range("2000-01-01", periods=4)), + ("space", ["IA", "IL", "IN"]), + ], + ) da[:2] da[0, 0] da[:, [2, 1]] @@ -81,7 +86,7 @@ fast. To do label based indexing, use the :py:attr:`~xarray.DataArray.loc` attri .. ipython:: python - da.loc['2000-01-01':'2000-01-02', 'IA'] + da.loc["2000-01-01":"2000-01-02", "IA"] In this example, the selected is a subpart of the array in the range '2000-01-01':'2000-01-02' along the first coordinate `time` @@ -98,7 +103,7 @@ Setting values with label based indexing is also supported: .. ipython:: python - da.loc['2000-01-01', ['IL', 'IN']] = -10 + da.loc["2000-01-01", ["IL", "IN"]] = -10 da @@ -117,7 +122,7 @@ use them explicitly to slice data. There are two ways to do this: da[dict(space=0, time=slice(None, 2))] # index by dimension coordinate labels - da.loc[dict(time=slice('2000-01-01', '2000-01-02'))] + da.loc[dict(time=slice("2000-01-01", "2000-01-02"))] 2. Use the :py:meth:`~xarray.DataArray.sel` and :py:meth:`~xarray.DataArray.isel` convenience methods: @@ -128,7 +133,7 @@ use them explicitly to slice data. There are two ways to do this: da.isel(space=0, time=slice(None, 2)) # index by dimension coordinate labels - da.sel(time=slice('2000-01-01', '2000-01-02')) + da.sel(time=slice("2000-01-01", "2000-01-02")) The arguments to these methods can be any objects that could index the array along the dimension given by the keyword, e.g., labels for an individual value, @@ -156,16 +161,16 @@ enabling nearest neighbor (inexact) lookups by use of the methods ``'pad'``, .. ipython:: python - da = xr.DataArray([1, 2, 3], [('x', [0, 1, 2])]) - da.sel(x=[1.1, 1.9], method='nearest') - da.sel(x=0.1, method='backfill') - da.reindex(x=[0.5, 1, 1.5, 2, 2.5], method='pad') + da = xr.DataArray([1, 2, 3], [("x", [0, 1, 2])]) + da.sel(x=[1.1, 1.9], method="nearest") + da.sel(x=0.1, method="backfill") + da.reindex(x=[0.5, 1, 1.5, 2, 2.5], method="pad") Tolerance limits the maximum distance for valid matches with an inexact lookup: .. ipython:: python - da.reindex(x=[1.1, 1.5], method='nearest', tolerance=0.2) + da.reindex(x=[1.1, 1.5], method="nearest", tolerance=0.2) The method parameter is not yet supported if any of the arguments to ``.sel()`` is a ``slice`` object: @@ -173,7 +178,7 @@ to ``.sel()`` is a ``slice`` object: .. ipython:: :verbatim: - In [1]: da.sel(x=slice(1, 3), method='nearest') + In [1]: da.sel(x=slice(1, 3), method="nearest") NotImplementedError However, you don't need to use ``method`` to do inexact slicing. Slicing @@ -182,15 +187,15 @@ labels are monotonic increasing: .. ipython:: python - da.sel(x=slice(0.9, 3.1)) + da.sel(x=slice(0.9, 3.1)) Indexing axes with monotonic decreasing labels also works, as long as the ``slice`` or ``.loc`` arguments are also decreasing: .. ipython:: python - reversed_da = da[::-1] - reversed_da.loc[3.1:0.9] + reversed_da = da[::-1] + reversed_da.loc[3.1:0.9] .. note:: @@ -227,7 +232,7 @@ arrays). However, you can do normal indexing with dimension names: .. ipython:: python ds[dict(space=[0], time=[0])] - ds.loc[dict(time='2000-01-01')] + ds.loc[dict(time="2000-01-01")] Using indexing to *assign* values to a subset of dataset (e.g., ``ds[dict(space=0)] = 1``) is not yet supported. @@ -240,7 +245,7 @@ index labels along a dimension dropped: .. ipython:: python - ds.drop_sel(space=['IN', 'IL']) + ds.drop_sel(space=["IN", "IL"]) ``drop_sel`` is both a ``Dataset`` and ``DataArray`` method. @@ -249,7 +254,7 @@ Any variables with these dimensions are also dropped: .. ipython:: python - ds.drop_dims('time') + ds.drop_dims("time") .. _masking with where: @@ -263,7 +268,7 @@ xarray, use :py:meth:`~xarray.DataArray.where`: .. ipython:: python - da = xr.DataArray(np.arange(16).reshape(4, 4), dims=['x', 'y']) + da = xr.DataArray(np.arange(16).reshape(4, 4), dims=["x", "y"]) da.where(da.x + da.y < 4) This is particularly useful for ragged indexing of multi-dimensional data, @@ -296,7 +301,7 @@ multiple values, use :py:meth:`~xarray.DataArray.isin`: .. ipython:: python - da = xr.DataArray([1, 2, 3, 4, 5], dims=['x']) + da = xr.DataArray([1, 2, 3, 4, 5], dims=["x"]) da.isin([2, 4]) :py:meth:`~xarray.DataArray.isin` works particularly well with @@ -305,7 +310,7 @@ already labels of an array: .. ipython:: python - lookup = xr.DataArray([-1, -2, -3, -4, -5], dims=['x']) + lookup = xr.DataArray([-1, -2, -3, -4, -5], dims=["x"]) da.where(lookup.isin([-2, -4]), drop=True) However, some caution is in order: when done repeatedly, this type of indexing @@ -328,14 +333,13 @@ MATLAB, or after using the :py:func:`numpy.ix_` helper: .. ipython:: python - da = xr.DataArray( np.arange(12).reshape((3, 4)), dims=["x", "y"], coords={"x": [0, 1, 2], "y": ["a", "b", "c", "d"]}, ) da - da[[0, 1], [1, 1]] + da[[0, 2, 2], [1, 3]] For more flexibility, you can supply :py:meth:`~xarray.DataArray` objects as indexers. @@ -344,8 +348,8 @@ dimensions: .. ipython:: python - ind_x = xr.DataArray([0, 1], dims=['x']) - ind_y = xr.DataArray([0, 1], dims=['y']) + ind_x = xr.DataArray([0, 1], dims=["x"]) + ind_y = xr.DataArray([0, 1], dims=["y"]) da[ind_x, ind_y] # orthogonal indexing da[ind_x, ind_x] # vectorized indexing @@ -364,7 +368,7 @@ indexers' dimension: .. ipython:: python - ind = xr.DataArray([[0, 1], [0, 1]], dims=['a', 'b']) + ind = xr.DataArray([[0, 1], [0, 1]], dims=["a", "b"]) da[ind] Similar to how NumPy's `advanced indexing`_ works, vectorized @@ -378,18 +382,18 @@ Vectorized indexing also works with ``isel``, ``loc``, and ``sel``: .. ipython:: python - ind = xr.DataArray([[0, 1], [0, 1]], dims=['a', 'b']) + ind = xr.DataArray([[0, 1], [0, 1]], dims=["a", "b"]) da.isel(y=ind) # same as da[:, ind] - ind = xr.DataArray([['a', 'b'], ['b', 'a']], dims=['a', 'b']) + ind = xr.DataArray([["a", "b"], ["b", "a"]], dims=["a", "b"]) da.loc[:, ind] # same as da.sel(y=ind) These methods may also be applied to ``Dataset`` objects .. ipython:: python - ds = da.to_dataset(name='bar') - ds.isel(x=xr.DataArray([0, 1, 2], dims=['points'])) + ds = da.to_dataset(name="bar") + ds.isel(x=xr.DataArray([0, 1, 2], dims=["points"])) .. tip:: @@ -476,8 +480,8 @@ Like ``numpy.ndarray``, value assignment sometimes works differently from what o .. ipython:: python - da = xr.DataArray([0, 1, 2, 3], dims=['x']) - ind = xr.DataArray([0, 0, 0], dims=['x']) + da = xr.DataArray([0, 1, 2, 3], dims=["x"]) + ind = xr.DataArray([0, 0, 0], dims=["x"]) da[ind] -= 1 da @@ -511,7 +515,7 @@ __ https://docs.scipy.org/doc/numpy/user/basics.indexing.html#assigning-values-t .. ipython:: python - da = xr.DataArray([0, 1, 2, 3], dims=['x']) + da = xr.DataArray([0, 1, 2, 3], dims=["x"]) # DO NOT do this da.isel(x=[0, 1, 2])[1] = -1 da @@ -544,7 +548,7 @@ you can supply a :py:class:`~xarray.DataArray` with a coordinate, x=xr.DataArray([0, 1, 6], dims="z", coords={"z": ["a", "b", "c"]}), y=xr.DataArray([0, 1, 0], dims="z"), ) - + Analogously, label-based pointwise-indexing is also possible by the ``.sel`` method: @@ -581,15 +585,15 @@ To reindex a particular dimension, use :py:meth:`~xarray.DataArray.reindex`: .. ipython:: python - da.reindex(space=['IA', 'CA']) + da.reindex(space=["IA", "CA"]) The :py:meth:`~xarray.DataArray.reindex_like` method is a useful shortcut. To demonstrate, we will make a subset DataArray with new values: .. ipython:: python - foo = da.rename('foo') - baz = (10 * da[:2, :2]).rename('baz') + foo = da.rename("foo") + baz = (10 * da[:2, :2]).rename("baz") baz Reindexing ``foo`` with ``baz`` selects out the first two values along each @@ -611,8 +615,8 @@ The :py:func:`~xarray.align` function lets us perform more flexible database-lik .. ipython:: python - xr.align(foo, baz, join='inner') - xr.align(foo, baz, join='outer') + xr.align(foo, baz, join="inner") + xr.align(foo, baz, join="outer") Both ``reindex_like`` and ``align`` work interchangeably between :py:class:`~xarray.DataArray` and :py:class:`~xarray.Dataset` objects, and with any number of matching dimension names: @@ -621,7 +625,7 @@ Both ``reindex_like`` and ``align`` work interchangeably between ds ds.reindex_like(baz) - other = xr.DataArray(['a', 'b', 'c'], dims='other') + other = xr.DataArray(["a", "b", "c"], dims="other") # this is a no-op, because there are no shared dimension names ds.reindex_like(other) @@ -636,7 +640,7 @@ integer-based indexing as a fallback for dimensions without a coordinate label: .. ipython:: python - da = xr.DataArray([1, 2, 3], dims='x') + da = xr.DataArray([1, 2, 3], dims="x") da.sel(x=[0, -1]) Alignment between xarray objects where one or both do not have coordinate labels @@ -675,9 +679,9 @@ labels: .. ipython:: python - da = xr.DataArray([1, 2, 3], dims='x') + da = xr.DataArray([1, 2, 3], dims="x") da - da.get_index('x') + da.get_index("x") .. _copies_vs_views: @@ -721,7 +725,6 @@ pandas: .. ipython:: python - midx = pd.MultiIndex.from_product([list("abc"), [0, 1]], names=("one", "two")) mda = xr.DataArray(np.random.rand(6, 3), [("x", midx), ("y", range(3))]) mda @@ -732,20 +735,20 @@ a slice of tuples: .. ipython:: python - mda.sel(x=[('a', 0), ('b', 1)]) + mda.sel(x=[("a", 0), ("b", 1)]) Additionally, xarray supports dictionaries: .. ipython:: python - mda.sel(x={'one': 'a', 'two': 0}) + mda.sel(x={"one": "a", "two": 0}) For convenience, ``sel`` also accepts multi-index levels directly as keyword arguments: .. ipython:: python - mda.sel(one='a', two=0) + mda.sel(one="a", two=0) Note that using ``sel`` it is not possible to mix a dimension indexer with level indexers for that dimension @@ -757,7 +760,7 @@ multi-index is reduced to a single index. .. ipython:: python - mda.loc[{'one': 'a'}, ...] + mda.loc[{"one": "a"}, ...] Unlike pandas, xarray does not guess whether you provide index levels or dimensions when using ``loc`` in some ambiguous cases. For example, for diff --git a/doc/installing.rst b/doc/installing.rst index a25bf65e342..99b8b621aed 100644 --- a/doc/installing.rst +++ b/doc/installing.rst @@ -6,8 +6,8 @@ Installation Required dependencies --------------------- -- Python (3.6 or later) -- setuptools +- Python (3.7 or later) +- setuptools (40.4 or later) - `numpy `__ (1.15 or later) - `pandas `__ (0.25 or later) @@ -16,6 +16,12 @@ Required dependencies Optional dependencies --------------------- +.. note:: + + If you are using pip to install xarray, optional dependencies can be installed by + specifying *extras*. :ref:`installation-instructions` for both pip and conda + are given below. + For netCDF and IO ~~~~~~~~~~~~~~~~~ @@ -25,8 +31,9 @@ For netCDF and IO - `pydap `__: used as a fallback for accessing OPeNDAP - `h5netcdf `__: an alternative library for reading and writing netCDF4 files that does not use the netCDF-C libraries -- `pynio `__: for reading GRIB and other - geoscience specific file formats. Note that pynio is not available for Windows. +- `PyNIO `__: for reading GRIB and other + geoscience specific file formats. Note that PyNIO is not available for Windows and + that the PyNIO backend may be moved outside of xarray in the future. - `zarr `__: for chunked, compressed, N-dimensional arrays. - `cftime `__: recommended if you want to encode/decode datetimes for non-standard calendars or dates before @@ -93,16 +100,16 @@ dependencies: - **Python:** 42 months (`NEP-29 `_) +- **setuptools:** 42 months (but no older than 40.4) - **numpy:** 24 months (`NEP-29 `_) -- **pandas:** 12 months -- **scipy:** 12 months +- **dask and dask.distributed:** 12 months (but no older than 2.9) - **sparse, pint** and other libraries that rely on `NEP-18 `_ for integration: very latest available versions only, until the technology will have matured. This extends to dask when used in conjunction with any of these libraries. numpy >=1.17. -- **all other libraries:** 6 months +- **all other libraries:** 12 months The above should be interpreted as *the minor version (X.Y) initially published no more than N months ago*. Patch versions (x.y.Z) are not pinned, and only the latest available @@ -111,10 +118,11 @@ at the moment of publishing the xarray release is guaranteed to work. You can see the actual minimum tested versions: - `For NEP-18 libraries - `_ + `_ - `For everything else - `_ + `_ +.. _installation-instructions: Instructions ------------ @@ -138,6 +146,26 @@ pandas) installed first. Then, install xarray with pip:: $ pip install xarray +We also maintain other dependency sets for different subsets of functionality:: + + $ pip install "xarray[io]" # Install optional dependencies for handling I/O + $ pip install "xarray[accel]" # Install optional dependencies for accelerating xarray + $ pip install "xarray[parallel]" # Install optional dependencies for dask arrays + $ pip install "xarray[viz]" # Install optional dependencies for visualization + $ pip install "xarray[complete]" # Install all the above + +The above commands should install most of the `optional dependencies`_. However, +some packages which are either not listed on PyPI or require extra +installation steps are excluded. To know which dependencies would be +installed, take a look at the ``[options.extras_require]`` section in +``setup.cfg``: + +.. literalinclude:: ../setup.cfg + :language: ini + :start-at: [options.extras_require] + :end-before: [options.package_data] + + Testing ------- diff --git a/doc/internals.rst b/doc/internals.rst index a4870f2316a..60d32128c60 100644 --- a/doc/internals.rst +++ b/doc/internals.rst @@ -42,15 +42,49 @@ xarray objects via the (readonly) :py:attr:`Dataset.variables ` and :py:attr:`DataArray.variable ` attributes. + +.. _internals.duck_arrays: + +Integrating with duck arrays +---------------------------- + +.. warning:: + + This is a experimental feature. + +xarray can wrap custom :term:`duck array` objects as long as they define numpy's +``shape``, ``dtype`` and ``ndim`` properties and the ``__array__``, +``__array_ufunc__`` and ``__array_function__`` methods. + +In certain situations (e.g. when printing the collapsed preview of +variables of a ``Dataset``), xarray will display the repr of a :term:`duck array` +in a single line, truncating it to a certain number of characters. If that +would drop too much information, the :term:`duck array` may define a +``_repr_inline_`` method that takes ``max_width`` (number of characters) as an +argument: + +.. code:: python + + class MyDuckArray: + ... + + def _repr_inline_(self, max_width): + """ format to a single line with at most max_width characters """ + ... + + ... + + Extending xarray ---------------- .. ipython:: python - :suppress: + :suppress: import numpy as np import pandas as pd import xarray as xr + np.random.seed(123456) xarray is designed as a general purpose library, and hence tries to avoid @@ -82,16 +116,21 @@ xarray: .. literalinclude:: examples/_code/accessor_example.py +In general, the only restriction on the accessor class is that the ``__init__`` method +must have a single parameter: the ``Dataset`` or ``DataArray`` object it is supposed +to work on. + This achieves the same result as if the ``Dataset`` class had a cached property defined that returns an instance of your class: .. code-block:: python - class Dataset: - ... - @property - def geo(self) - return GeoAccessor(self) + class Dataset: + ... + + @property + def geo(self): + return GeoAccessor(self) However, using the register accessor decorators is preferable to simply adding your own ad-hoc property (i.e., ``Dataset.geo = property(...)``), for several @@ -116,14 +155,13 @@ reasons: Back in an interactive IPython session, we can use these properties: .. ipython:: python - :suppress: + :suppress: - exec(open("examples/_code/accessor_example.py").read()) + exec(open("examples/_code/accessor_example.py").read()) .. ipython:: python - ds = xr.Dataset({'longitude': np.linspace(0, 10), - 'latitude': np.linspace(0, 20)}) + ds = xr.Dataset({"longitude": np.linspace(0, 10), "latitude": np.linspace(0, 20)}) ds.geo.center ds.geo.plot() @@ -137,3 +175,59 @@ To help users keep things straight, please `let us know `_ if you plan to write a new accessor for an open source library. In the future, we will maintain a list of accessors and the libraries that implement them on this page. + +To make documenting accessors with ``sphinx`` and ``sphinx.ext.autosummary`` +easier, you can use `sphinx-autosummary-accessors`_. + +.. _sphinx-autosummary-accessors: https://sphinx-autosummary-accessors.readthedocs.io/ + +.. _zarr_encoding: + +Zarr Encoding Specification +--------------------------- + +In implementing support for the `Zarr `_ storage +format, Xarray developers made some *ad hoc* choices about how to store +NetCDF data in Zarr. +Future versions of the Zarr spec will likely include a more formal convention +for the storage of the NetCDF data model in Zarr; see +`Zarr spec repo `_ for ongoing +discussion. + +First, Xarray can only read and write Zarr groups. There is currently no support +for reading / writting individual Zarr arrays. Zarr groups are mapped to +Xarray ``Dataset`` objects. + +Second, from Xarray's point of view, the key difference between +NetCDF and Zarr is that all NetCDF arrays have *dimension names* while Zarr +arrays do not. Therefore, in order to store NetCDF data in Zarr, Xarray must +somehow encode and decode the name of each array's dimensions. + +To accomplish this, Xarray developers decided to define a special Zarr array +attribute: ``_ARRAY_DIMENSIONS``. The value of this attribute is a list of +dimension names (strings), for example ``["time", "lon", "lat"]``. When writing +data to Zarr, Xarray sets this attribute on all variables based on the variable +dimensions. When reading a Zarr group, Xarray looks for this attribute on all +arrays, raising an error if it can't be found. The attribute is used to define +the variable dimension names and then removed from the attributes dictionary +returned to the user. + +Because of these choices, Xarray cannot read arbitrary array data, but only +Zarr data with valid ``_ARRAY_DIMENSIONS`` attributes on each array. + +After decoding the ``_ARRAY_DIMENSIONS`` attribute and assigning the variable +dimensions, Xarray proceeds to [optionally] decode each variable using its +standard CF decoding machinery used for NetCDF data (see :py:func:`decode_cf`). + +As a concrete example, here we write a tutorial dataset to Zarr and then +re-open it directly with Zarr: + +.. ipython:: python + + ds = xr.tutorial.load_dataset("rasm") + ds.to_zarr("rasm.zarr", mode="w") + import zarr + + zgroup = zarr.open("rasm.zarr") + print(zgroup.tree()) + dict(zgroup["Tair"].attrs) diff --git a/doc/interpolation.rst b/doc/interpolation.rst index 4cf39807e5a..9a3b7a7ee2d 100644 --- a/doc/interpolation.rst +++ b/doc/interpolation.rst @@ -4,11 +4,12 @@ Interpolating data ================== .. ipython:: python - :suppress: + :suppress: import numpy as np import pandas as pd import xarray as xr + np.random.seed(123456) xarray offers flexible interpolation routines, which have a similar interface @@ -27,9 +28,10 @@ indexing of a :py:class:`~xarray.DataArray`, .. ipython:: python - da = xr.DataArray(np.sin(0.3 * np.arange(12).reshape(4, 3)), - [('time', np.arange(4)), - ('space', [0.1, 0.2, 0.3])]) + da = xr.DataArray( + np.sin(0.3 * np.arange(12).reshape(4, 3)), + [("time", np.arange(4)), ("space", [0.1, 0.2, 0.3])], + ) # label lookup da.sel(time=3) @@ -52,20 +54,21 @@ To interpolate data with a :py:doc:`numpy.datetime64 .. ipython:: python - da_dt64 = xr.DataArray([1, 3], - [('time', pd.date_range('1/1/2000', '1/3/2000', periods=2))]) - da_dt64.interp(time='2000-01-02') + da_dt64 = xr.DataArray( + [1, 3], [("time", pd.date_range("1/1/2000", "1/3/2000", periods=2))] + ) + da_dt64.interp(time="2000-01-02") The interpolated data can be merged into the original :py:class:`~xarray.DataArray` by specifying the time periods required. .. ipython:: python - da_dt64.interp(time=pd.date_range('1/1/2000', '1/3/2000', periods=3)) + da_dt64.interp(time=pd.date_range("1/1/2000", "1/3/2000", periods=3)) Interpolation of data indexed by a :py:class:`~xarray.CFTimeIndex` is also allowed. See :ref:`CFTimeIndex` for examples. - + .. note:: Currently, our interpolation only works for regular grids. @@ -108,9 +111,10 @@ different coordinates, .. ipython:: python - other = xr.DataArray(np.sin(0.4 * np.arange(9).reshape(3, 3)), - [('time', [0.9, 1.9, 2.9]), - ('space', [0.15, 0.25, 0.35])]) + other = xr.DataArray( + np.sin(0.4 * np.arange(9).reshape(3, 3)), + [("time", [0.9, 1.9, 2.9]), ("space", [0.15, 0.25, 0.35])], + ) it might be a good idea to first interpolate ``da`` so that it will stay on the same coordinates of ``other``, and then subtract it. @@ -118,9 +122,9 @@ same coordinates of ``other``, and then subtract it. .. ipython:: python - # interpolate da along other's coordinates - interpolated = da.interp_like(other) - interpolated + # interpolate da along other's coordinates + interpolated = da.interp_like(other) + interpolated It is now possible to safely compute the difference ``other - interpolated``. @@ -135,12 +139,15 @@ The interpolation method can be specified by the optional ``method`` argument. .. ipython:: python - da = xr.DataArray(np.sin(np.linspace(0, 2 * np.pi, 10)), dims='x', - coords={'x': np.linspace(0, 1, 10)}) + da = xr.DataArray( + np.sin(np.linspace(0, 2 * np.pi, 10)), + dims="x", + coords={"x": np.linspace(0, 1, 10)}, + ) - da.plot.line('o', label='original') - da.interp(x=np.linspace(0, 1, 100)).plot.line(label='linear (default)') - da.interp(x=np.linspace(0, 1, 100), method='cubic').plot.line(label='cubic') + da.plot.line("o", label="original") + da.interp(x=np.linspace(0, 1, 100)).plot.line(label="linear (default)") + da.interp(x=np.linspace(0, 1, 100), method="cubic").plot.line(label="cubic") @savefig interpolation_sample1.png width=4in plt.legend() @@ -149,15 +156,16 @@ Additional keyword arguments can be passed to scipy's functions. .. ipython:: python # fill 0 for the outside of the original coordinates. - da.interp(x=np.linspace(-0.5, 1.5, 10), kwargs={'fill_value': 0.0}) + da.interp(x=np.linspace(-0.5, 1.5, 10), kwargs={"fill_value": 0.0}) # 1-dimensional extrapolation - da.interp(x=np.linspace(-0.5, 1.5, 10), kwargs={'fill_value': 'extrapolate'}) + da.interp(x=np.linspace(-0.5, 1.5, 10), kwargs={"fill_value": "extrapolate"}) # multi-dimensional extrapolation - da = xr.DataArray(np.sin(0.3 * np.arange(12).reshape(4, 3)), - [('time', np.arange(4)), - ('space', [0.1, 0.2, 0.3])]) + da = xr.DataArray( + np.sin(0.3 * np.arange(12).reshape(4, 3)), + [("time", np.arange(4)), ("space", [0.1, 0.2, 0.3])], + ) - da.interp(time=4, space=np.linspace(-0.1, 0.5, 10), kwargs={'fill_value': None}) + da.interp(time=4, space=np.linspace(-0.1, 0.5, 10), kwargs={"fill_value": None}) Advanced Interpolation @@ -181,17 +189,18 @@ For example: .. ipython:: python - da = xr.DataArray(np.sin(0.3 * np.arange(20).reshape(5, 4)), - [('x', np.arange(5)), - ('y', [0.1, 0.2, 0.3, 0.4])]) + da = xr.DataArray( + np.sin(0.3 * np.arange(20).reshape(5, 4)), + [("x", np.arange(5)), ("y", [0.1, 0.2, 0.3, 0.4])], + ) # advanced indexing - x = xr.DataArray([0, 2, 4], dims='z') - y = xr.DataArray([0.1, 0.2, 0.3], dims='z') + x = xr.DataArray([0, 2, 4], dims="z") + y = xr.DataArray([0.1, 0.2, 0.3], dims="z") da.sel(x=x, y=y) # advanced interpolation - x = xr.DataArray([0.5, 1.5, 2.5], dims='z') - y = xr.DataArray([0.15, 0.25, 0.35], dims='z') + x = xr.DataArray([0.5, 1.5, 2.5], dims="z") + y = xr.DataArray([0.15, 0.25, 0.35], dims="z") da.interp(x=x, y=y) where values on the original coordinates @@ -203,9 +212,8 @@ If you want to add a coordinate to the new dimension ``z``, you can supply .. ipython:: python - x = xr.DataArray([0.5, 1.5, 2.5], dims='z', coords={'z': ['a', 'b','c']}) - y = xr.DataArray([0.15, 0.25, 0.35], dims='z', - coords={'z': ['a', 'b','c']}) + x = xr.DataArray([0.5, 1.5, 2.5], dims="z", coords={"z": ["a", "b", "c"]}) + y = xr.DataArray([0.15, 0.25, 0.35], dims="z", coords={"z": ["a", "b", "c"]}) da.interp(x=x, y=y) For the details of the advanced indexing, @@ -224,19 +232,18 @@ while other methods such as ``cubic`` or ``quadratic`` return all NaN arrays. .. ipython:: python - da = xr.DataArray([0, 2, np.nan, 3, 3.25], dims='x', - coords={'x': range(5)}) + da = xr.DataArray([0, 2, np.nan, 3, 3.25], dims="x", coords={"x": range(5)}) da.interp(x=[0.5, 1.5, 2.5]) - da.interp(x=[0.5, 1.5, 2.5], method='cubic') + da.interp(x=[0.5, 1.5, 2.5], method="cubic") To avoid this, you can drop NaN by :py:meth:`~xarray.DataArray.dropna`, and then make the interpolation .. ipython:: python - dropped = da.dropna('x') + dropped = da.dropna("x") dropped - dropped.interp(x=[0.5, 1.5, 2.5], method='cubic') + dropped.interp(x=[0.5, 1.5, 2.5], method="cubic") If NaNs are distributed randomly in your multidimensional array, dropping all the columns containing more than one NaNs by @@ -246,7 +253,7 @@ which is similar to :py:meth:`pandas.Series.interpolate`. .. ipython:: python - filled = da.interpolate_na(dim='x') + filled = da.interpolate_na(dim="x") filled This fills NaN by interpolating along the specified dimension. @@ -254,7 +261,7 @@ After filling NaNs, you can interpolate: .. ipython:: python - filled.interp(x=[0.5, 1.5, 2.5], method='cubic') + filled.interp(x=[0.5, 1.5, 2.5], method="cubic") For the details of :py:meth:`~xarray.DataArray.interpolate_na`, see :ref:`Missing values `. @@ -268,18 +275,18 @@ Let's see how :py:meth:`~xarray.DataArray.interp` works on real data. .. ipython:: python # Raw data - ds = xr.tutorial.open_dataset('air_temperature').isel(time=0) + ds = xr.tutorial.open_dataset("air_temperature").isel(time=0) fig, axes = plt.subplots(ncols=2, figsize=(10, 4)) ds.air.plot(ax=axes[0]) - axes[0].set_title('Raw data') + axes[0].set_title("Raw data") # Interpolated data - new_lon = np.linspace(ds.lon[0], ds.lon[-1], ds.dims['lon'] * 4) - new_lat = np.linspace(ds.lat[0], ds.lat[-1], ds.dims['lat'] * 4) + new_lon = np.linspace(ds.lon[0], ds.lon[-1], ds.dims["lon"] * 4) + new_lat = np.linspace(ds.lat[0], ds.lat[-1], ds.dims["lat"] * 4) dsi = ds.interp(lat=new_lat, lon=new_lon) dsi.air.plot(ax=axes[1]) @savefig interpolation_sample3.png width=8in - axes[1].set_title('Interpolated data') + axes[1].set_title("Interpolated data") Our advanced interpolation can be used to remap the data to the new coordinate. Consider the new coordinates x and z on the two dimensional plane. @@ -291,20 +298,23 @@ The remapping can be done as follows x = np.linspace(240, 300, 100) z = np.linspace(20, 70, 100) # relation between new and original coordinates - lat = xr.DataArray(z, dims=['z'], coords={'z': z}) - lon = xr.DataArray((x[:, np.newaxis]-270)/np.cos(z*np.pi/180)+270, - dims=['x', 'z'], coords={'x': x, 'z': z}) + lat = xr.DataArray(z, dims=["z"], coords={"z": z}) + lon = xr.DataArray( + (x[:, np.newaxis] - 270) / np.cos(z * np.pi / 180) + 270, + dims=["x", "z"], + coords={"x": x, "z": z}, + ) fig, axes = plt.subplots(ncols=2, figsize=(10, 4)) ds.air.plot(ax=axes[0]) # draw the new coordinate on the original coordinates. for idx in [0, 33, 66, 99]: - axes[0].plot(lon.isel(x=idx), lat, '--k') + axes[0].plot(lon.isel(x=idx), lat, "--k") for idx in [0, 33, 66, 99]: - axes[0].plot(*xr.broadcast(lon.isel(z=idx), lat.isel(z=idx)), '--k') - axes[0].set_title('Raw data') + axes[0].plot(*xr.broadcast(lon.isel(z=idx), lat.isel(z=idx)), "--k") + axes[0].set_title("Raw data") dsi = ds.interp(lon=lon, lat=lat) dsi.air.plot(ax=axes[1]) @savefig interpolation_sample4.png width=8in - axes[1].set_title('Remapped data') + axes[1].set_title("Remapped data") diff --git a/doc/io.rst b/doc/io.rst index 0c666099df8..2e46879929b 100644 --- a/doc/io.rst +++ b/doc/io.rst @@ -9,11 +9,12 @@ simple :ref:`io.pickle` files to the more flexible :ref:`io.netcdf` format (recommended). .. ipython:: python - :suppress: + :suppress: import numpy as np import pandas as pd import xarray as xr + np.random.seed(123456) .. _io.netcdf: @@ -25,7 +26,7 @@ The recommended way to store xarray data structures is `netCDF`__, which is a binary file format for self-described datasets that originated in the geosciences. xarray is based on the netCDF data model, so netCDF files on disk directly correspond to :py:class:`Dataset` objects (more accurately, -a group in a netCDF file directly corresponds to a to :py:class:`Dataset` object. +a group in a netCDF file directly corresponds to a :py:class:`Dataset` object. See :ref:`io.netcdf_groups` for more.) NetCDF is supported on almost all platforms, and parsers exist @@ -42,7 +43,7 @@ __ http://www.unidata.ucar.edu/software/netcdf/ .. _netCDF FAQ: http://www.unidata.ucar.edu/software/netcdf/docs/faq.html#What-Is-netCDF Reading and writing netCDF files with xarray requires scipy or the -`netCDF4-Python`__ library to be installed (the later is required to +`netCDF4-Python`__ library to be installed (the latter is required to read/write netCDF V4 files and use the compression options described below). __ https://github.com/Unidata/netcdf4-python @@ -52,12 +53,16 @@ We can save a Dataset to disk using the .. ipython:: python - ds = xr.Dataset({'foo': (('x', 'y'), np.random.rand(4, 5))}, - coords={'x': [10, 20, 30, 40], - 'y': pd.date_range('2000-01-01', periods=5), - 'z': ('x', list('abcd'))}) + ds = xr.Dataset( + {"foo": (("x", "y"), np.random.rand(4, 5))}, + coords={ + "x": [10, 20, 30, 40], + "y": pd.date_range("2000-01-01", periods=5), + "z": ("x", list("abcd")), + }, + ) - ds.to_netcdf('saved_on_disk.nc') + ds.to_netcdf("saved_on_disk.nc") By default, the file is saved as netCDF4 (assuming netCDF4-Python is installed). You can control the format and engine used to write the file with @@ -76,7 +81,7 @@ We can load netCDF files to create a new Dataset using .. ipython:: python - ds_disk = xr.open_dataset('saved_on_disk.nc') + ds_disk = xr.open_dataset("saved_on_disk.nc") ds_disk Similarly, a DataArray can be saved to disk using the @@ -100,6 +105,12 @@ Dataset and DataArray objects, and no array values are loaded into memory until you try to perform some sort of actual computation. For an example of how these lazy arrays work, see the OPeNDAP section below. +There may be minor differences in the :py:class:`Dataset` object returned +when reading a NetCDF file with different engines. For example, +single-valued attributes are returned as scalars by the default +``engine=netcdf4``, but as arrays of size ``(1,)`` when reading with +``engine=h5netcdf``. + It is important to note that when you modify values of a Dataset, even one linked to files on disk, only the in-memory copy you are manipulating in xarray is modified: the original file on disk is never touched. @@ -117,7 +128,7 @@ netCDF file. However, it's often cleaner to use a ``with`` statement: .. ipython:: python # this automatically closes the dataset after use - with xr.open_dataset('saved_on_disk.nc') as ds: + with xr.open_dataset("saved_on_disk.nc") as ds: print(ds.keys()) Although xarray provides reasonable support for incremental reads of files on @@ -171,7 +182,7 @@ You can view this encoding information (among others) in the .. ipython:: :verbatim: - In [1]: ds_disk['y'].encoding + In [1]: ds_disk["y"].encoding Out[1]: {'zlib': False, 'shuffle': False, @@ -230,7 +241,7 @@ See its docstring for more details. .. note:: A common use-case involves a dataset distributed across a large number of files with - each file containing a large number of variables. Commonly a few of these variables + each file containing a large number of variables. Commonly, a few of these variables need to be concatenated along a dimension (say ``"time"``), while the rest are equal across the datasets (ignoring floating point differences). The following command with suitable modifications (such as ``parallel=True``) works well with such datasets:: @@ -287,8 +298,8 @@ library:: combined = read_netcdfs('/all/my/files/*.nc', dim='time') This function will work in many cases, but it's not very robust. First, it -never closes files, which means it will fail one you need to load more than -a few thousands file. Second, it assumes that you want all the data from each +never closes files, which means it will fail if you need to load more than +a few thousand files. Second, it assumes that you want all the data from each file and that it can all fit into memory. In many situations, you only need a small subset or an aggregated summary of the data from each file. @@ -340,7 +351,7 @@ default encoding, or the options in the ``encoding`` attribute, if set. This works perfectly fine in most cases, but encoding can be useful for additional control, especially for enabling compression. -In the file on disk, these encodings as saved as attributes on each variable, which +In the file on disk, these encodings are saved as attributes on each variable, which allow xarray and other CF-compliant tools for working with netCDF files to correctly read the data. @@ -353,7 +364,7 @@ These encoding options work on any version of the netCDF file format: or ``'float32'``. This controls the type of the data written on disk. - ``_FillValue``: Values of ``NaN`` in xarray variables are remapped to this value when saved on disk. This is important when converting floating point with missing values - to integers on disk, because ``NaN`` is not a valid value for integer dtypes. As a + to integers on disk, because ``NaN`` is not a valid value for integer dtypes. By default, variables with float types are attributed a ``_FillValue`` of ``NaN`` in the output file, unless explicitly disabled with an encoding ``{'_FillValue': None}``. - ``scale_factor`` and ``add_offset``: Used to convert from encoded data on disk to @@ -395,8 +406,8 @@ If character arrays are used: by setting the ``_Encoding`` field in ``encoding``. But `we don't recommend it `_. - The character dimension name can be specifed by the ``char_dim_name`` field of a variable's - ``encoding``. If this is not specified the default name for the character dimension is - ``'string%s' % data.shape[-1]``. When decoding character arrays from existing files, the + ``encoding``. If the name of the character dimension is not specified, the default is + ``f'string{data.shape[-1]}'``. When decoding character arrays from existing files, the ``char_dim_name`` is added to the variables ``encoding`` to preserve if encoding happens, but the field can be edited by the user. @@ -458,7 +469,7 @@ This is not CF-compliant but again facilitates roundtripping of xarray datasets. Invalid netCDF files ~~~~~~~~~~~~~~~~~~~~ -The library ``h5netcdf`` allows writing some dtypes (booleans, complex, ...) that aren't +The library ``h5netcdf`` allows writing some dtypes (booleans, complex, ...) that aren't allowed in netCDF4 (see `h5netcdf documentation `_). This feature is availabe through :py:meth:`DataArray.to_netcdf` and @@ -469,7 +480,7 @@ and currently raises a warning unless ``invalid_netcdf=True`` is set: :okwarning: # Writing complex valued data - da = xr.DataArray([1.+1.j, 2.+2.j, 3.+3.j]) + da = xr.DataArray([1.0 + 1.0j, 2.0 + 2.0j, 3.0 + 3.0j]) da.to_netcdf("complex.nc", engine="h5netcdf", invalid_netcdf=True) # Reading it back @@ -479,7 +490,8 @@ and currently raises a warning unless ``invalid_netcdf=True`` is set: :suppress: import os - os.remove('complex.nc') + + os.remove("complex.nc") .. warning:: @@ -494,14 +506,16 @@ Iris The Iris_ tool allows easy reading of common meteorological and climate model formats (including GRIB and UK MetOffice PP files) into ``Cube`` objects which are in many ways very similar to ``DataArray`` objects, while enforcing a CF-compliant data model. If iris is -installed xarray can convert a ``DataArray`` into a ``Cube`` using +installed, xarray can convert a ``DataArray`` into a ``Cube`` using :py:meth:`DataArray.to_iris`: .. ipython:: python - da = xr.DataArray(np.random.rand(4, 5), dims=['x', 'y'], - coords=dict(x=[10, 20, 30, 40], - y=pd.date_range('2000-01-01', periods=5))) + da = xr.DataArray( + np.random.rand(4, 5), + dims=["x", "y"], + coords=dict(x=[10, 20, 30, 40], y=pd.date_range("2000-01-01", periods=5)), + ) cube = da.to_iris() cube @@ -548,8 +562,9 @@ __ http://iri.columbia.edu/ :verbatim: In [3]: remote_data = xr.open_dataset( - ...: 'http://iridl.ldeo.columbia.edu/SOURCES/.OSU/.PRISM/.monthly/dods', - ...: decode_times=False) + ...: "http://iridl.ldeo.columbia.edu/SOURCES/.OSU/.PRISM/.monthly/dods", + ...: decode_times=False, + ...: ) In [4]: remote_data Out[4]: @@ -587,7 +602,7 @@ over the network until we look at particular values: .. ipython:: :verbatim: - In [4]: tmax = remote_data['tmax'][:500, ::3, ::3] + In [4]: tmax = remote_data["tmax"][:500, ::3, ::3] In [5]: tmax Out[5]: @@ -701,7 +716,7 @@ require external libraries and dicts can easily be pickled, or converted to json, or geojson. All the values are converted to lists, so dicts might be quite large. -To export just the dataset schema, without the data itself, use the +To export just the dataset schema without the data itself, use the ``data=False`` option: .. ipython:: python @@ -715,7 +730,8 @@ search indices or other automated data discovery tools. :suppress: import os - os.remove('saved_on_disk.nc') + + os.remove("saved_on_disk.nc") .. _io.rasterio: @@ -729,7 +745,7 @@ rasterio is installed. Here is an example of how to use .. ipython:: :verbatim: - In [7]: rio = xr.open_rasterio('RGB.byte.tif') + In [7]: rio = xr.open_rasterio("RGB.byte.tif") In [8]: rio Out[8]: @@ -756,7 +772,7 @@ for an example of how to convert these to longitudes and latitudes. .. warning:: This feature has been added in xarray v0.9.6 and should still be - considered as being experimental. Please report any bug you may find + considered experimental. Please report any bugs you may find on xarray's github repository. @@ -769,7 +785,7 @@ GDAL readable raster data using `rasterio`_ as well as for exporting to a geoTIF In [1]: import rioxarray - In [2]: rds = rioxarray.open_rasterio('RGB.byte.tif') + In [2]: rds = rioxarray.open_rasterio("RGB.byte.tif") In [3]: rds Out[3]: @@ -794,12 +810,12 @@ GDAL readable raster data using `rasterio`_ as well as for exporting to a geoTIF In [4]: rds.rio.crs Out[4]: CRS.from_epsg(32618) - In [5]: rds4326 = rio.rio.reproject("epsg:4326") + In [5]: rds4326 = rds.rio.reproject("epsg:4326") In [6]: rds4326.rio.crs Out[6]: CRS.from_epsg(4326) - In [7]: rds4326.rio.to_raster('RGB.byte.4326.tif') + In [7]: rds4326.rio.to_raster("RGB.byte.4326.tif") .. _rasterio: https://rasterio.readthedocs.io/en/latest/ @@ -812,12 +828,14 @@ GDAL readable raster data using `rasterio`_ as well as for exporting to a geoTIF Zarr ---- -`Zarr`_ is a Python package providing an implementation of chunked, compressed, +`Zarr`_ is a Python package that provides an implementation of chunked, compressed, N-dimensional arrays. Zarr has the ability to store arrays in a range of ways, including in memory, in files, and in cloud-based object storage such as `Amazon S3`_ and `Google Cloud Storage`_. -Xarray's Zarr backend allows xarray to leverage these capabilities. +Xarray's Zarr backend allows xarray to leverage these capabilities, including +the ability to store and analyze datasets far too large fit onto disk +(particularly :ref:`in combination with dask `). .. warning:: @@ -827,58 +845,44 @@ Xarray's Zarr backend allows xarray to leverage these capabilities. Xarray can't open just any zarr dataset, because xarray requires special metadata (attributes) describing the dataset dimensions and coordinates. At this time, xarray can only open zarr datasets that have been written by -xarray. To write a dataset with zarr, we use the :py:attr:`Dataset.to_zarr` method. -To write to a local directory, we pass a path to a directory +xarray. For implementation details, see :ref:`zarr_encoding`. + +To write a dataset with zarr, we use the :py:meth:`Dataset.to_zarr` method. + +To write to a local directory, we pass a path to a directory: .. ipython:: python - :suppress: + :suppress: ! rm -rf path/to/directory.zarr .. ipython:: python - ds = xr.Dataset({'foo': (('x', 'y'), np.random.rand(4, 5))}, - coords={'x': [10, 20, 30, 40], - 'y': pd.date_range('2000-01-01', periods=5), - 'z': ('x', list('abcd'))}) - ds.to_zarr('path/to/directory.zarr') + ds = xr.Dataset( + {"foo": (("x", "y"), np.random.rand(4, 5))}, + coords={ + "x": [10, 20, 30, 40], + "y": pd.date_range("2000-01-01", periods=5), + "z": ("x", list("abcd")), + }, + ) + ds.to_zarr("path/to/directory.zarr") (The suffix ``.zarr`` is optional--just a reminder that a zarr store lives there.) If the directory does not exist, it will be created. If a zarr store is already present at that path, an error will be raised, preventing it from being overwritten. To override this behavior and overwrite an existing -store, add ``mode='w'`` when invoking ``to_zarr``. - -It is also possible to append to an existing store. For that, set -``append_dim`` to the name of the dimension along which to append. ``mode`` -can be omitted as it will internally be set to ``'a'``. - -.. ipython:: python - :suppress: - - ! rm -rf path/to/directory.zarr - -.. ipython:: python +store, add ``mode='w'`` when invoking :py:meth:`~Dataset.to_zarr`. - ds1 = xr.Dataset({'foo': (('x', 'y', 't'), np.random.rand(4, 5, 2))}, - coords={'x': [10, 20, 30, 40], - 'y': [1,2,3,4,5], - 't': pd.date_range('2001-01-01', periods=2)}) - ds1.to_zarr('path/to/directory.zarr') - ds2 = xr.Dataset({'foo': (('x', 'y', 't'), np.random.rand(4, 5, 2))}, - coords={'x': [10, 20, 30, 40], - 'y': [1,2,3,4,5], - 't': pd.date_range('2001-01-03', periods=2)}) - ds2.to_zarr('path/to/directory.zarr', append_dim='t') - -To store variable length strings use ``dtype=object``. +To store variable length strings, convert them to object arrays first with +``dtype=object``. To read back a zarr dataset that has been created this way, we use the :py:func:`open_zarr` method: .. ipython:: python - ds_zarr = xr.open_zarr('path/to/directory.zarr') + ds_zarr = xr.open_zarr("path/to/directory.zarr") ds_zarr Cloud Storage Buckets @@ -912,15 +916,16 @@ These options can be passed to the ``to_zarr`` method as variable encoding. For example: .. ipython:: python - :suppress: + :suppress: ! rm -rf foo.zarr .. ipython:: python import zarr - compressor = zarr.Blosc(cname='zstd', clevel=3, shuffle=2) - ds.to_zarr('foo.zarr', encoding={'foo': {'compressor': compressor}}) + + compressor = zarr.Blosc(cname="zstd", clevel=3, shuffle=2) + ds.to_zarr("foo.zarr", encoding={"foo": {"compressor": compressor}}) .. note:: @@ -956,34 +961,137 @@ Xarray can't perform consolidation on pre-existing zarr datasets. This should be done directly from zarr, as described in the `zarr docs `_. +.. _io.zarr.appending: + +Appending to existing Zarr stores +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Xarray supports several ways of incrementally writing variables to a Zarr +store. These options are useful for scenarios when it is infeasible or +undesirable to write your entire dataset at once. + +.. tip:: + + If you can load all of your data into a single ``Dataset`` using dask, a + single call to ``to_zarr()`` will write all of your data in parallel. + +.. warning:: + + Alignment of coordinates is currently not checked when modifying an + existing Zarr store. It is up to the user to ensure that coordinates are + consistent. + +To add or overwrite entire variables, simply call :py:meth:`~Dataset.to_zarr` +with ``mode='a'`` on a Dataset containing the new variables, passing in an +existing Zarr store or path to a Zarr store. + +To resize and then append values along an existing dimension in a store, set +``append_dim``. This is a good option if data always arives in a particular +order, e.g., for time-stepping a simulation: + +.. ipython:: python + :suppress: + + ! rm -rf path/to/directory.zarr + +.. ipython:: python + + ds1 = xr.Dataset( + {"foo": (("x", "y", "t"), np.random.rand(4, 5, 2))}, + coords={ + "x": [10, 20, 30, 40], + "y": [1, 2, 3, 4, 5], + "t": pd.date_range("2001-01-01", periods=2), + }, + ) + ds1.to_zarr("path/to/directory.zarr") + ds2 = xr.Dataset( + {"foo": (("x", "y", "t"), np.random.rand(4, 5, 2))}, + coords={ + "x": [10, 20, 30, 40], + "y": [1, 2, 3, 4, 5], + "t": pd.date_range("2001-01-03", periods=2), + }, + ) + ds2.to_zarr("path/to/directory.zarr", append_dim="t") + +Finally, you can use ``region`` to write to limited regions of existing arrays +in an existing Zarr store. This is a good option for writing data in parallel +from independent processes. + +To scale this up to writing large datasets, the first step is creating an +initial Zarr store without writing all of its array data. This can be done by +first creating a ``Dataset`` with dummy values stored in :ref:`dask `, +and then calling ``to_zarr`` with ``compute=False`` to write only metadata +(including ``attrs``) to Zarr: + +.. ipython:: python + :suppress: + + ! rm -rf path/to/directory.zarr + +.. ipython:: python + + import dask.array + + # The values of this dask array are entirely irrelevant; only the dtype, + # shape and chunks are used + dummies = dask.array.zeros(30, chunks=10) + ds = xr.Dataset({"foo": ("x", dummies)}) + path = "path/to/directory.zarr" + # Now we write the metadata without computing any array values + ds.to_zarr(path, compute=False, consolidated=True) + +Now, a Zarr store with the correct variable shapes and attributes exists that +can be filled out by subsequent calls to ``to_zarr``. The ``region`` provides a +mapping from dimension names to Python ``slice`` objects indicating where the +data should be written (in index space, not coordinate space), e.g., + +.. ipython:: python + + # For convenience, we'll slice a single dataset, but in the real use-case + # we would create them separately, possibly even from separate processes. + ds = xr.Dataset({"foo": ("x", np.arange(30))}) + ds.isel(x=slice(0, 10)).to_zarr(path, region={"x": slice(0, 10)}) + ds.isel(x=slice(10, 20)).to_zarr(path, region={"x": slice(10, 20)}) + ds.isel(x=slice(20, 30)).to_zarr(path, region={"x": slice(20, 30)}) + +Concurrent writes with ``region`` are safe as long as they modify distinct +chunks in the underlying Zarr arrays (or use an appropriate ``lock``). + +As a safety check to make it harder to inadvertently override existing values, +if you set ``region`` then *all* variables included in a Dataset must have +dimensions included in ``region``. Other variables (typically coordinates) +need to be explicitly dropped and/or written in a separate calls to ``to_zarr`` +with ``mode='a'``. + .. _io.cfgrib: .. ipython:: python - :suppress: + :suppress: import shutil - shutil.rmtree('foo.zarr') - shutil.rmtree('path/to/directory.zarr') + + shutil.rmtree("foo.zarr") + shutil.rmtree("path/to/directory.zarr") GRIB format via cfgrib ---------------------- -xarray supports reading GRIB files via ECMWF cfgrib_ python driver and ecCodes_ -C-library, if they are installed. To open a GRIB file supply ``engine='cfgrib'`` +xarray supports reading GRIB files via ECMWF cfgrib_ python driver, +if it is installed. To open a GRIB file supply ``engine='cfgrib'`` to :py:func:`open_dataset`: .. ipython:: :verbatim: - In [1]: ds_grib = xr.open_dataset('example.grib', engine='cfgrib') + In [1]: ds_grib = xr.open_dataset("example.grib", engine="cfgrib") -We recommend installing ecCodes via conda:: +We recommend installing cfgrib via conda:: - conda install -c conda-forge eccodes - pip install cfgrib + conda install -c conda-forge cfgrib .. _cfgrib: https://github.com/ecmwf/cfgrib -.. _ecCodes: https://confluence.ecmwf.int/display/ECC/ecCodes+Home .. _io.pynio: @@ -998,6 +1106,11 @@ We recommend installing PyNIO via conda:: conda install -c conda-forge pynio + .. note:: + + PyNIO is no longer actively maintained and conflicts with netcdf4 > 1.5.3. + The PyNIO backend may be moved outside of xarray in the future. + .. _PyNIO: https://www.pyngl.ucar.edu/Nio.shtml .. _io.PseudoNetCDF: @@ -1010,7 +1123,7 @@ formats supported by PseudoNetCDF_, if PseudoNetCDF is installed. PseudoNetCDF can also provide Climate Forecasting Conventions to CMAQ files. In addition, PseudoNetCDF can automatically register custom readers that subclass PseudoNetCDF.PseudoNetCDFFile. PseudoNetCDF can -identify readers heuristically, or format can be specified via a key in +identify readers either heuristically, or by a format specified via a key in `backend_kwargs`. To use PseudoNetCDF to read such files, supply @@ -1032,3 +1145,11 @@ For CSV files, one might also consider `xarray_extras`_. .. _xarray_extras: https://xarray-extras.readthedocs.io/en/latest/api/csv.html .. _IO tools: http://pandas.pydata.org/pandas-docs/stable/io.html + + +Third party libraries +--------------------- + +More formats are supported by extension libraries: + +- `xarray-mongodb `_: Store xarray objects on MongoDB diff --git a/doc/pandas.rst b/doc/pandas.rst index b0ec2a117dc..acf1d16b6ee 100644 --- a/doc/pandas.rst +++ b/doc/pandas.rst @@ -20,6 +20,7 @@ __ http://seaborn.pydata.org/ import numpy as np import pandas as pd import xarray as xr + np.random.seed(123456) Hierarchical and tidy data @@ -47,10 +48,15 @@ To convert any dataset to a ``DataFrame`` in tidy form, use the .. ipython:: python - ds = xr.Dataset({'foo': (('x', 'y'), np.random.randn(2, 3))}, - coords={'x': [10, 20], 'y': ['a', 'b', 'c'], - 'along_x': ('x', np.random.randn(2)), - 'scalar': 123}) + ds = xr.Dataset( + {"foo": (("x", "y"), np.random.randn(2, 3))}, + coords={ + "x": [10, 20], + "y": ["a", "b", "c"], + "along_x": ("x", np.random.randn(2)), + "scalar": 123, + }, + ) ds df = ds.to_dataframe() df @@ -91,7 +97,7 @@ DataFrames: .. ipython:: python - s = ds['foo'].to_series() + s = ds["foo"].to_series() s # or equivalently, with Series.to_xarray() xr.DataArray.from_series(s) @@ -117,8 +123,9 @@ available in pandas (i.e., a 1D array is converted to a .. ipython:: python - arr = xr.DataArray(np.random.randn(2, 3), - coords=[('x', [10, 20]), ('y', ['a', 'b', 'c'])]) + arr = xr.DataArray( + np.random.randn(2, 3), coords=[("x", [10, 20]), ("y", ["a", "b", "c"])] + ) df = arr.to_pandas() df @@ -136,9 +143,10 @@ preserve all use of multi-indexes: .. ipython:: python - index = pd.MultiIndex.from_arrays([['a', 'a', 'b'], [0, 1, 2]], - names=['one', 'two']) - df = pd.DataFrame({'x': 1, 'y': 2}, index=index) + index = pd.MultiIndex.from_arrays( + [["a", "a", "b"], [0, 1, 2]], names=["one", "two"] + ) + df = pd.DataFrame({"x": 1, "y": 2}, index=index) ds = xr.Dataset(df) ds @@ -175,9 +183,9 @@ Let's take a look: .. ipython:: python data = np.random.RandomState(0).rand(2, 3, 4) - items = list('ab') - major_axis = list('mno') - minor_axis = pd.date_range(start='2000', periods=4, name='date') + items = list("ab") + major_axis = list("mno") + minor_axis = pd.date_range(start="2000", periods=4, name="date") With old versions of pandas (prior to 0.25), this could stored in a ``Panel``: @@ -207,7 +215,7 @@ You can also easily convert this data into ``Dataset``: .. ipython:: python - array.to_dataset(dim='dim_0') + array.to_dataset(dim="dim_0") Here, there are two data variables, each representing a DataFrame on panel's ``items`` axis, and labeled as such. Each variable is a 2D array of the diff --git a/doc/plotting.rst b/doc/plotting.rst index f3d9c0213de..3699f794ae8 100644 --- a/doc/plotting.rst +++ b/doc/plotting.rst @@ -13,7 +13,7 @@ labels can also be used to easily create informative plots. xarray's plotting capabilities are centered around :py:class:`DataArray` objects. To plot :py:class:`Dataset` objects -simply access the relevant DataArrays, ie ``dset['var1']``. +simply access the relevant DataArrays, i.e. ``dset['var1']``. Dataset specific plotting routines are also available (see :ref:`plot-dataset`). Here we focus mostly on arrays 2d or larger. If your data fits nicely into a pandas DataFrame then you're better off using one of the more @@ -37,7 +37,7 @@ For more extensive plotting applications consider the following projects: Integrates well with pandas. - `HoloViews `_ - and `GeoViews `_: "Composable, declarative + and `GeoViews `_: "Composable, declarative data structures for building even complex visualizations easily." Includes native support for xarray objects. @@ -56,6 +56,7 @@ Imports # Use defaults so we don't get gridlines in generated docs import matplotlib as mpl + mpl.rcdefaults() The following imports are necessary for all of the examples. @@ -71,7 +72,7 @@ For these examples we'll use the North American air temperature dataset. .. ipython:: python - airtemps = xr.tutorial.open_dataset('air_temperature') + airtemps = xr.tutorial.open_dataset("air_temperature") airtemps # Convert to celsius @@ -79,7 +80,7 @@ For these examples we'll use the North American air temperature dataset. # copy attributes to get nice figure labels and change Kelvin to Celsius air.attrs = airtemps.air.attrs - air.attrs['units'] = 'deg C' + air.attrs["units"] = "deg C" .. note:: Until :issue:`1614` is solved, you might need to copy over the metadata in ``attrs`` to get informative figure labels (as was done above). @@ -98,13 +99,14 @@ One Dimension The simplest way to make a plot is to call the :py:func:`DataArray.plot()` method. .. ipython:: python + :okwarning: air1d = air.isel(lat=10, lon=10) @savefig plotting_1d_simple.png width=4in air1d.plot() -xarray uses the coordinate name along with metadata ``attrs.long_name``, ``attrs.standard_name``, ``DataArray.name`` and ``attrs.units`` (if available) to label the axes. The names ``long_name``, ``standard_name`` and ``units`` are copied from the `CF-conventions spec `_. When choosing names, the order of precedence is ``long_name``, ``standard_name`` and finally ``DataArray.name``. The y-axis label in the above plot was constructed from the ``long_name`` and ``units`` attributes of ``air1d``. +xarray uses the coordinate name along with metadata ``attrs.long_name``, ``attrs.standard_name``, ``DataArray.name`` and ``attrs.units`` (if available) to label the axes. The names ``long_name``, ``standard_name`` and ``units`` are copied from the `CF-conventions spec `_. When choosing names, the order of precedence is ``long_name``, ``standard_name`` and finally ``DataArray.name``. The y-axis label in the above plot was constructed from the ``long_name`` and ``units`` attributes of ``air1d``. .. ipython:: python @@ -124,9 +126,10 @@ can be used: .. _matplotlib.pyplot.plot: http://matplotlib.org/api/pyplot_api.html#matplotlib.pyplot.plot .. ipython:: python + :okwarning: @savefig plotting_1d_additional_args.png width=4in - air1d[:200].plot.line('b-^') + air1d[:200].plot.line("b-^") .. note:: Not all xarray plotting methods support passing positional arguments @@ -136,9 +139,10 @@ can be used: Keyword arguments work the same way, and are more explicit. .. ipython:: python + :okwarning: @savefig plotting_example_sin3.png width=4in - air1d[:200].plot.line(color='purple', marker='o') + air1d[:200].plot.line(color="purple", marker="o") ========================= Adding to Existing Axis @@ -150,6 +154,7 @@ In this example ``axes`` is an array consisting of the left and right axes created by ``plt.subplots``. .. ipython:: python + :okwarning: fig, axes = plt.subplots(ncols=2) @@ -177,6 +182,7 @@ support the ``aspect`` and ``size`` arguments which control the size of the resulting image via the formula ``figsize = (aspect * size, size)``: .. ipython:: python + :okwarning: air1d.plot(aspect=2, size=3) @savefig plotting_example_size_and_aspect.png @@ -208,6 +214,48 @@ entire figure (as for matplotlib's ``figsize`` argument). .. _plotting.multiplelines: +========================= + Determine x-axis values +========================= + +Per default dimension coordinates are used for the x-axis (here the time coordinates). +However, you can also use non-dimension coordinates, MultiIndex levels, and dimensions +without coordinates along the x-axis. To illustrate this, let's calculate a 'decimal day' (epoch) +from the time and assign it as a non-dimension coordinate: + +.. ipython:: python + :okwarning: + + decimal_day = (air1d.time - air1d.time[0]) / pd.Timedelta("1d") + air1d_multi = air1d.assign_coords(decimal_day=("time", decimal_day)) + air1d_multi + +To use ``'decimal_day'`` as x coordinate it must be explicitly specified: + +.. ipython:: python + :okwarning: + + air1d_multi.plot(x="decimal_day") + +Creating a new MultiIndex named ``'date'`` from ``'time'`` and ``'decimal_day'``, +it is also possible to use a MultiIndex level as x-axis: + +.. ipython:: python + :okwarning: + + air1d_multi = air1d_multi.set_index(date=("time", "decimal_day")) + air1d_multi.plot(x="decimal_day") + +Finally, if a dataset does not have any coordinates it enumerates all data points: + +.. ipython:: python + :okwarning: + + air1d_multi = air1d_multi.drop("date") + air1d_multi.plot() + +The same applies to 2D plots below. + ==================================================== Multiple lines showing variation along a dimension ==================================================== @@ -217,9 +265,10 @@ with appropriate arguments. Consider the 3D variable ``air`` defined above. We c plots to check the variation of air temperature at three different latitudes along a longitude line: .. ipython:: python + :okwarning: @savefig plotting_example_multiple_lines_x_kwarg.png - air.isel(lon=10, lat=[19,21,22]).plot.line(x='time') + air.isel(lon=10, lat=[19, 21, 22]).plot.line(x="time") It is required to explicitly specify either @@ -238,9 +287,10 @@ If required, the automatic legend can be turned off using ``add_legend=False``. It is also possible to make line plots such that the data are on the x-axis and a dimension is on the y-axis. This can be done by specifying the appropriate ``y`` keyword argument. .. ipython:: python + :okwarning: @savefig plotting_example_xy_kwarg.png - air.isel(time=10, lon=[10, 11]).plot(y='lat', hue='lon') + air.isel(time=10, lon=[10, 11]).plot(y="lat", hue="lon") ============ Step plots @@ -253,23 +303,24 @@ made using 1D data. :okwarning: @savefig plotting_example_step.png width=4in - air1d[:20].plot.step(where='mid') + air1d[:20].plot.step(where="mid") The argument ``where`` defines where the steps should be placed, options are ``'pre'`` (default), ``'post'``, and ``'mid'``. This is particularly handy when plotting data grouped with :py:meth:`Dataset.groupby_bins`. .. ipython:: python + :okwarning: - air_grp = air.mean(['time','lon']).groupby_bins('lat',[0,23.5,66.5,90]) + air_grp = air.mean(["time", "lon"]).groupby_bins("lat", [0, 23.5, 66.5, 90]) air_mean = air_grp.mean() air_std = air_grp.std() air_mean.plot.step() - (air_mean + air_std).plot.step(ls=':') - (air_mean - air_std).plot.step(ls=':') - plt.ylim(-20,30) + (air_mean + air_std).plot.step(ls=":") + (air_mean - air_std).plot.step(ls=":") + plt.ylim(-20, 30) @savefig plotting_example_step_groupby.png width=4in - plt.title('Zonal mean temperature') + plt.title("Zonal mean temperature") In this case, the actual boundaries of the bins are used and the ``where`` argument is ignored. @@ -282,9 +333,12 @@ Other axes kwargs The keyword arguments ``xincrease`` and ``yincrease`` let you control the axes direction. .. ipython:: python + :okwarning: @savefig plotting_example_xincrease_yincrease_kwarg.png - air.isel(time=10, lon=[10, 11]).plot.line(y='lat', hue='lon', xincrease=False, yincrease=False) + air.isel(time=10, lon=[10, 11]).plot.line( + y="lat", hue="lon", xincrease=False, yincrease=False + ) In addition, one can use ``xscale, yscale`` to set axes scaling; ``xticks, yticks`` to set axes ticks and ``xlim, ylim`` to set axes limits. These accept the same values as the matplotlib methods ``Axes.set_(x,y)scale()``, ``Axes.set_(x,y)ticks()``, ``Axes.set_(x,y)lim()`` respectively. @@ -299,6 +353,7 @@ Two Dimensions The default method :py:meth:`DataArray.plot` calls :py:func:`xarray.plot.pcolormesh` by default when the data is two-dimensional. .. ipython:: python + :okwarning: air2d = air.isel(time=500) @@ -309,6 +364,7 @@ All 2d plots in xarray allow the use of the keyword arguments ``yincrease`` and ``xincrease``. .. ipython:: python + :okwarning: @savefig 2d_simple_yincrease.png width=4in air2d.plot(yincrease=False) @@ -328,6 +384,7 @@ and ``xincrease``. xarray plots data with :ref:`missing_values`. .. ipython:: python + :okwarning: bad_air2d = air2d.copy() @@ -345,10 +402,11 @@ It's not necessary for the coordinates to be evenly spaced. Both produce plots with nonuniform coordinates. .. ipython:: python + :okwarning: b = air2d.copy() # Apply a nonlinear transformation to one of the coords - b.coords['lat'] = np.log(b.coords['lat']) + b.coords["lat"] = np.log(b.coords["lat"]) @savefig plotting_nonuniform_coords.png width=4in b.plot() @@ -361,11 +419,12 @@ Since this is a thin wrapper around matplotlib, all the functionality of matplotlib is available. .. ipython:: python + :okwarning: air2d.plot(cmap=plt.cm.Blues) - plt.title('These colors prove North America\nhas fallen in the ocean') - plt.ylabel('latitude') - plt.xlabel('longitude') + plt.title("These colors prove North America\nhas fallen in the ocean") + plt.ylabel("latitude") + plt.xlabel("longitude") plt.tight_layout() @savefig plotting_2d_call_matplotlib.png width=4in @@ -380,8 +439,9 @@ matplotlib is available. ``d_ylog.plot()`` updates the xlabel. .. ipython:: python + :okwarning: - plt.xlabel('Never gonna see this.') + plt.xlabel("Never gonna see this.") air2d.plot() @savefig plotting_2d_call_matplotlib2.png width=4in @@ -395,6 +455,7 @@ xarray borrows logic from Seaborn to infer what kind of color map to use. For example, consider the original data in Kelvins rather than Celsius: .. ipython:: python + :okwarning: @savefig plotting_kelvin.png width=4in airtemps.air.isel(time=0).plot() @@ -413,6 +474,7 @@ Here we add two bad data points. This affects the color scale, washing out the plot. .. ipython:: python + :okwarning: air_outliers = airtemps.air.isel(time=0).copy() air_outliers[0, 0] = 100 @@ -428,6 +490,7 @@ This will use the 2nd and 98th percentiles of the data to compute the color limits. .. ipython:: python + :okwarning: @savefig plotting_robust2.png width=4in air_outliers.plot(robust=True) @@ -446,6 +509,7 @@ rather than the default continuous colormaps that matplotlib uses. The colormaps. For example, to make a plot with 8 discrete color intervals: .. ipython:: python + :okwarning: @savefig plotting_discrete_levels.png width=4in air2d.plot(levels=8) @@ -454,6 +518,7 @@ It is also possible to use a list of levels to specify the boundaries of the discrete colormap: .. ipython:: python + :okwarning: @savefig plotting_listed_levels.png width=4in air2d.plot(levels=[0, 12, 18, 30]) @@ -461,6 +526,7 @@ discrete colormap: You can also specify a list of discrete colors through the ``colors`` argument: .. ipython:: python + :okwarning: flatui = ["#9b59b6", "#3498db", "#95a5a6", "#e74c3c", "#34495e", "#2ecc71"] @savefig plotting_custom_colors_levels.png width=4in @@ -473,10 +539,10 @@ if using ``imshow`` or ``pcolormesh`` (but not with ``contour`` or ``contourf``, since levels are chosen automatically). .. ipython:: python - :okwarning: + :okwarning: @savefig plotting_seaborn_palette.png width=4in - air2d.plot(levels=10, cmap='husl') + air2d.plot(levels=10, cmap="husl") plt.draw() .. _plotting.faceting: @@ -518,16 +584,20 @@ arguments to the xarray plotting methods/functions. This returns a :py:class:`xarray.plot.FacetGrid` object. .. ipython:: python + :okwarning: @savefig plot_facet_dataarray.png - g_simple = t.plot(x='lon', y='lat', col='time', col_wrap=3) + g_simple = t.plot(x="lon", y="lat", col="time", col_wrap=3) Faceting also works for line plots. .. ipython:: python + :okwarning: @savefig plot_facet_dataarray_line.png - g_simple_line = t.isel(lat=slice(0,None,4)).plot(x='lon', hue='lat', col='time', col_wrap=3) + g_simple_line = t.isel(lat=slice(0, None, 4)).plot( + x="lon", hue="lat", col="time", col_wrap=3 + ) =============== 4 dimensional @@ -539,14 +609,15 @@ a fixed amount. Now we can see how the temperature maps would compare if one were much hotter. .. ipython:: python + :okwarning: t2 = t.isel(time=slice(0, 2)) - t4d = xr.concat([t2, t2 + 40], pd.Index(['normal', 'hot'], name='fourth_dim')) + t4d = xr.concat([t2, t2 + 40], pd.Index(["normal", "hot"], name="fourth_dim")) # This is a 4d array t4d.coords @savefig plot_facet_4d.png - t4d.plot(x='lon', y='lat', col='time', row='fourth_dim') + t4d.plot(x="lon", y="lat", col="time", row="fourth_dim") ================ Other features @@ -555,20 +626,27 @@ one were much hotter. Faceted plotting supports other arguments common to xarray 2d plots. .. ipython:: python - :suppress: + :suppress: - plt.close('all') + plt.close("all") .. ipython:: python + :okwarning: hasoutliers = t.isel(time=slice(0, 5)).copy() hasoutliers[0, 0, 0] = -100 hasoutliers[-1, -1, -1] = 400 @savefig plot_facet_robust.png - g = hasoutliers.plot.pcolormesh('lon', 'lat', col='time', col_wrap=3, - robust=True, cmap='viridis', - cbar_kwargs={'label': 'this has outliers'}) + g = hasoutliers.plot.pcolormesh( + "lon", + "lat", + col="time", + col_wrap=3, + robust=True, + cmap="viridis", + cbar_kwargs={"label": "this has outliers"}, + ) =================== FacetGrid Objects @@ -594,20 +672,21 @@ It's possible to select the :py:class:`xarray.DataArray` or .. ipython:: python - g.data.loc[g.name_dicts[0, 0]] + g.data.loc[g.name_dicts[0, 0]] Here is an example of using the lower level API and then modifying the axes after they have been plotted. .. ipython:: python + :okwarning: - g = t.plot.imshow('lon', 'lat', col='time', col_wrap=3, robust=True) + g = t.plot.imshow("lon", "lat", col="time", col_wrap=3, robust=True) for i, ax in enumerate(g.axes.flat): - ax.set_title('Air Temperature %d' % i) + ax.set_title("Air Temperature %d" % i) bottomright = g.axes[-1, -1] - bottomright.annotate('bottom right', (240, 40)) + bottomright.annotate("bottom right", (240, 40)) @savefig plot_facet_iterator.png plt.draw() @@ -632,23 +711,25 @@ Consider this dataset .. ipython:: python - ds = xr.tutorial.scatter_example_dataset() - ds + ds = xr.tutorial.scatter_example_dataset() + ds Suppose we want to scatter ``A`` against ``B`` .. ipython:: python + :okwarning: @savefig ds_simple_scatter.png - ds.plot.scatter(x='A', y='B') + ds.plot.scatter(x="A", y="B") The ``hue`` kwarg lets you vary the color by variable value .. ipython:: python + :okwarning: @savefig ds_hue_scatter.png - ds.plot.scatter(x='A', y='B', hue='w') + ds.plot.scatter(x="A", y="B", hue="w") When ``hue`` is specified, a colorbar is added for numeric ``hue`` DataArrays by default and a legend is added for non-numeric ``hue`` DataArrays (as above). @@ -656,24 +737,27 @@ You can force a legend instead of a colorbar by setting ``hue_style='discrete'`` Additionally, the boolean kwarg ``add_guide`` can be used to prevent the display of a legend or colorbar (as appropriate). .. ipython:: python + :okwarning: ds = ds.assign(w=[1, 2, 3, 5]) @savefig ds_discrete_legend_hue_scatter.png - ds.plot.scatter(x='A', y='B', hue='w', hue_style='discrete') + ds.plot.scatter(x="A", y="B", hue="w", hue_style="discrete") The ``markersize`` kwarg lets you vary the point's size by variable value. You can additionally pass ``size_norm`` to control how the variable's values are mapped to point sizes. .. ipython:: python + :okwarning: @savefig ds_hue_size_scatter.png - ds.plot.scatter(x='A', y='B', hue='z', hue_style='discrete', markersize='z') + ds.plot.scatter(x="A", y="B", hue="z", hue_style="discrete", markersize="z") Faceting is also possible .. ipython:: python + :okwarning: @savefig ds_facet_scatter.png - ds.plot.scatter(x='A', y='B', col='x', row='z', hue='w', hue_style='discrete') + ds.plot.scatter(x="A", y="B", col="x", row="z", hue="w", hue_style="discrete") For more advanced scatter plots, we recommend converting the relevant data variables to a pandas DataFrame and using the extensive plotting capabilities of ``seaborn``. @@ -689,27 +773,38 @@ To follow this section you'll need to have Cartopy installed and working. This script will plot the air temperature on a map. .. ipython:: python + :okwarning: import cartopy.crs as ccrs - air = xr.tutorial.open_dataset('air_temperature').air - ax = plt.axes(projection=ccrs.Orthographic(-80, 35)) - air.isel(time=0).plot.contourf(ax=ax, transform=ccrs.PlateCarree()); + + air = xr.tutorial.open_dataset("air_temperature").air + + p = air.isel(time=0).plot( + subplot_kws=dict(projection=ccrs.Orthographic(-80, 35), facecolor="gray"), + transform=ccrs.PlateCarree(), + ) + p.axes.set_global() + @savefig plotting_maps_cartopy.png width=100% - ax.set_global(); ax.coastlines(); + p.axes.coastlines() When faceting on maps, the projection can be transferred to the ``plot`` function using the ``subplot_kws`` keyword. The axes for the subplots created by faceting are accessible in the object returned by ``plot``: .. ipython:: python + :okwarning: - p = air.isel(time=[0, 4]).plot(transform=ccrs.PlateCarree(), col='time', - subplot_kws={'projection': ccrs.Orthographic(-80, 35)}) + p = air.isel(time=[0, 4]).plot( + transform=ccrs.PlateCarree(), + col="time", + subplot_kws={"projection": ccrs.Orthographic(-80, 35)}, + ) for ax in p.axes.flat: ax.coastlines() ax.gridlines() @savefig plotting_maps_cartopy_facetting.png width=100% - plt.draw(); + plt.draw() Details @@ -730,8 +825,10 @@ There are three ways to use the xarray plotting functionality: These are provided for user convenience; they all call the same code. .. ipython:: python + :okwarning: import xarray.plot as xplt + da = xr.DataArray(range(5)) fig, axes = plt.subplots(ncols=2, nrows=2) da.plot(ax=axes[0, 0]) @@ -766,8 +863,7 @@ read on. .. ipython:: python - a0 = xr.DataArray(np.zeros((4, 3, 2)), dims=('y', 'x', 'z'), - name='temperature') + a0 = xr.DataArray(np.zeros((4, 3, 2)), dims=("y", "x", "z"), name="temperature") a0[0, 0, 0] = 1 a = a0.isel(z=0) a @@ -779,6 +875,7 @@ think carefully about what the limits, labels, and orientation for each of the axes should be. .. ipython:: python + :okwarning: @savefig plotting_example_2d_simple.png width=4in a.plot() @@ -799,16 +896,19 @@ xarray, but you'll have to tell the plot function to use these coordinates instead of the default ones: .. ipython:: python + :okwarning: lon, lat = np.meshgrid(np.linspace(-20, 20, 5), np.linspace(0, 30, 4)) - lon += lat/10 - lat += lon/10 - da = xr.DataArray(np.arange(20).reshape(4, 5), dims=['y', 'x'], - coords = {'lat': (('y', 'x'), lat), - 'lon': (('y', 'x'), lon)}) + lon += lat / 10 + lat += lon / 10 + da = xr.DataArray( + np.arange(20).reshape(4, 5), + dims=["y", "x"], + coords={"lat": (("y", "x"), lat), "lon": (("y", "x"), lon)}, + ) @savefig plotting_example_2d_irreg.png width=4in - da.plot.pcolormesh('lon', 'lat'); + da.plot.pcolormesh("lon", "lat") Note that in this case, xarray still follows the pixel centered convention. This might be undesirable in some cases, for example when your data is defined @@ -816,24 +916,29 @@ on a polar projection (:issue:`781`). This is why the default is to not follow this convention when plotting on a map: .. ipython:: python + :okwarning: import cartopy.crs as ccrs - ax = plt.subplot(projection=ccrs.PlateCarree()); - da.plot.pcolormesh('lon', 'lat', ax=ax); - ax.scatter(lon, lat, transform=ccrs.PlateCarree()); + + ax = plt.subplot(projection=ccrs.PlateCarree()) + da.plot.pcolormesh("lon", "lat", ax=ax) + ax.scatter(lon, lat, transform=ccrs.PlateCarree()) + ax.coastlines() @savefig plotting_example_2d_irreg_map.png width=4in - ax.coastlines(); ax.gridlines(draw_labels=True); + ax.gridlines(draw_labels=True) You can however decide to infer the cell boundaries and use the ``infer_intervals`` keyword: .. ipython:: python + :okwarning: - ax = plt.subplot(projection=ccrs.PlateCarree()); - da.plot.pcolormesh('lon', 'lat', ax=ax, infer_intervals=True); - ax.scatter(lon, lat, transform=ccrs.PlateCarree()); + ax = plt.subplot(projection=ccrs.PlateCarree()) + da.plot.pcolormesh("lon", "lat", ax=ax, infer_intervals=True) + ax.scatter(lon, lat, transform=ccrs.PlateCarree()) + ax.coastlines() @savefig plotting_example_2d_irreg_map_infer.png width=4in - ax.coastlines(); ax.gridlines(draw_labels=True); + ax.gridlines(draw_labels=True) .. note:: The data model of xarray does not support datasets with `cell boundaries`_ @@ -845,8 +950,9 @@ You can however decide to infer the cell boundaries and use the One can also make line plots with multidimensional coordinates. In this case, ``hue`` must be a dimension name, not a coordinate name. .. ipython:: python + :okwarning: f, ax = plt.subplots(2, 1) - da.plot.line(x='lon', hue='y', ax=ax[0]); + da.plot.line(x="lon", hue="y", ax=ax[0]) @savefig plotting_example_2d_hue_xy.png - da.plot.line(x='lon', hue='x', ax=ax[1]); + da.plot.line(x="lon", hue="x", ax=ax[1]) diff --git a/doc/quick-overview.rst b/doc/quick-overview.rst index 741b3d1a5fe..1a2bc809550 100644 --- a/doc/quick-overview.rst +++ b/doc/quick-overview.rst @@ -22,16 +22,14 @@ array or list, with optional *dimensions* and *coordinates*: .. ipython:: python - data = xr.DataArray(np.random.randn(2, 3), - dims=('x', 'y'), - coords={'x': [10, 20]}) + data = xr.DataArray(np.random.randn(2, 3), dims=("x", "y"), coords={"x": [10, 20]}) data In this case, we have generated a 2D array, assigned the names *x* and *y* to the two dimensions respectively and associated two *coordinate labels* '10' and '20' with the two locations along the x dimension. If you supply a pandas :py:class:`~pandas.Series` or :py:class:`~pandas.DataFrame`, metadata is copied directly: .. ipython:: python - xr.DataArray(pd.Series(range(3), index=list('abc'), name='foo')) + xr.DataArray(pd.Series(range(3), index=list("abc"), name="foo")) Here are the key properties for a ``DataArray``: @@ -48,7 +46,7 @@ Here are the key properties for a ``DataArray``: Indexing -------- -xarray supports four kind of indexing. Since we have assigned coordinate labels to the x dimension we can use label-based indexing along that dimension just like pandas. The four examples below all yield the same result (the value at `x=10`) but at varying levels of convenience and intuitiveness. +xarray supports four kinds of indexing. Since we have assigned coordinate labels to the x dimension we can use label-based indexing along that dimension just like pandas. The four examples below all yield the same result (the value at `x=10`) but at varying levels of convenience and intuitiveness. .. ipython:: python @@ -75,13 +73,13 @@ While you're setting up your DataArray, it's often a good idea to set metadata a .. ipython:: python - data.attrs['long_name'] = 'random velocity' - data.attrs['units'] = 'metres/sec' - data.attrs['description'] = 'A random variable created as an example.' - data.attrs['random_attribute'] = 123 + data.attrs["long_name"] = "random velocity" + data.attrs["units"] = "metres/sec" + data.attrs["description"] = "A random variable created as an example." + data.attrs["random_attribute"] = 123 data.attrs # you can add metadata to coordinates too - data.x.attrs['units'] = 'x units' + data.x.attrs["units"] = "x units" Computation @@ -102,15 +100,15 @@ numbers: .. ipython:: python - data.mean(dim='x') + data.mean(dim="x") Arithmetic operations broadcast based on dimension name. This means you don't need to insert dummy dimensions for alignment: .. ipython:: python - a = xr.DataArray(np.random.randn(3), [data.coords['y']]) - b = xr.DataArray(np.random.randn(4), dims='z') + a = xr.DataArray(np.random.randn(3), [data.coords["y"]]) + b = xr.DataArray(np.random.randn(4), dims="z") a b @@ -139,9 +137,9 @@ xarray supports grouped operations using a very similar API to pandas (see :ref: .. ipython:: python - labels = xr.DataArray(['E', 'F', 'E'], [data.coords['y']], name='labels') + labels = xr.DataArray(["E", "F", "E"], [data.coords["y"]], name="labels") labels - data.groupby(labels).mean('y') + data.groupby(labels).mean("y") data.groupby(labels).map(lambda x: x - x.min()) Plotting @@ -155,7 +153,7 @@ Visualizing your datasets is quick and convenient: data.plot() Note the automatic labeling with names and units. Our effort in adding metadata attributes has paid off! Many aspects of these figures are customizable: see :ref:`plotting`. - + pandas ------ @@ -178,7 +176,7 @@ objects. You can think of it as a multi-dimensional generalization of the .. ipython:: python - ds = xr.Dataset({'foo': data, 'bar': ('x', [1, 2]), 'baz': np.pi}) + ds = xr.Dataset({"foo": data, "bar": ("x", [1, 2]), "baz": np.pi}) ds @@ -186,7 +184,7 @@ This creates a dataset with three DataArrays named ``foo``, ``bar`` and ``baz``. .. ipython:: python - ds['foo'] + ds["foo"] ds.foo @@ -216,14 +214,15 @@ You can directly read and write xarray objects to disk using :py:meth:`~xarray.D .. ipython:: python - ds.to_netcdf('example.nc') - xr.open_dataset('example.nc') + ds.to_netcdf("example.nc") + xr.open_dataset("example.nc") .. ipython:: python - :suppress: + :suppress: import os - os.remove('example.nc') + + os.remove("example.nc") It is common for datasets to be distributed across multiple files (commonly one file per timestep). xarray supports this use-case by providing the :py:meth:`~xarray.open_mfdataset` and the :py:meth:`~xarray.save_mfdataset` methods. For more, see :ref:`io`. diff --git a/doc/related-projects.rst b/doc/related-projects.rst index 57b8da0c447..456cb64197f 100644 --- a/doc/related-projects.rst +++ b/doc/related-projects.rst @@ -3,9 +3,11 @@ Xarray related projects ----------------------- -Here below is a list of existing open source projects that build +Below is a list of existing open source projects that build functionality upon xarray. See also section :ref:`internals` for more -details on how to build xarray extensions. +details on how to build xarray extensions. We also maintain the +`xarray-contrib `_ GitHub organization +as a place to curate projects that build upon xarray. Geosciences ~~~~~~~~~~~ @@ -36,10 +38,11 @@ Geosciences harmonic wind analysis in Python. - `wrf-python `_: A collection of diagnostic and interpolation routines for use with output of the Weather Research and Forecasting (WRF-ARW) Model. - `xarray-simlab `_: xarray extension for computer model simulations. +- `xarray-spatial `_: Numba-accelerated raster-based spatial processing tools (NDVI, curvature, zonal-statistics, proximity, hillshading, viewshed, etc.) - `xarray-topo `_: xarray extension for topographic analysis and modelling. - `xbpch `_: xarray interface for bpch files. - `xclim `_: A library for calculating climate science indices with unit handling built from xarray and dask. -- `xESMF `_: Universal Regridder for Geospatial Data. +- `xESMF `_: Universal regridder for geospatial data. - `xgcm `_: Extends the xarray data model to understand finite volume grid cells (common in General Circulation Models) and provides interpolation and difference operations for such grids. - `xmitgcm `_: a python package for reading `MITgcm `_ binary MDS files into xarray data structures. - `xshape `_: Tools for working with shapefiles, topographies, and polygons in xarray. @@ -55,6 +58,7 @@ Other domains ~~~~~~~~~~~~~ - `ptsa `_: EEG Time Series Analysis - `pycalphad `_: Computational Thermodynamics in Python +- `pyomeca `_: Python framework for biomechanical analysis Extend xarray capabilities ~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -62,19 +66,23 @@ Extend xarray capabilities - `eofs `_: EOF analysis in Python. - `hypothesis-gufunc `_: Extension to hypothesis. Makes it easy to write unit tests with xarray objects as input. - `nxarray `_: NeXus input/output capability for xarray. +- `xarray-compare `_: xarray extension for data comparison. +- `xarray-custom `_: Data classes for custom xarray creation. - `xarray_extras `_: Advanced algorithms for xarray objects (e.g. integrations/interpolations). - `xpublish `_: Publish Xarray Datasets via a Zarr compatible REST API. - `xrft `_: Fourier transforms for xarray data. - `xr-scipy `_: A lightweight scipy wrapper for xarray. - `X-regression `_: Multiple linear regression from Statsmodels library coupled with Xarray library. -- `xskillscore `_: Metrics for verifying forecasts. +- `xskillscore `_: Metrics for verifying forecasts. - `xyzpy `_: Easily generate high dimensional data, including parallelization. Visualization ~~~~~~~~~~~~~ -- `Datashader `_, `geoviews `_, `holoviews `_, : visualization packages for large data. +- `datashader `_, `geoviews `_, `holoviews `_, : visualization packages for large data. - `hvplot `_ : A high-level plotting API for the PyData ecosystem built on HoloViews. - `psyplot `_: Interactive data visualization with python. +- `xarray-leaflet `_: An xarray extension for tiled map plotting based on ipyleaflet. +- `xtrude `_: An xarray extension for 3D terrain visualization based on pydeck. Non-Python projects ~~~~~~~~~~~~~~~~~~~ diff --git a/doc/reshaping.rst b/doc/reshaping.rst index 465ca14dfc2..81fd4a6d35e 100644 --- a/doc/reshaping.rst +++ b/doc/reshaping.rst @@ -7,25 +7,26 @@ Reshaping and reorganizing data These methods allow you to reorganize .. ipython:: python - :suppress: + :suppress: import numpy as np import pandas as pd import xarray as xr + np.random.seed(123456) Reordering dimensions --------------------- To reorder dimensions on a :py:class:`~xarray.DataArray` or across all variables -on a :py:class:`~xarray.Dataset`, use :py:meth:`~xarray.DataArray.transpose`. An +on a :py:class:`~xarray.Dataset`, use :py:meth:`~xarray.DataArray.transpose`. An ellipsis (`...`) can be use to represent all other dimensions: .. ipython:: python - ds = xr.Dataset({'foo': (('x', 'y', 'z'), [[[42]]]), 'bar': (('y', 'z'), [[24]])}) - ds.transpose('y', 'z', 'x') - ds.transpose(..., 'x') # equivalent + ds = xr.Dataset({"foo": (("x", "y", "z"), [[[42]]]), "bar": (("y", "z"), [[24]])}) + ds.transpose("y", "z", "x") + ds.transpose(..., "x") # equivalent ds.transpose() # reverses all dimensions Expand and squeeze dimensions @@ -37,7 +38,7 @@ use :py:meth:`~xarray.DataArray.expand_dims` .. ipython:: python - expanded = ds.expand_dims('w') + expanded = ds.expand_dims("w") expanded This method attaches a new dimension with size 1 to all data variables. @@ -48,7 +49,7 @@ use :py:meth:`~xarray.DataArray.squeeze` .. ipython:: python - expanded.squeeze('w') + expanded.squeeze("w") Converting between datasets and arrays -------------------------------------- @@ -69,14 +70,14 @@ To convert back from a DataArray to a Dataset, use .. ipython:: python - arr.to_dataset(dim='variable') + arr.to_dataset(dim="variable") The broadcasting behavior of ``to_array`` means that the resulting array includes the union of data variable dimensions: .. ipython:: python - ds2 = xr.Dataset({'a': 0, 'b': ('x', [3, 4, 5])}) + ds2 = xr.Dataset({"a": 0, "b": ("x", [3, 4, 5])}) # the input dataset has 4 elements ds2 @@ -90,7 +91,7 @@ If you use ``to_dataset`` without supplying the ``dim`` argument, the DataArray .. ipython:: python - arr.to_dataset(name='combined') + arr.to_dataset(name="combined") .. _reshape.stack: @@ -103,11 +104,12 @@ implemented :py:meth:`~xarray.DataArray.stack` and .. ipython:: python - array = xr.DataArray(np.random.randn(2, 3), - coords=[('x', ['a', 'b']), ('y', [0, 1, 2])]) - stacked = array.stack(z=('x', 'y')) + array = xr.DataArray( + np.random.randn(2, 3), coords=[("x", ["a", "b"]), ("y", [0, 1, 2])] + ) + stacked = array.stack(z=("x", "y")) stacked - stacked.unstack('z') + stacked.unstack("z") As elsewhere in xarray, an ellipsis (`...`) can be used to represent all unlisted dimensions: @@ -128,15 +130,15 @@ possible levels. Missing levels are filled in with ``NaN`` in the resulting obje stacked2 = stacked[::2] stacked2 - stacked2.unstack('z') + stacked2.unstack("z") However, xarray's ``stack`` has an important difference from pandas: unlike pandas, it does not automatically drop missing values. Compare: .. ipython:: python - array = xr.DataArray([[np.nan, 1], [2, 3]], dims=['x', 'y']) - array.stack(z=('x', 'y')) + array = xr.DataArray([[np.nan, 1], [2, 3]], dims=["x", "y"]) + array.stack(z=("x", "y")) array.to_pandas().stack() We departed from pandas's behavior here because predictable shapes for new @@ -166,16 +168,15 @@ like this: .. ipython:: python - data = xr.Dataset( - data_vars={'a': (('x', 'y'), [[0, 1, 2], [3, 4, 5]]), - 'b': ('x', [6, 7])}, - coords={'y': ['u', 'v', 'w']} - ) - data - stacked = data.to_stacked_array("z", sample_dims=['x']) - stacked - unstacked = stacked.to_unstacked_dataset("z") - unstacked + data = xr.Dataset( + data_vars={"a": (("x", "y"), [[0, 1, 2], [3, 4, 5]]), "b": ("x", [6, 7])}, + coords={"y": ["u", "v", "w"]}, + ) + data + stacked = data.to_stacked_array("z", sample_dims=["x"]) + stacked + unstacked = stacked.to_unstacked_dataset("z") + unstacked In this example, ``stacked`` is a two dimensional array that we can easily pass to a scikit-learn or another generic numerical method. @@ -202,19 +203,23 @@ coordinates using :py:meth:`~xarray.DataArray.set_index`: .. ipython:: python - da = xr.DataArray(np.random.rand(4), - coords={'band': ('x', ['a', 'a', 'b', 'b']), - 'wavenumber': ('x', np.linspace(200, 400, 4))}, - dims='x') - da - mda = da.set_index(x=['band', 'wavenumber']) - mda + da = xr.DataArray( + np.random.rand(4), + coords={ + "band": ("x", ["a", "a", "b", "b"]), + "wavenumber": ("x", np.linspace(200, 400, 4)), + }, + dims="x", + ) + da + mda = da.set_index(x=["band", "wavenumber"]) + mda These coordinates can now be used for indexing, e.g., .. ipython:: python - mda.sel(band='a') + mda.sel(band="a") Conversely, you can use :py:meth:`~xarray.DataArray.reset_index` to extract multi-index levels as coordinates (this is mainly useful @@ -222,27 +227,27 @@ for serialization): .. ipython:: python - mda.reset_index('x') + mda.reset_index("x") :py:meth:`~xarray.DataArray.reorder_levels` allows changing the order of multi-index levels: .. ipython:: python - mda.reorder_levels(x=['wavenumber', 'band']) + mda.reorder_levels(x=["wavenumber", "band"]) As of xarray v0.9 coordinate labels for each dimension are optional. -You can also use ``.set_index`` / ``.reset_index`` to add / remove +You can also use ``.set_index`` / ``.reset_index`` to add / remove labels for one or several dimensions: .. ipython:: python - array = xr.DataArray([1, 2, 3], dims='x') + array = xr.DataArray([1, 2, 3], dims="x") array - array['c'] = ('x', ['a', 'b', 'c']) - array.set_index(x='c') - array = array.set_index(x='c') - array = array.reset_index('x', drop=True) + array["c"] = ("x", ["a", "b", "c"]) + array.set_index(x="c") + array = array.set_index(x="c") + array = array.reset_index("x", drop=True) .. _reshape.shift_and_roll: @@ -254,9 +259,9 @@ To adjust coordinate labels, you can use the :py:meth:`~xarray.Dataset.shift` an .. ipython:: python - array = xr.DataArray([1, 2, 3, 4], dims='x') - array.shift(x=2) - array.roll(x=2, roll_coords=True) + array = xr.DataArray([1, 2, 3, 4], dims="x") + array.shift(x=2) + array.roll(x=2, roll_coords=True) .. _reshape.sort: @@ -269,17 +274,18 @@ One may sort a DataArray/Dataset via :py:meth:`~xarray.DataArray.sortby` and .. ipython:: python - ds = xr.Dataset({'A': (('x', 'y'), [[1, 2], [3, 4]]), - 'B': (('x', 'y'), [[5, 6], [7, 8]])}, - coords={'x': ['b', 'a'], 'y': [1, 0]}) - dax = xr.DataArray([100, 99], [('x', [0, 1])]) - day = xr.DataArray([90, 80], [('y', [0, 1])]) - ds.sortby([day, dax]) + ds = xr.Dataset( + {"A": (("x", "y"), [[1, 2], [3, 4]]), "B": (("x", "y"), [[5, 6], [7, 8]])}, + coords={"x": ["b", "a"], "y": [1, 0]}, + ) + dax = xr.DataArray([100, 99], [("x", [0, 1])]) + day = xr.DataArray([90, 80], [("y", [0, 1])]) + ds.sortby([day, dax]) As a shortcut, you can refer to existing coordinates by name: .. ipython:: python - ds.sortby('x') - ds.sortby(['y', 'x']) - ds.sortby(['y', 'x'], ascending=False) + ds.sortby("x") + ds.sortby(["y", "x"]) + ds.sortby(["y", "x"], ascending=False) diff --git a/doc/roadmap.rst b/doc/roadmap.rst index 401dac779ad..1cbbaf8ef42 100644 --- a/doc/roadmap.rst +++ b/doc/roadmap.rst @@ -224,6 +224,8 @@ Current core developers - Tom Nicholas - Guido Imperiale - Justus Magin +- Mathias Hauser +- Anderson Banihirwe NumFOCUS ~~~~~~~~ diff --git a/doc/terminology.rst b/doc/terminology.rst index ab6d856920a..3cfc211593f 100644 --- a/doc/terminology.rst +++ b/doc/terminology.rst @@ -4,40 +4,111 @@ Terminology =========== -*Xarray terminology differs slightly from CF, mathematical conventions, and pandas; and therefore using xarray, understanding the documentation, and parsing error messages is easier once key terminology is defined. This glossary was designed so that more fundamental concepts come first. Thus for new users, this page is best read top-to-bottom. Throughout the glossary,* ``arr`` *will refer to an xarray* :py:class:`DataArray` *in any small examples. For more complete examples, please consult the relevant documentation.* - ----- - -**DataArray:** A multi-dimensional array with labeled or named dimensions. ``DataArray`` objects add metadata such as dimension names, coordinates, and attributes (defined below) to underlying "unlabeled" data structures such as numpy and Dask arrays. If its optional ``name`` property is set, it is a *named DataArray*. - ----- - -**Dataset:** A dict-like collection of ``DataArray`` objects with aligned dimensions. Thus, most operations that can be performed on the dimensions of a single ``DataArray`` can be performed on a dataset. Datasets have data variables (see **Variable** below), dimensions, coordinates, and attributes. - ----- - -**Variable:** A `NetCDF-like variable `_ consisting of dimensions, data, and attributes which describe a single array. The main functional difference between variables and numpy arrays is that numerical operations on variables implement array broadcasting by dimension name. Each ``DataArray`` has an underlying variable that can be accessed via ``arr.variable``. However, a variable is not fully described outside of either a ``Dataset`` or a ``DataArray``. - -.. note:: - - The :py:class:`Variable` class is low-level interface and can typically be ignored. However, the word "variable" appears often enough in the code and documentation that is useful to understand. - ----- - -**Dimension:** In mathematics, the *dimension* of data is loosely the number of degrees of freedom for it. A *dimension axis* is a set of all points in which all but one of these degrees of freedom is fixed. We can think of each dimension axis as having a name, for example the "x dimension". In xarray, a ``DataArray`` object's *dimensions* are its named dimension axes, and the name of the ``i``-th dimension is ``arr.dims[i]``. If an array is created without dimensions, the default dimension names are ``dim_0``, ``dim_1``, and so forth. - ----- - -**Coordinate:** An array that labels a dimension or set of dimensions of another ``DataArray``. In the usual one-dimensional case, the coordinate array's values can loosely be thought of as tick labels along a dimension. There are two types of coordinate arrays: *dimension coordinates* and *non-dimension coordinates* (see below). A coordinate named ``x`` can be retrieved from ``arr.coords[x]``. A ``DataArray`` can have more coordinates than dimensions because a single dimension can be labeled by multiple coordinate arrays. However, only one coordinate array can be a assigned as a particular dimension's dimension coordinate array. As a consequence, ``len(arr.dims) <= len(arr.coords)`` in general. - ----- - -**Dimension coordinate:** A one-dimensional coordinate array assigned to ``arr`` with both a name and dimension name in ``arr.dims``. Dimension coordinates are used for label-based indexing and alignment, like the index found on a :py:class:`pandas.DataFrame` or :py:class:`pandas.Series`. In fact, dimension coordinates use :py:class:`pandas.Index` objects under the hood for efficient computation. Dimension coordinates are marked by ``*`` when printing a ``DataArray`` or ``Dataset``. - ----- - -**Non-dimension coordinate:** A coordinate array assigned to ``arr`` with a name in ``arr.coords`` but *not* in ``arr.dims``. These coordinates arrays can be one-dimensional or multidimensional, and they are useful for auxiliary labeling. As an example, multidimensional coordinates are often used in geoscience datasets when :doc:`the data's physical coordinates (such as latitude and longitude) differ from their logical coordinates `. However, non-dimension coordinates are not indexed, and any operation on non-dimension coordinates that leverages indexing will fail. Printing ``arr.coords`` will print all of ``arr``'s coordinate names, with the corresponding dimension(s) in parentheses. For example, ``coord_name (dim_name) 1 2 3 ...``. - ----- - -**Index:** An *index* is a data structure optimized for efficient selecting and slicing of an associated array. Xarray creates indexes for dimension coordinates so that operations along dimensions are fast, while non-dimension coordinates are not indexed. Under the hood, indexes are implemented as :py:class:`pandas.Index` objects. The index associated with dimension name ``x`` can be retrieved by ``arr.indexes[x]``. By construction, ``len(arr.dims) == len(arr.indexes)`` +*Xarray terminology differs slightly from CF, mathematical conventions, and +pandas; so we've put together a glossary of its terms. Here,* ``arr`` * +refers to an xarray* :py:class:`DataArray` *in the examples. For more +complete examples, please consult the relevant documentation.* + +.. glossary:: + + DataArray + A multi-dimensional array with labeled or named + dimensions. ``DataArray`` objects add metadata such as dimension names, + coordinates, and attributes (defined below) to underlying "unlabeled" + data structures such as numpy and Dask arrays. If its optional ``name`` + property is set, it is a *named DataArray*. + + Dataset + A dict-like collection of ``DataArray`` objects with aligned + dimensions. Thus, most operations that can be performed on the + dimensions of a single ``DataArray`` can be performed on a + dataset. Datasets have data variables (see **Variable** below), + dimensions, coordinates, and attributes. + + Variable + A `NetCDF-like variable + `_ + consisting of dimensions, data, and attributes which describe a single + array. The main functional difference between variables and numpy arrays + is that numerical operations on variables implement array broadcasting + by dimension name. Each ``DataArray`` has an underlying variable that + can be accessed via ``arr.variable``. However, a variable is not fully + described outside of either a ``Dataset`` or a ``DataArray``. + + .. note:: + + The :py:class:`Variable` class is low-level interface and can + typically be ignored. However, the word "variable" appears often + enough in the code and documentation that is useful to understand. + + Dimension + In mathematics, the *dimension* of data is loosely the number of degrees + of freedom for it. A *dimension axis* is a set of all points in which + all but one of these degrees of freedom is fixed. We can think of each + dimension axis as having a name, for example the "x dimension". In + xarray, a ``DataArray`` object's *dimensions* are its named dimension + axes, and the name of the ``i``-th dimension is ``arr.dims[i]``. If an + array is created without dimension names, the default dimension names are + ``dim_0``, ``dim_1``, and so forth. + + Coordinate + An array that labels a dimension or set of dimensions of another + ``DataArray``. In the usual one-dimensional case, the coordinate array's + values can loosely be thought of as tick labels along a dimension. There + are two types of coordinate arrays: *dimension coordinates* and + *non-dimension coordinates* (see below). A coordinate named ``x`` can be + retrieved from ``arr.coords[x]``. A ``DataArray`` can have more + coordinates than dimensions because a single dimension can be labeled by + multiple coordinate arrays. However, only one coordinate array can be a + assigned as a particular dimension's dimension coordinate array. As a + consequence, ``len(arr.dims) <= len(arr.coords)`` in general. + + Dimension coordinate + A one-dimensional coordinate array assigned to ``arr`` with both a name + and dimension name in ``arr.dims``. Dimension coordinates are used for + label-based indexing and alignment, like the index found on a + :py:class:`pandas.DataFrame` or :py:class:`pandas.Series`. In fact, + dimension coordinates use :py:class:`pandas.Index` objects under the + hood for efficient computation. Dimension coordinates are marked by + ``*`` when printing a ``DataArray`` or ``Dataset``. + + Non-dimension coordinate + A coordinate array assigned to ``arr`` with a name in ``arr.coords`` but + *not* in ``arr.dims``. These coordinates arrays can be one-dimensional + or multidimensional, and they are useful for auxiliary labeling. As an + example, multidimensional coordinates are often used in geoscience + datasets when :doc:`the data's physical coordinates (such as latitude + and longitude) differ from their logical coordinates + `. However, non-dimension coordinates + are not indexed, and any operation on non-dimension coordinates that + leverages indexing will fail. Printing ``arr.coords`` will print all of + ``arr``'s coordinate names, with the corresponding dimension(s) in + parentheses. For example, ``coord_name (dim_name) 1 2 3 ...``. + + Index + An *index* is a data structure optimized for efficient selecting and + slicing of an associated array. Xarray creates indexes for dimension + coordinates so that operations along dimensions are fast, while + non-dimension coordinates are not indexed. Under the hood, indexes are + implemented as :py:class:`pandas.Index` objects. The index associated + with dimension name ``x`` can be retrieved by ``arr.indexes[x]``. By + construction, ``len(arr.dims) == len(arr.indexes)`` + + name + The names of dimensions, coordinates, DataArray objects and data + variables can be anything as long as they are :term:`hashable`. However, + it is preferred to use :py:class:`str` typed names. + + scalar + By definition, a scalar is not an :term:`array` and when converted to + one, it has 0 dimensions. That means that, e.g., :py:class:`int`, + :py:class:`float`, and :py:class:`str` objects are "scalar" while + :py:class:`list` or :py:class:`tuple` are not. + + duck array + `Duck arrays`__ are array implementations that behave + like numpy arrays. They have to define the ``shape``, ``dtype`` and + ``ndim`` properties. For integration with ``xarray``, the ``__array__``, + ``__array_ufunc__`` and ``__array_function__`` protocols are also required. + + __ https://numpy.org/neps/nep-0022-ndarray-duck-typing-overview.html diff --git a/doc/time-series.rst b/doc/time-series.rst index d838dbbd4cd..96a2edc0ea5 100644 --- a/doc/time-series.rst +++ b/doc/time-series.rst @@ -10,11 +10,12 @@ data in pandas such a joy to xarray. In most cases, we rely on pandas for the core functionality. .. ipython:: python - :suppress: + :suppress: import numpy as np import pandas as pd import xarray as xr + np.random.seed(123456) Creating datetime64 data @@ -29,8 +30,8 @@ using :py:func:`pandas.to_datetime` and :py:func:`pandas.date_range`: .. ipython:: python - pd.to_datetime(['2000-01-01', '2000-02-02']) - pd.date_range('2000-01-01', periods=365) + pd.to_datetime(["2000-01-01", "2000-02-02"]) + pd.date_range("2000-01-01", periods=365) Alternatively, you can supply arrays of Python ``datetime`` objects. These get converted automatically when used as arguments in xarray objects: @@ -38,7 +39,8 @@ converted automatically when used as arguments in xarray objects: .. ipython:: python import datetime - xr.Dataset({'time': datetime.datetime(2000, 1, 1)}) + + xr.Dataset({"time": datetime.datetime(2000, 1, 1)}) When reading or writing netCDF files, xarray automatically decodes datetime and timedelta arrays using `CF conventions`_ (that is, by using a ``units`` @@ -62,8 +64,8 @@ You can manual decode arrays in this form by passing a dataset to .. ipython:: python - attrs = {'units': 'hours since 2000-01-01'} - ds = xr.Dataset({'time': ('time', [0, 1, 2, 3], attrs)}) + attrs = {"units": "hours since 2000-01-01"} + ds = xr.Dataset({"time": ("time", [0, 1, 2, 3], attrs)}) xr.decode_cf(ds) One unfortunate limitation of using ``datetime64[ns]`` is that it limits the @@ -87,10 +89,10 @@ items and with the `slice` object: .. ipython:: python - time = pd.date_range('2000-01-01', freq='H', periods=365 * 24) - ds = xr.Dataset({'foo': ('time', np.arange(365 * 24)), 'time': time}) - ds.sel(time='2000-01') - ds.sel(time=slice('2000-06-01', '2000-06-10')) + time = pd.date_range("2000-01-01", freq="H", periods=365 * 24) + ds = xr.Dataset({"foo": ("time", np.arange(365 * 24)), "time": time}) + ds.sel(time="2000-01") + ds.sel(time=slice("2000-06-01", "2000-06-10")) You can also select a particular time by indexing with a :py:class:`datetime.time` object: @@ -113,8 +115,8 @@ given ``DataArray`` can be quickly computed using a special ``.dt`` accessor. .. ipython:: python - time = pd.date_range('2000-01-01', freq='6H', periods=365 * 4) - ds = xr.Dataset({'foo': ('time', np.arange(365 * 4)), 'time': time}) + time = pd.date_range("2000-01-01", freq="6H", periods=365 * 4) + ds = xr.Dataset({"foo": ("time", np.arange(365 * 4)), "time": time}) ds.time.dt.hour ds.time.dt.dayofweek @@ -130,16 +132,16 @@ __ http://pandas.pydata.org/pandas-docs/stable/api.html#time-date-components .. ipython:: python - ds['time.month'] - ds['time.dayofyear'] + ds["time.month"] + ds["time.dayofyear"] For use as a derived coordinate, xarray adds ``'season'`` to the list of datetime components supported by pandas: .. ipython:: python - ds['time.season'] - ds['time'].dt.season + ds["time.season"] + ds["time"].dt.season The set of valid seasons consists of 'DJF', 'MAM', 'JJA' and 'SON', labeled by the first letters of the corresponding months. @@ -152,7 +154,7 @@ __ http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases .. ipython:: python - ds['time'].dt.floor('D') + ds["time"].dt.floor("D") The ``.dt`` accessor can also be used to generate formatted datetime strings for arrays utilising the same formatting as the standard `datetime.strftime`_. @@ -161,7 +163,7 @@ for arrays utilising the same formatting as the standard `datetime.strftime`_. .. ipython:: python - ds['time'].dt.strftime('%a, %b %d %H:%M') + ds["time"].dt.strftime("%a, %b %d %H:%M") .. _resampling: @@ -173,9 +175,9 @@ Datetime components couple particularly well with grouped operations (see calculate the mean by time of day: .. ipython:: python - :okwarning: + :okwarning: - ds.groupby('time.hour').mean() + ds.groupby("time.hour").mean() For upsampling or downsampling temporal resolutions, xarray offers a :py:meth:`~xarray.Dataset.resample` method building on the core functionality @@ -187,25 +189,25 @@ same api as ``resample`` `in pandas`_. For example, we can downsample our dataset from hourly to 6-hourly: .. ipython:: python - :okwarning: + :okwarning: - ds.resample(time='6H') + ds.resample(time="6H") This will create a specialized ``Resample`` object which saves information necessary for resampling. All of the reduction methods which work with ``Resample`` objects can also be used for resampling: .. ipython:: python - :okwarning: + :okwarning: - ds.resample(time='6H').mean() + ds.resample(time="6H").mean() You can also supply an arbitrary reduction function to aggregate over each resampling group: .. ipython:: python - ds.resample(time='6H').reduce(np.mean) + ds.resample(time="6H").reduce(np.mean) For upsampling, xarray provides six methods: ``asfreq``, ``ffill``, ``bfill``, ``pad``, ``nearest`` and ``interpolate``. ``interpolate`` extends ``scipy.interpolate.interp1d`` @@ -218,7 +220,7 @@ Data that has indices outside of the given ``tolerance`` are set to ``NaN``. .. ipython:: python - ds.resample(time='1H').nearest(tolerance='1H') + ds.resample(time="1H").nearest(tolerance="1H") For more examples of using grouped operations on a time dimension, see diff --git a/doc/weather-climate.rst b/doc/weather-climate.rst index 768cf6556f9..db612d74859 100644 --- a/doc/weather-climate.rst +++ b/doc/weather-climate.rst @@ -4,7 +4,7 @@ Weather and climate data ======================== .. ipython:: python - :suppress: + :suppress: import xarray as xr @@ -56,11 +56,14 @@ coordinate with dates from a no-leap calendar and a .. ipython:: python - from itertools import product - from cftime import DatetimeNoLeap - dates = [DatetimeNoLeap(year, month, 1) for year, month in - product(range(1, 3), range(1, 13))] - da = xr.DataArray(np.arange(24), coords=[dates], dims=['time'], name='foo') + from itertools import product + from cftime import DatetimeNoLeap + + dates = [ + DatetimeNoLeap(year, month, 1) + for year, month in product(range(1, 3), range(1, 13)) + ] + da = xr.DataArray(np.arange(24), coords=[dates], dims=["time"], name="foo") xarray also includes a :py:func:`~xarray.cftime_range` function, which enables creating a :py:class:`~xarray.CFTimeIndex` with regularly-spaced dates. For @@ -68,30 +71,50 @@ instance, we can create the same dates and DataArray we created above using: .. ipython:: python - dates = xr.cftime_range(start='0001', periods=24, freq='MS', calendar='noleap') - da = xr.DataArray(np.arange(24), coords=[dates], dims=['time'], name='foo') + dates = xr.cftime_range(start="0001", periods=24, freq="MS", calendar="noleap") + da = xr.DataArray(np.arange(24), coords=[dates], dims=["time"], name="foo") + +Mirroring pandas' method with the same name, :py:meth:`~xarray.infer_freq` allows one to +infer the sampling frequency of a :py:class:`~xarray.CFTimeIndex` or a 1-D +:py:class:`~xarray.DataArray` containing cftime objects. It also works transparently with +``np.datetime64[ns]`` and ``np.timedelta64[ns]`` data. + +.. ipython:: python + + xr.infer_freq(dates) With :py:meth:`~xarray.CFTimeIndex.strftime` we can also easily generate formatted strings from the datetime values of a :py:class:`~xarray.CFTimeIndex` directly or through the -:py:meth:`~xarray.DataArray.dt` accessor for a :py:class:`~xarray.DataArray` +``dt`` accessor for a :py:class:`~xarray.DataArray` using the same formatting as the standard `datetime.strftime`_ convention . .. _datetime.strftime: https://docs.python.org/3/library/datetime.html#strftime-strptime-behavior .. ipython:: python - dates.strftime('%c') - da['time'].dt.strftime('%Y%m%d') + dates.strftime("%c") + da["time"].dt.strftime("%Y%m%d") For data indexed by a :py:class:`~xarray.CFTimeIndex` xarray currently supports: -- `Partial datetime string indexing`_ using strictly `ISO 8601-format`_ partial - datetime strings: +- `Partial datetime string indexing`_: .. ipython:: python - da.sel(time='0001') - da.sel(time=slice('0001-05', '0002-02')) + da.sel(time="0001") + da.sel(time=slice("0001-05", "0002-02")) + +.. note:: + + + For specifying full or partial datetime strings in cftime + indexing, xarray supports two versions of the `ISO 8601 standard`_, the + basic pattern (YYYYMMDDhhmmss) or the extended pattern + (YYYY-MM-DDThh:mm:ss), as well as the default cftime string format + (YYYY-MM-DD hh:mm:ss). This is somewhat more restrictive than pandas; + in other words, some datetime strings that would be valid for a + :py:class:`pandas.DatetimeIndex` are not valid for an + :py:class:`~xarray.CFTimeIndex`. - Access of basic datetime components via the ``dt`` accessor (in this case just "year", "month", "day", "hour", "minute", "second", "microsecond", @@ -99,64 +122,65 @@ For data indexed by a :py:class:`~xarray.CFTimeIndex` xarray currently supports: .. ipython:: python - da.time.dt.year - da.time.dt.month - da.time.dt.season - da.time.dt.dayofyear - da.time.dt.dayofweek - da.time.dt.days_in_month + da.time.dt.year + da.time.dt.month + da.time.dt.season + da.time.dt.dayofyear + da.time.dt.dayofweek + da.time.dt.days_in_month - Rounding of datetimes to fixed frequencies via the ``dt`` accessor: .. ipython:: python - da.time.dt.ceil('3D') - da.time.dt.floor('5D') - da.time.dt.round('2D') - + da.time.dt.ceil("3D") + da.time.dt.floor("5D") + da.time.dt.round("2D") + - Group-by operations based on datetime accessor attributes (e.g. by month of the year): .. ipython:: python - da.groupby('time.month').sum() + da.groupby("time.month").sum() - Interpolation using :py:class:`cftime.datetime` objects: .. ipython:: python - da.interp(time=[DatetimeNoLeap(1, 1, 15), DatetimeNoLeap(1, 2, 15)]) + da.interp(time=[DatetimeNoLeap(1, 1, 15), DatetimeNoLeap(1, 2, 15)]) - Interpolation using datetime strings: .. ipython:: python - da.interp(time=['0001-01-15', '0001-02-15']) + da.interp(time=["0001-01-15", "0001-02-15"]) - Differentiation: .. ipython:: python - da.differentiate('time') + da.differentiate("time") - Serialization: .. ipython:: python - da.to_netcdf('example-no-leap.nc') - xr.open_dataset('example-no-leap.nc') + da.to_netcdf("example-no-leap.nc") + xr.open_dataset("example-no-leap.nc") .. ipython:: python :suppress: import os - os.remove('example-no-leap.nc') + + os.remove("example-no-leap.nc") - And resampling along the time dimension for data indexed by a :py:class:`~xarray.CFTimeIndex`: .. ipython:: python - da.resample(time='81T', closed='right', label='right', base=3).mean() + da.resample(time="81T", closed="right", label="right", base=3).mean() .. note:: @@ -168,13 +192,13 @@ For data indexed by a :py:class:`~xarray.CFTimeIndex` xarray currently supports: method: .. ipython:: python - :okwarning: + :okwarning: - modern_times = xr.cftime_range('2000', periods=24, freq='MS', calendar='noleap') - da = xr.DataArray(range(24), [('time', modern_times)]) + modern_times = xr.cftime_range("2000", periods=24, freq="MS", calendar="noleap") + da = xr.DataArray(range(24), [("time", modern_times)]) da - datetimeindex = da.indexes['time'].to_datetimeindex() - da['time'] = datetimeindex + datetimeindex = da.indexes["time"].to_datetimeindex() + da["time"] = datetimeindex However in this case one should use caution to only perform operations which do not depend on differences between dates (e.g. differentiation, @@ -182,6 +206,6 @@ For data indexed by a :py:class:`~xarray.CFTimeIndex` xarray currently supports: and silent errors due to the difference in calendar types between the dates encoded in your data and the dates stored in memory. -.. _Timestamp-valid range: https://pandas.pydata.org/pandas-docs/stable/timeseries.html#timestamp-limitations -.. _ISO 8601-format: https://en.wikipedia.org/wiki/ISO_8601 -.. _partial datetime string indexing: https://pandas.pydata.org/pandas-docs/stable/timeseries.html#partial-string-indexing +.. _Timestamp-valid range: https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#timestamp-limitations +.. _ISO 8601 standard: https://en.wikipedia.org/wiki/ISO_8601 +.. _partial datetime string indexing: https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#partial-string-indexing diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c9a2ca2e41c..063e7cd9b64 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -4,28 +4,436 @@ What's New ========== .. ipython:: python - :suppress: + :suppress: import numpy as np import pandas as pd import xarray as xray import xarray import xarray as xr + np.random.seed(123456) + +.. _whats-new.0.16.3: + +v0.16.3 (unreleased) +-------------------- + +Breaking changes +~~~~~~~~~~~~~~~~ +- xarray no longer supports python 3.6 + + The minimum versions of some other dependencies were changed: + ============ ====== ==== + Package Old New + ============ ====== ==== + Python 3.6 3.7 + setuptools 38.4 40.4 + ============ ====== ==== + + (:issue:`4688`, :pull:`4720`) + By `Justus Magin `_. +- As a result of :pull:`4684` the default units encoding for + datetime-like values (``np.datetime64[ns]`` or ``cftime.datetime``) will now + always be set such that ``int64`` values can be used. In the past, no units + finer than "seconds" were chosen, which would sometimes mean that ``float64`` + values were required, which would lead to inaccurate I/O round-trips. +- remove deprecated ``autoclose`` kwargs from :py:func:`open_dataset` (:pull: `4725`). + By `Aureliana Barghini `_ + + +New Features +~~~~~~~~~~~~ +- Performance improvement when constructing DataArrays. Significantly speeds up repr for Datasets with large number of variables. + By `Deepak Cherian `_ + +Bug fixes +~~~~~~~~~ +- :py:meth:`DataArray.resample` and :py:meth:`Dataset.resample` do not trigger computations anymore if :py:meth:`Dataset.weighted` or :py:meth:`DataArray.weighted` are applied (:issue:`4625`, :pull:`4668`). By `Julius Busecke `_. +- :py:func:`merge` with ``combine_attrs='override'`` makes a copy of the attrs (:issue:`4627`). +- By default, when possible, xarray will now always use values of type ``int64`` when encoding + and decoding ``numpy.datetime64[ns]`` datetimes. This ensures that maximum + precision and accuracy are maintained in the round-tripping process + (:issue:`4045`, :pull:`4684`). It also enables encoding and decoding standard calendar + dates with time units of nanoseconds (:pull:`4400`). By `Spencer Clark + `_ and `Mark Harfouche `_. +- :py:meth:`DataArray.astype`, :py:meth:`Dataset.astype` and :py:meth:`Variable.astype` support + the ``order`` and ``subok`` parameters again. This fixes a regression introduced in version 0.16.1 + (:issue:`4644`, :pull:`4683`). + By `Richard Kleijn `_ . +- Remove dictionary unpacking when using ``.loc`` to avoid collision with ``.sel`` parameters (:pull:`4695`). + By `Anderson Banihirwe `_ +- Fix the legend created by :py:meth:`Dataset.plot.scatter` (:issue:`4641`, :pull:`4723`). + By `Justus Magin `_. +- Fix a crash in orthogonal indexing on geographic coordinates with ``engine='cfgrib'`` (:issue:`4733` :pull:`4737`). + By `Alessandro Amici `_ +- Coordinates with dtype ``str`` or ``bytes`` now retain their dtype on many operations, + e.g. ``reindex``, ``align``, ``concat``, ``assign``, previously they were cast to an object dtype + (:issue:`2658` and :issue:`4543`) by `Mathias Hauser `_. +- Limit number of data rows when printing large datasets. (:issue:`4736`, :pull:`4750`). By `Jimmy Westling `_. +- Add ``missing_dims`` parameter to transpose (:issue:`4647`, :pull:`4767`). By `Daniel Mesejo `_. +- Resolve intervals before appending other metadata to labels when plotting (:issue:`4322`, :pull:`4794`). + By `Justus Magin `_. +- Fix regression when decoding a variable with a ``scale_factor`` and ``add_offset`` given + as a list of length one (:issue:`4631`) by `Mathias Hauser `_. +- Expand user directory paths (e.g. ``~/``) in :py:func:`open_mfdataset` and + :py:meth:`Dataset.to_zarr` (:issue:`4783`, :pull:`4795`). + By `Julien Seguinot `_. + +Documentation +~~~~~~~~~~~~~ +- add information about requirements for accessor classes (:issue:`2788`, :pull:`4657`). + By `Justus Magin `_. +- start a list of external I/O integrating with ``xarray`` (:issue:`683`, :pull:`4566`). + By `Justus Magin `_. +- add concat examples and improve combining documentation (:issue:`4620`, :pull:`4645`). + By `Ray Bell `_ and + `Justus Magin `_. + +Internal Changes +~~~~~~~~~~~~~~~~ +- Speed up of the continuous integration tests on azure. + + - Switched to mamba and use matplotlib-base for a faster installation of all dependencies (:pull:`4672`). + - Use ``pytest.mark.skip`` instead of ``pytest.mark.xfail`` for some tests that can currently not + succeed (:pull:`4685`). + - Run the tests in parallel using pytest-xdist (:pull:`4694`). + + By `Justus Magin `_ and `Mathias Hauser `_. + +- Replace all usages of ``assert x.identical(y)`` with ``assert_identical(x, y)`` + for clearer error messages. + (:pull:`4752`); + By `Maximilian Roos `_. +- Speed up attribute style access (e.g. ``ds.somevar`` instead of ``ds["somevar"]``) and tab completion + in ipython (:issue:`4741`, :pull:`4742`). By `Richard Kleijn `_. + +.. _whats-new.0.16.2: + +v0.16.2 (30 Nov 2020) +--------------------- + +This release brings the ability to write to limited regions of ``zarr`` files, open zarr files with :py:func:`open_dataset` and :py:func:`open_mfdataset`, increased support for propagating ``attrs`` using the ``keep_attrs`` flag, as well as numerous bugfixes and documentation improvements. + +Many thanks to the 31 contributors who contributed to this release: +Aaron Spring, Akio Taniguchi, Aleksandar Jelenak, alexamici, Alexandre Poux, Anderson Banihirwe, Andrew Pauling, Ashwin Vishnu, aurghs, Brian Ward, Caleb, crusaderky, Dan Nowacki, darikg, David Brochart, David Huard, Deepak Cherian, Dion Häfner, Gerardo Rivera, Gerrit Holl, Illviljan, inakleinbottle, Jacob Tomlinson, James A. Bednar, jenssss, Joe Hamman, johnomotani, Joris Van den Bossche, Julia Kent, Julius Busecke, Kai Mühlbauer, keewis, Keisuke Fujii, Kyle Cranmer, Luke Volpatti, Mathias Hauser, Maximilian Roos, Michaël Defferrard, Michal Baumgartner, Nick R. Papior, Pascal Bourgault, Peter Hausamann, PGijsbers, Ray Bell, Romain Martinez, rpgoldman, Russell Manser, Sahid Velji, Samnan Rahee, Sander, Spencer Clark, Stephan Hoyer, Thomas Zilio, Tobias Kölling, Tom Augspurger, Wei Ji, Yash Saboo, Zeb Nicholls, + +Deprecations +~~~~~~~~~~~~ + +- :py:attr:`~core.accessor_dt.DatetimeAccessor.weekofyear` and :py:attr:`~core.accessor_dt.DatetimeAccessor.week` + have been deprecated. Use ``DataArray.dt.isocalendar().week`` + instead (:pull:`4534`). By `Mathias Hauser `_, + `Maximilian Roos `_, and `Spencer Clark `_. +- :py:attr:`DataArray.rolling` and :py:attr:`Dataset.rolling` no longer support passing ``keep_attrs`` + via its constructor. Pass ``keep_attrs`` via the applied function, i.e. use + ``ds.rolling(...).mean(keep_attrs=False)`` instead of ``ds.rolling(..., keep_attrs=False).mean()`` + Rolling operations now keep their attributes per default (:pull:`4510`). + By `Mathias Hauser `_. + +New Features +~~~~~~~~~~~~ + +- :py:func:`open_dataset` and :py:func:`open_mfdataset` + now works with ``engine="zarr"`` (:issue:`3668`, :pull:`4003`, :pull:`4187`). + By `Miguel Jimenez `_ and `Wei Ji Leong `_. +- Unary & binary operations follow the ``keep_attrs`` flag (:issue:`3490`, :issue:`4065`, :issue:`3433`, :issue:`3595`, :pull:`4195`). + By `Deepak Cherian `_. +- Added :py:meth:`~core.accessor_dt.DatetimeAccessor.isocalendar()` that returns a Dataset + with year, week, and weekday calculated according to the ISO 8601 calendar. Requires + pandas version 1.1.0 or greater (:pull:`4534`). By `Mathias Hauser `_, + `Maximilian Roos `_, and `Spencer Clark `_. +- :py:meth:`Dataset.to_zarr` now supports a ``region`` keyword for writing to + limited regions of existing Zarr stores (:pull:`4035`). + See :ref:`io.zarr.appending` for full details. + By `Stephan Hoyer `_. +- Added typehints in :py:func:`align` to reflect that the same type received in ``objects`` arg will be returned (:pull:`4522`). + By `Michal Baumgartner `_. +- :py:meth:`Dataset.weighted` and :py:meth:`DataArray.weighted` are now executing value checks lazily if weights are provided as dask arrays (:issue:`4541`, :pull:`4559`). + By `Julius Busecke `_. +- Added the ``keep_attrs`` keyword to ``rolling_exp.mean()``; it now keeps attributes + per default. By `Mathias Hauser `_ (:pull:`4592`). +- Added ``freq`` as property to :py:class:`CFTimeIndex` and into the + ``CFTimeIndex.repr``. (:issue:`2416`, :pull:`4597`) + By `Aaron Spring `_. + +Bug fixes +~~~~~~~~~ + +- Fix bug where reference times without padded years (e.g. ``since 1-1-1``) would lose their units when + being passed by ``encode_cf_datetime`` (:issue:`4422`, :pull:`4506`). Such units are ambiguous + about which digit represents the years (is it YMD or DMY?). Now, if such formatting is encountered, + it is assumed that the first digit is the years, they are padded appropriately (to e.g. ``since 0001-1-1``) + and a warning that this assumption is being made is issued. Previously, without ``cftime``, such times + would be silently parsed incorrectly (at least based on the CF conventions) e.g. "since 1-1-1" would + be parsed (via ``pandas`` and ``dateutil``) to ``since 2001-1-1``. + By `Zeb Nicholls `_. +- Fix :py:meth:`DataArray.plot.step`. By `Deepak Cherian `_. +- Fix bug where reading a scalar value from a NetCDF file opened with the ``h5netcdf`` backend would raise a ``ValueError`` when ``decode_cf=True`` (:issue:`4471`, :pull:`4485`). + By `Gerrit Holl `_. +- Fix bug where datetime64 times are silently changed to incorrect values if they are outside the valid date range for ns precision when provided in some other units (:issue:`4427`, :pull:`4454`). + By `Andrew Pauling `_ +- Fix silently overwriting the ``engine`` key when passing :py:func:`open_dataset` a file object + to an incompatible netCDF (:issue:`4457`). Now incompatible combinations of files and engines raise + an exception instead. By `Alessandro Amici `_. +- The ``min_count`` argument to :py:meth:`DataArray.sum()` and :py:meth:`DataArray.prod()` + is now ignored when not applicable, i.e. when ``skipna=False`` or when ``skipna=None`` + and the dtype does not have a missing value (:issue:`4352`). + By `Mathias Hauser `_. +- :py:func:`combine_by_coords` now raises an informative error when passing coordinates + with differing calendars (:issue:`4495`). By `Mathias Hauser `_. +- :py:attr:`DataArray.rolling` and :py:attr:`Dataset.rolling` now also keep the attributes and names of of (wrapped) + ``DataArray`` objects, previously only the global attributes were retained (:issue:`4497`, :pull:`4510`). + By `Mathias Hauser `_. +- Improve performance where reading small slices from huge dimensions was slower than necessary (:pull:`4560`). By `Dion Häfner `_. +- Fix bug where ``dask_gufunc_kwargs`` was silently changed in :py:func:`apply_ufunc` (:pull:`4576`). By `Kai Mühlbauer `_. + +Documentation +~~~~~~~~~~~~~ +- document the API not supported with duck arrays (:pull:`4530`). + By `Justus Magin `_. +- Mention the possibility to pass functions to :py:meth:`Dataset.where` or + :py:meth:`DataArray.where` in the parameter documentation (:issue:`4223`, :pull:`4613`). + By `Justus Magin `_. +- Update the docstring of :py:class:`DataArray` and :py:class:`Dataset`. + (:pull:`4532`); + By `Jimmy Westling `_. +- Raise a more informative error when :py:meth:`DataArray.to_dataframe` is + is called on a scalar, (:issue:`4228`); + By `Pieter Gijsbers `_. +- Fix grammar and typos in the :doc:`contributing` guide (:pull:`4545`). + By `Sahid Velji `_. +- Fix grammar and typos in the :doc:`io` guide (:pull:`4553`). + By `Sahid Velji `_. +- Update link to NumPy docstring standard in the :doc:`contributing` guide (:pull:`4558`). + By `Sahid Velji `_. +- Add docstrings to ``isnull`` and ``notnull``, and fix the displayed signature + (:issue:`2760`, :pull:`4618`). + By `Justus Magin `_. + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Optional dependencies can be installed along with xarray by specifying + extras as ``pip install "xarray[extra]"`` where ``extra`` can be one of ``io``, + ``accel``, ``parallel``, ``viz`` and ``complete``. See docs for updated + :ref:`installation instructions `. + (:issue:`2888`, :pull:`4480`). + By `Ashwin Vishnu `_, `Justus Magin + `_ and `Mathias Hauser + `_. +- Removed stray spaces that stem from black removing new lines (:pull:`4504`). + By `Mathias Hauser `_. +- Ensure tests are not skipped in the ``py38-all-but-dask`` test environment + (:issue:`4509`). By `Mathias Hauser `_. +- Ignore select numpy warnings around missing values, where xarray handles + the values appropriately, (:pull:`4536`); + By `Maximilian Roos `_. +- Replace the internal use of ``pd.Index.__or__`` and ``pd.Index.__and__`` with ``pd.Index.union`` + and ``pd.Index.intersection`` as they will stop working as set operations in the future + (:issue:`4565`). By `Mathias Hauser `_. +- Add GitHub action for running nightly tests against upstream dependencies (:pull:`4583`). + By `Anderson Banihirwe `_. +- Ensure all figures are closed properly in plot tests (:pull:`4600`). + By `Yash Saboo `_, `Nirupam K N + `_ and `Mathias Hauser + `_. + +.. _whats-new.0.16.1: + +v0.16.1 (2020-09-20) +--------------------- + +This patch release fixes an incompatibility with a recent pandas change, which +was causing an issue indexing with a ``datetime64``. It also includes +improvements to ``rolling``, ``to_dataframe``, ``cov`` & ``corr`` methods and +bug fixes. Our documentation has a number of improvements, including fixing all +doctests and confirming their accuracy on every commit. + +Many thanks to the 36 contributors who contributed to this release: + +Aaron Spring, Akio Taniguchi, Aleksandar Jelenak, Alexandre Poux, +Caleb, Dan Nowacki, Deepak Cherian, Gerardo Rivera, Jacob Tomlinson, James A. +Bednar, Joe Hamman, Julia Kent, Kai Mühlbauer, Keisuke Fujii, Mathias Hauser, +Maximilian Roos, Nick R. Papior, Pascal Bourgault, Peter Hausamann, Romain +Martinez, Russell Manser, Samnan Rahee, Sander, Spencer Clark, Stephan Hoyer, +Thomas Zilio, Tobias Kölling, Tom Augspurger, alexamici, crusaderky, darikg, +inakleinbottle, jenssss, johnomotani, keewis, and rpgoldman. + +Breaking changes +~~~~~~~~~~~~~~~~ + +- :py:meth:`DataArray.astype` and :py:meth:`Dataset.astype` now preserve attributes. Keep the + old behavior by passing `keep_attrs=False` (:issue:`2049`, :pull:`4314`). + By `Dan Nowacki `_ and `Gabriel Joel Mitchell `_. + +New Features +~~~~~~~~~~~~ + +- :py:meth:`~xarray.DataArray.rolling` and :py:meth:`~xarray.Dataset.rolling` + now accept more than 1 dimension. (:pull:`4219`) + By `Keisuke Fujii `_. +- :py:meth:`~xarray.DataArray.to_dataframe` and :py:meth:`~xarray.Dataset.to_dataframe` + now accept a ``dim_order`` parameter allowing to specify the resulting dataframe's + dimensions order (:issue:`4331`, :pull:`4333`). + By `Thomas Zilio `_. +- Support multiple outputs in :py:func:`xarray.apply_ufunc` when using + ``dask='parallelized'``. (:issue:`1815`, :pull:`4060`). + By `Kai Mühlbauer `_. +- ``min_count`` can be supplied to reductions such as ``.sum`` when specifying + multiple dimension to reduce over; (:pull:`4356`). + By `Maximilian Roos `_. +- :py:func:`xarray.cov` and :py:func:`xarray.corr` now handle missing values; (:pull:`4351`). + By `Maximilian Roos `_. +- Add support for parsing datetime strings formatted following the default + string representation of cftime objects, i.e. YYYY-MM-DD hh:mm:ss, in + partial datetime string indexing, as well as :py:meth:`~xarray.cftime_range` + (:issue:`4337`). By `Spencer Clark `_. +- Build ``CFTimeIndex.__repr__`` explicitly as :py:class:`pandas.Index`. Add ``calendar`` as a new + property for :py:class:`CFTimeIndex` and show ``calendar`` and ``length`` in + ``CFTimeIndex.__repr__`` (:issue:`2416`, :pull:`4092`) + By `Aaron Spring `_. +- Use a wrapped array's ``_repr_inline_`` method to construct the collapsed ``repr`` + of :py:class:`DataArray` and :py:class:`Dataset` objects and + document the new method in :doc:`internals`. (:pull:`4248`). + By `Justus Magin `_. +- Allow per-variable fill values in most functions. (:pull:`4237`). + By `Justus Magin `_. +- Expose ``use_cftime`` option in :py:func:`~xarray.open_zarr` (:issue:`2886`, :pull:`3229`) + By `Samnan Rahee `_ and `Anderson Banihirwe `_. + + +Bug fixes +~~~~~~~~~ + +- Fix indexing with datetime64 scalars with pandas 1.1 (:issue:`4283`). + By `Stephan Hoyer `_ and + `Justus Magin `_. +- Variables which are chunked using dask only along some dimensions can be chunked while storing with zarr along previously + unchunked dimensions (:pull:`4312`) By `Tobias Kölling `_. +- Fixed a bug in backend caused by basic installation of Dask (:issue:`4164`, :pull:`4318`) + `Sam Morley `_. +- Fixed a few bugs with :py:meth:`Dataset.polyfit` when encountering deficient matrix ranks (:issue:`4190`, :pull:`4193`). By `Pascal Bourgault `_. +- Fixed inconsistencies between docstring and functionality for :py:meth:`DataArray.str.get` + and :py:meth:`DataArray.str.wrap` (:issue:`4334`). By `Mathias Hauser `_. +- Fixed overflow issue causing incorrect results in computing means of :py:class:`cftime.datetime` + arrays (:issue:`4341`). By `Spencer Clark `_. +- Fixed :py:meth:`Dataset.coarsen`, :py:meth:`DataArray.coarsen` dropping attributes on original object (:issue:`4120`, :pull:`4360`). By `Julia Kent `_. +- fix the signature of the plot methods. (:pull:`4359`) By `Justus Magin `_. +- Fix :py:func:`xarray.apply_ufunc` with ``vectorize=True`` and ``exclude_dims`` (:issue:`3890`). + By `Mathias Hauser `_. +- Fix `KeyError` when doing linear interpolation to an nd `DataArray` + that contains NaNs (:pull:`4233`). + By `Jens Svensmark `_ +- Fix incorrect legend labels for :py:meth:`Dataset.plot.scatter` (:issue:`4126`). + By `Peter Hausamann `_. +- Fix ``dask.optimize`` on ``DataArray`` producing an invalid Dask task graph (:issue:`3698`) + By `Tom Augspurger `_ +- Fix ``pip install .`` when no ``.git`` directory exists; namely when the xarray source + directory has been rsync'ed by PyCharm Professional for a remote deployment over SSH. + By `Guido Imperiale `_ +- Preserve dimension and coordinate order during :py:func:`xarray.concat` (:issue:`2811`, :issue:`4072`, :pull:`4419`). + By `Kai Mühlbauer `_. +- Avoid relying on :py:class:`set` objects for the ordering of the coordinates (:pull:`4409`) + By `Justus Magin `_. + +Documentation +~~~~~~~~~~~~~ + +- Update the docstring of :py:meth:`DataArray.copy` to remove incorrect mention of 'dataset' (:issue:`3606`) + By `Sander van Rijn `_. +- Removed skipna argument from :py:meth:`DataArray.count`, :py:meth:`DataArray.any`, :py:meth:`DataArray.all`. (:issue:`755`) + By `Sander van Rijn `_ +- Update the contributing guide to use merges instead of rebasing and state + that we squash-merge. (:pull:`4355`). By `Justus Magin `_. +- Make sure the examples from the docstrings actually work (:pull:`4408`). + By `Justus Magin `_. +- Updated Vectorized Indexing to a clearer example. + By `Maximilian Roos `_ + +Internal Changes +~~~~~~~~~~~~~~~~ + +- Fixed all doctests and enabled their running in CI. + By `Justus Magin `_. +- Relaxed the :ref:`mindeps_policy` to support: + + - all versions of setuptools released in the last 42 months (but no older than 38.4) + - all versions of dask and dask.distributed released in the last 12 months (but no + older than 2.9) + - all versions of other packages released in the last 12 months + + All are up from 6 months (:issue:`4295`) + `Guido Imperiale `_. +- Use :py:func:`dask.array.apply_gufunc ` instead of + :py:func:`dask.array.blockwise` in :py:func:`xarray.apply_ufunc` when using + ``dask='parallelized'``. (:pull:`4060`, :pull:`4391`, :pull:`4392`) + By `Kai Mühlbauer `_. +- Align ``mypy`` versions to ``0.782`` across ``requirements`` and + ``.pre-commit-config.yml`` files. (:pull:`4390`) + By `Maximilian Roos `_ +- Only load resource files when running inside a Jupyter Notebook + (:issue:`4294`) By `Guido Imperiale `_ +- Silenced most ``numpy`` warnings such as ``Mean of empty slice``. (:pull:`4369`) + By `Maximilian Roos `_ +- Enable type checking for :py:func:`concat` (:issue:`4238`) + By `Mathias Hauser `_. +- Updated plot functions for matplotlib version 3.3 and silenced warnings in the + plot tests (:pull:`4365`). By `Mathias Hauser `_. +- Versions in ``pre-commit.yaml`` are now pinned, to reduce the chances of + conflicting versions. (:pull:`4388`) + By `Maximilian Roos `_ + + + .. _whats-new.0.16.0: -v0.16.0 (unreleased) +v0.16.0 (2020-07-11) --------------------- +This release adds `xarray.cov` & `xarray.corr` for covariance & correlation +respectively; the `idxmax` & `idxmin` methods, the `polyfit` method & +`xarray.polyval` for fitting polynomials, as well as a number of documentation +improvements, other features, and bug fixes. Many thanks to all 44 contributors +who contributed to this release: + +Akio Taniguchi, Andrew Williams, Aurélien Ponte, Benoit Bovy, Dave Cole, David +Brochart, Deepak Cherian, Elliott Sales de Andrade, Etienne Combrisson, Hossein +Madadi, Huite, Joe Hamman, Kai Mühlbauer, Keisuke Fujii, Maik Riechert, Marek +Jacob, Mathias Hauser, Matthieu Ancellin, Maximilian Roos, Noah D Brenowitz, +Oriol Abril, Pascal Bourgault, Phillip Butcher, Prajjwal Nijhara, Ray Bell, Ryan +Abernathey, Ryan May, Spencer Clark, Spencer Hill, Srijan Saurav, Stephan Hoyer, +Taher Chegini, Todd, Tom Nicholas, Yohai Bar Sinai, Yunus Sevinchan, +arabidopsis, aurghs, clausmichele, dmey, johnomotani, keewis, raphael dussin, +risebell + Breaking changes ~~~~~~~~~~~~~~~~ + +- Minimum supported versions for the following packages have changed: ``dask >=2.9``, + ``distributed>=2.9``. + By `Deepak Cherian `_ +- ``groupby`` operations will restore coord dimension order. Pass ``restore_coord_dims=False`` + to revert to previous behavior. +- :meth:`DataArray.transpose` will now transpose coordinates by default. + Pass ``transpose_coords=False`` to revert to previous behaviour. + By `Maximilian Roos `_ - Alternate draw styles for :py:meth:`plot.step` must be passed using the ``drawstyle`` (or ``ds``) keyword argument, instead of the ``linestyle`` (or ``ls``) keyword argument, in line with the `upstream change in Matplotlib `_. (:pull:`3274`) By `Elliott Sales de Andrade `_ +- The old ``auto_combine`` function has now been removed in + favour of the :py:func:`combine_by_coords` and + :py:func:`combine_nested` functions. This also means that + the default behaviour of :py:func:`open_mfdataset` has changed to use + ``combine='by_coords'`` as the default argument value. (:issue:`2616`, :pull:`3926`) + By `Tom Nicholas `_. +- The ``DataArray`` and ``Variable`` HTML reprs now expand the data section by + default (:issue:`4176`) + By `Stephan Hoyer `_. - New deprecations (behavior will be changed in xarray 0.17): - ``dim`` argument to :py:meth:`DataArray.integrate` is being deprecated in @@ -35,25 +443,47 @@ Breaking changes New Features ~~~~~~~~~~~~ -- Added :py:meth:`DataArray.polyfit` and :py:func:`xarray.polyval` for fitting polynomials. (:issue:`3349`) +- :py:meth:`DataArray.argmin` and :py:meth:`DataArray.argmax` now support + sequences of 'dim' arguments, and if a sequence is passed return a dict + (which can be passed to :py:meth:`DataArray.isel` to get the value of the minimum) of + the indices for each dimension of the minimum or maximum of a DataArray. + (:pull:`3936`) + By `John Omotani `_, thanks to `Keisuke Fujii + `_ for work in :pull:`1469`. +- Added :py:func:`xarray.cov` and :py:func:`xarray.corr` (:issue:`3784`, :pull:`3550`, :pull:`4089`). + By `Andrew Williams `_ and `Robin Beer `_. +- Implement :py:meth:`DataArray.idxmax`, :py:meth:`DataArray.idxmin`, + :py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:issue:`60`, :pull:`3871`) + By `Todd Jennings `_ +- Added :py:meth:`DataArray.polyfit` and :py:func:`xarray.polyval` for fitting + polynomials. (:issue:`3349`, :pull:`3733`, :pull:`4099`) + By `Pascal Bourgault `_. +- Added :py:meth:`xarray.infer_freq` for extending frequency inferring to CFTime indexes and data (:pull:`4033`). By `Pascal Bourgault `_. +- ``chunks='auto'`` is now supported in the ``chunks`` argument of + :py:meth:`Dataset.chunk`. (:issue:`4055`) + By `Andrew Williams `_ - Control over attributes of result in :py:func:`merge`, :py:func:`concat`, :py:func:`combine_by_coords` and :py:func:`combine_nested` using combine_attrs keyword argument. (:issue:`3865`, :pull:`3877`) By `John Omotani `_ -- 'missing_dims' argument to :py:meth:`Dataset.isel`, - `:py:meth:`DataArray.isel` and :py:meth:`Variable.isel` to allow replacing +- `missing_dims` argument to :py:meth:`Dataset.isel`, + :py:meth:`DataArray.isel` and :py:meth:`Variable.isel` to allow replacing the exception when a dimension passed to ``isel`` is not present with a warning, or just ignore the dimension. (:issue:`3866`, :pull:`3923`) By `John Omotani `_ -- Limited the length of array items with long string reprs to a - reasonable width (:pull:`3900`) - By `Maximilian Roos `_ -- Implement :py:meth:`DataArray.idxmax`, :py:meth:`DataArray.idxmin`, - :py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:issue:`60`, :pull:`3871`) - By `Todd Jennings `_ +- Support dask handling for :py:meth:`DataArray.idxmax`, :py:meth:`DataArray.idxmin`, + :py:meth:`Dataset.idxmax`, :py:meth:`Dataset.idxmin`. (:pull:`3922`, :pull:`4135`) + By `Kai Mühlbauer `_ and `Pascal Bourgault `_. +- More support for unit aware arrays with pint (:pull:`3643`, :pull:`3975`, :pull:`4163`) + By `Justus Magin `_. +- Support overriding existing variables in ``to_zarr()`` with ``mode='a'`` even + without ``append_dim``, as long as dimension sizes do not change. + By `Stephan Hoyer `_. - Allow plotting of boolean arrays. (:pull:`3766`) By `Marek Jacob `_ +- Enable using MultiIndex levels as coordinates in 1D and 2D plots (:issue:`3927`). + By `Mathias Hauser `_. - A ``days_in_month`` accessor for :py:class:`xarray.CFTimeIndex`, analogous to the ``days_in_month`` accessor for a :py:class:`pandas.DatetimeIndex`, which returns the days in the month each datetime in the index. Now days in month @@ -61,16 +491,63 @@ New Features the :py:class:`~core.accessor_dt.DatetimeAccessor` (:pull:`3935`). This feature requires cftime version 1.1.0 or greater. By `Spencer Clark `_. +- For the netCDF3 backend, added dtype coercions for unsigned integer types. + (:issue:`4014`, :pull:`4018`) + By `Yunus Sevinchan `_ +- :py:meth:`map_blocks` now accepts a ``template`` kwarg. This allows use cases + where the result of a computation could not be inferred automatically. + By `Deepak Cherian `_ +- :py:meth:`map_blocks` can now handle dask-backed xarray objects in ``args``. (:pull:`3818`) + By `Deepak Cherian `_ +- Add keyword ``decode_timedelta`` to :py:func:`xarray.open_dataset`, + (:py:func:`xarray.open_dataarray`, :py:func:`xarray.open_dataarray`, + :py:func:`xarray.decode_cf`) that allows to disable/enable the decoding of timedeltas + independently of time decoding (:issue:`1621`) + `Aureliana Barghini `_ + +Enhancements +~~~~~~~~~~~~ +- Performance improvement of :py:meth:`DataArray.interp` and :py:func:`Dataset.interp` + We performs independant interpolation sequentially rather than interpolating in + one large multidimensional space. (:issue:`2223`) + By `Keisuke Fujii `_. +- :py:meth:`DataArray.interp` now support interpolations over chunked dimensions (:pull:`4155`). By `Alexandre Poux `_. +- Major performance improvement for :py:meth:`Dataset.from_dataframe` when the + dataframe has a MultiIndex (:pull:`4184`). + By `Stephan Hoyer `_. + - :py:meth:`DataArray.reset_index` and :py:meth:`Dataset.reset_index` now keep + coordinate attributes (:pull:`4103`). By `Oriol Abril `_. +- Axes kwargs such as ``facecolor`` can now be passed to :py:meth:`DataArray.plot` in ``subplot_kws``. + This works for both single axes plots and FacetGrid plots. + By `Raphael Dussin `_. +- Array items with long string reprs are now limited to a + reasonable width (:pull:`3900`) + By `Maximilian Roos `_ +- Large arrays whose numpy reprs would have greater than 40 lines are now + limited to a reasonable length. + (:pull:`3905`) + By `Maximilian Roos `_ Bug fixes ~~~~~~~~~ -- Fix wrong order in converting a ``pd.Series`` with a MultiIndex to ``DataArray``. (:issue:`3951`) +- Fix errors combining attrs in :py:func:`open_mfdataset` (:issue:`4009`, :pull:`4173`) + By `John Omotani `_ +- If groupby receives a ``DataArray`` with name=None, assign a default name (:issue:`158`) + By `Phil Butcher `_. +- Support dark mode in VS code (:issue:`4024`) By `Keisuke Fujii `_. +- Fix bug when converting multiindexed Pandas objects to sparse xarray objects. (:issue:`4019`) + By `Deepak Cherian `_. +- ``ValueError`` is raised when ``fill_value`` is not a scalar in :py:meth:`full_like`. (:issue:`3977`) + By `Huite Bootsma `_. +- Fix wrong order in converting a ``pd.Series`` with a MultiIndex to ``DataArray``. + (:issue:`3951`, :issue:`4186`) + By `Keisuke Fujii `_ and `Stephan Hoyer `_. - Fix renaming of coords when one or more stacked coords is not in sorted order during stack+groupby+apply operations. (:issue:`3287`, :pull:`3906`) By `Spencer Hill `_ - Fix a regression where deleting a coordinate from a copied :py:class:`DataArray` - can affect the original :py:class:`Dataarray`. (:issue:`3899`, :pull:`3871`) + can affect the original :py:class:`DataArray`. (:issue:`3899`, :pull:`3871`) By `Todd Jennings `_ - Fix :py:class:`~xarray.plot.FacetGrid` plots with a single contour. (:issue:`3569`, :pull:`3915`). By `Deepak Cherian `_ @@ -78,15 +555,30 @@ Bug fixes By `Deepak Cherian `_ - Fix :py:class:`~xarray.plot.FacetGrid` when ``vmin == vmax``. (:issue:`3734`) By `Deepak Cherian `_ +- Fix plotting when ``levels`` is a scalar and ``norm`` is provided. (:issue:`3735`) + By `Deepak Cherian `_ - Fix bug where plotting line plots with 2D coordinates depended on dimension order. (:issue:`3933`) By `Tom Nicholas `_. - Fix ``RasterioDeprecationWarning`` when using a ``vrt`` in ``open_rasterio``. (:issue:`3964`) By `Taher Chegini `_. +- Fix ``AttributeError`` on displaying a :py:class:`Variable` + in a notebook context. (:issue:`3972`, :pull:`3973`) + By `Ian Castleden `_. - Fix bug causing :py:meth:`DataArray.interpolate_na` to always drop attributes, and added `keep_attrs` argument. (:issue:`3968`) By `Tom Nicholas `_. - +- Fix bug in time parsing failing to fall back to cftime. This was causing time + variables with a time unit of `'msecs'` to fail to parse. (:pull:`3998`) + By `Ryan May `_. +- Fix weighted mean when passing boolean weights (:issue:`4074`). + By `Mathias Hauser `_. +- Fix html repr in untrusted notebooks: fallback to plain text repr. (:pull:`4053`) + By `Benoit Bovy `_. +- Fix :py:meth:`DataArray.to_unstacked_dataset` for single-dimension variables. (:issue:`4049`) + By `Deepak Cherian `_ +- Fix :py:func:`open_rasterio` for ``WarpedVRT`` with specified ``src_crs``. (:pull:`4104`) + By `Dave Cole `_. Documentation ~~~~~~~~~~~~~ @@ -108,18 +600,33 @@ Documentation of ``kwargs`` in :py:meth:`Dataset.interp` and :py:meth:`DataArray.interp` for 1-d and n-d interpolation (:pull:`3956`). By `Matthias Riße `_. +- Apply ``black`` to all the code in the documentation (:pull:`4012`) + By `Justus Magin `_. +- Narrative documentation now describes :py:meth:`map_blocks`: :ref:`dask.automatic-parallelization`. + By `Deepak Cherian `_. +- Document ``.plot``, ``.dt``, ``.str`` accessors the way they are called. (:issue:`3625`, :pull:`3988`) + By `Justus Magin `_. +- Add documentation for the parameters and return values of :py:meth:`DataArray.sel`. + By `Justus Magin `_. Internal Changes ~~~~~~~~~~~~~~~~ +- Raise more informative error messages for chunk size conflicts when writing to zarr files. + By `Deepak Cherian `_. - Run the ``isort`` pre-commit hook only on python source files and update the ``flake8`` version. (:issue:`3750`, :pull:`3711`) By `Justus Magin `_. +- Add `blackdoc `_ to the list of + checkers for development. (:pull:`4177`) + By `Justus Magin `_. - Add a CI job that runs the tests with every optional dependency except ``dask``. (:issue:`3794`, :pull:`3919`) By `Justus Magin `_. - Use ``async`` / ``await`` for the asynchronous distributed tests. (:issue:`3987`, :pull:`3989`) By `Justus Magin `_. +- Various internal code clean-ups (:pull:`4026`, :pull:`4038`). + By `Prajjwal Nijhara `_. .. _whats-new.0.15.1: @@ -144,7 +651,7 @@ New Features - Weighted array reductions are now supported via the new :py:meth:`DataArray.weighted` and :py:meth:`Dataset.weighted` methods. See :ref:`comput.weighted`. (:issue:`422`, :pull:`2922`). - By `Mathias Hauser `_ + By `Mathias Hauser `_. - The new jupyter notebook repr (``Dataset._repr_html_`` and ``DataArray._repr_html_``) (introduced in 0.14.1) is now on by default. To disable, use ``xarray.set_options(display_style="text")``. @@ -178,6 +685,8 @@ New Features :py:meth:`core.groupby.DatasetGroupBy.quantile`, :py:meth:`core.groupby.DataArrayGroupBy.quantile` (:issue:`3843`, :pull:`3844`) By `Aaron Spring `_. +- Add a diff summary for `testing.assert_allclose`. (:issue:`3617`, :pull:`3847`) + By `Justus Magin `_. Bug fixes ~~~~~~~~~ @@ -206,13 +715,13 @@ Bug fixes - xarray now respects the over, under and bad colors if set on a provided colormap. (:issue:`3590`, :pull:`3601`) By `johnomotani `_. -- :py:func:`coarsen` now respects ``xr.set_options(keep_attrs=True)`` +- ``coarsen`` and ``rolling`` now respect ``xr.set_options(keep_attrs=True)`` to preserve attributes. :py:meth:`Dataset.coarsen` accepts a keyword argument ``keep_attrs`` to change this setting. (:issue:`3376`, :pull:`3801`) By `Andrew Thomas `_. - Delete associated indexes when deleting coordinate variables. (:issue:`3746`). By `Deepak Cherian `_. -- Fix :py:meth:`xarray.core.dataset.Dataset.to_zarr` when using `append_dim` and `group` +- Fix :py:meth:`Dataset.to_zarr` when using ``append_dim`` and ``group`` simultaneously. (:issue:`3170`). By `Matthias Meyer `_. - Fix html repr on :py:class:`Dataset` with non-string keys (:pull:`3807`). By `Maximilian Roos `_. @@ -250,7 +759,7 @@ Internal Changes By `Maximilian Roos `_ - Remove xfails for scipy 1.0.1 for tests that append to netCDF files (:pull:`3805`). By `Mathias Hauser `_. -- Remove conversion to :py:class:`pandas.Panel`, given its removal in pandas +- Remove conversion to ``pandas.Panel``, given its removal in pandas in favor of xarray's objects. By `Maximilian Roos `_ @@ -942,7 +1451,7 @@ New functions/methods ``combine_by_coords`` to combine datasets along multiple dimensions, by specifying the argument ``combine='nested'`` or ``combine='by_coords'``. - The older function :py:func:`~xarray.auto_combine` has been deprecated, + The older function ``auto_combine`` has been deprecated, because its functionality has been subsumed by the new functions. To avoid FutureWarnings switch to using ``combine_nested`` or ``combine_by_coords``, (or set the ``combine`` argument in @@ -1962,8 +2471,8 @@ Enhancements .. ipython:: python - ds = xr.Dataset({'a': 1}) - np.sin(ds) + ds = xr.Dataset({"a": 1}) + np.sin(ds) This obliviates the need for the ``xarray.ufuncs`` module, which will be deprecated in the future when xarray drops support for older versions of @@ -2054,8 +2563,8 @@ Enhancements .. ipython:: python - da = xr.DataArray(np.array([True, False, np.nan], dtype=object), dims='x') - da.sum() + da = xr.DataArray(np.array([True, False, np.nan], dtype=object), dims="x") + da.sum() (:issue:`1866`) By `Keisuke Fujii `_. @@ -2201,7 +2710,7 @@ Breaking changes - A new resampling interface to match pandas' groupby-like API was added to :py:meth:`Dataset.resample` and :py:meth:`DataArray.resample` (:issue:`1272`). :ref:`Timeseries resampling ` is - fully supported for data with arbitrary dimensions as is both downsampling + fully supported for data with arbitrary dimensions as is both downsampling and upsampling (including linear, quadratic, cubic, and spline interpolation). Old syntax: @@ -2209,7 +2718,7 @@ Breaking changes .. ipython:: :verbatim: - In [1]: ds.resample('24H', dim='time', how='max') + In [1]: ds.resample("24H", dim="time", how="max") Out[1]: [...] @@ -2219,7 +2728,7 @@ Breaking changes .. ipython:: :verbatim: - In [1]: ds.resample(time='24H').max() + In [1]: ds.resample(time="24H").max() Out[1]: [...] @@ -2289,9 +2798,9 @@ Enhancements In [1]: import xarray as xr - In [2]: arr = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=('x', 'y')) + In [2]: arr = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=("x", "y")) - In [3]: xr.where(arr % 2, 'even', 'odd') + In [3]: xr.where(arr % 2, "even", "odd") Out[3]: array([['even', 'odd', 'even'], @@ -2812,7 +3321,7 @@ Breaking changes .. ipython:: :verbatim: - In [1]: xr.Dataset({'foo': (('x', 'y'), [[1, 2]])}) + In [1]: xr.Dataset({"foo": (("x", "y"), [[1, 2]])}) Out[1]: Dimensions: (x: 1, y: 2) @@ -3269,10 +3778,10 @@ Enhancements .. ipython:: :verbatim: - In [1]: import xarray as xr; import numpy as np + In [1]: import xarray as xr + ...: import numpy as np - In [2]: arr = xr.DataArray(np.arange(0, 7.5, 0.5).reshape(3, 5), - dims=('x', 'y')) + In [2]: arr = xr.DataArray(np.arange(0, 7.5, 0.5).reshape(3, 5), dims=("x", "y")) In [3]: arr Out[3]: @@ -3332,7 +3841,7 @@ Bug fixes - Restore checks for shape consistency between data and coordinates in the DataArray constructor (:issue:`758`). - Single dimension variables no longer transpose as part of a broader - ``.transpose``. This behavior was causing ``pandas.PeriodIndex`` dimensions + ``.transpose``. This behavior was causing ``pandas.PeriodIndex`` dimensions to lose their type (:issue:`749`) - :py:class:`~xarray.Dataset` labels remain as their native type on ``.to_dataset``. Previously they were coerced to strings (:issue:`745`) @@ -3411,7 +3920,7 @@ Breaking changes .. ipython:: :verbatim: - In [2]: xray.DataArray([4, 5, 6], dims='x', name='x') + In [2]: xray.DataArray([4, 5, 6], dims="x", name="x") Out[2]: array([4, 5, 6]) @@ -3423,7 +3932,7 @@ Breaking changes .. ipython:: :verbatim: - In [2]: xray.DataArray([4, 5, 6], dims='x', name='x') + In [2]: xray.DataArray([4, 5, 6], dims="x", name="x") Out[2]: array([4, 5, 6]) @@ -3446,13 +3955,11 @@ Enhancements .. ipython:: :verbatim: - In [7]: df = pd.DataFrame({'foo': range(3), - ...: 'x': ['a', 'b', 'b'], - ...: 'y': [0, 0, 1]}) + In [7]: df = pd.DataFrame({"foo": range(3), "x": ["a", "b", "b"], "y": [0, 0, 1]}) - In [8]: s = df.set_index(['x', 'y'])['foo'] + In [8]: s = df.set_index(["x", "y"])["foo"] - In [12]: arr = xray.DataArray(s, dims='z') + In [12]: arr = xray.DataArray(s, dims="z") In [13]: arr Out[13]: @@ -3461,13 +3968,13 @@ Enhancements Coordinates: * z (z) object ('a', 0) ('b', 0) ('b', 1) - In [19]: arr.indexes['z'] + In [19]: arr.indexes["z"] Out[19]: MultiIndex(levels=[[u'a', u'b'], [0, 1]], labels=[[0, 1, 1], [0, 0, 1]], names=[u'x', u'y']) - In [14]: arr.unstack('z') + In [14]: arr.unstack("z") Out[14]: array([[ 0., nan], @@ -3476,7 +3983,7 @@ Enhancements * x (x) object 'a' 'b' * y (y) int64 0 1 - In [26]: arr.unstack('z').stack(z=('x', 'y')) + In [26]: arr.unstack("z").stack(z=("x", "y")) Out[26]: array([ 0., nan, 1., 2.]) @@ -3504,9 +4011,9 @@ Enhancements for shifting/rotating datasets or arrays along a dimension: .. ipython:: python - :okwarning: + :okwarning: - array = xray.DataArray([5, 6, 7, 8], dims='x') + array = xray.DataArray([5, 6, 7, 8], dims="x") array.shift(x=2) array.roll(x=2) @@ -3521,8 +4028,8 @@ Enhancements .. ipython:: python - a = xray.DataArray([1, 2, 3], dims='x') - b = xray.DataArray([5, 6], dims='y') + a = xray.DataArray([1, 2, 3], dims="x") + b = xray.DataArray([5, 6], dims="y") a b a2, b2 = xray.broadcast(a, b) @@ -3592,9 +4099,9 @@ Enhancements .. ipython:: :verbatim: - In [5]: array = xray.DataArray([1, 2, 3], dims='x') + In [5]: array = xray.DataArray([1, 2, 3], dims="x") - In [6]: array.reindex(x=[0.9, 1.5], method='nearest', tolerance=0.2) + In [6]: array.reindex(x=[0.9, 1.5], method="nearest", tolerance=0.2) Out[6]: array([ 2., nan]) @@ -3674,10 +4181,11 @@ Enhancements .. ipython:: :verbatim: - In [1]: da = xray.DataArray(np.arange(56).reshape((7, 8)), - ...: coords={'x': list('abcdefg'), - ...: 'y': 10 * np.arange(8)}, - ...: dims=['x', 'y']) + In [1]: da = xray.DataArray( + ...: np.arange(56).reshape((7, 8)), + ...: coords={"x": list("abcdefg"), "y": 10 * np.arange(8)}, + ...: dims=["x", "y"], + ...: ) In [2]: da Out[2]: @@ -3694,7 +4202,7 @@ Enhancements * x (x) |S1 'a' 'b' 'c' 'd' 'e' 'f' 'g' # we can index by position along each dimension - In [3]: da.isel_points(x=[0, 1, 6], y=[0, 1, 0], dim='points') + In [3]: da.isel_points(x=[0, 1, 6], y=[0, 1, 0], dim="points") Out[3]: array([ 0, 9, 48]) @@ -3704,7 +4212,7 @@ Enhancements * points (points) int64 0 1 2 # or equivalently by label - In [9]: da.sel_points(x=['a', 'b', 'g'], y=[0, 10, 0], dim='points') + In [9]: da.sel_points(x=["a", "b", "g"], y=[0, 10, 0], dim="points") Out[9]: array([ 0, 9, 48]) @@ -3718,11 +4226,11 @@ Enhancements .. ipython:: python - ds = xray.Dataset(coords={'x': range(100), 'y': range(100)}) - ds['distance'] = np.sqrt(ds.x ** 2 + ds.y ** 2) + ds = xray.Dataset(coords={"x": range(100), "y": range(100)}) + ds["distance"] = np.sqrt(ds.x ** 2 + ds.y ** 2) - @savefig where_example.png width=4in height=4in - ds.distance.where(ds.distance < 100).plot() + @savefig where_example.png width=4in height=4in + ds.distance.where(ds.distance < 100).plot() - Added new methods ``xray.DataArray.diff`` and ``xray.Dataset.diff`` for finite difference calculations along a given axis. @@ -3732,9 +4240,9 @@ Enhancements .. ipython:: python - da = xray.DataArray(np.random.random_sample(size=(5, 4))) - da.where(da < 0.5) - da.where(da < 0.5).to_masked_array(copy=True) + da = xray.DataArray(np.random.random_sample(size=(5, 4))) + da.where(da < 0.5) + da.where(da < 0.5).to_masked_array(copy=True) - Added new flag "drop_variables" to ``xray.open_dataset`` for excluding variables from being parsed. This may be useful to drop @@ -3792,9 +4300,9 @@ Enhancements .. ipython:: :verbatim: - In [1]: years, datasets = zip(*ds.groupby('time.year')) + In [1]: years, datasets = zip(*ds.groupby("time.year")) - In [2]: paths = ['%s.nc' % y for y in years] + In [2]: paths = ["%s.nc" % y for y in years] In [3]: xray.save_mfdataset(datasets, paths) @@ -3867,9 +4375,9 @@ Backwards incompatible changes .. ipython:: :verbatim: - In [1]: ds = xray.Dataset({'x': 0}) + In [1]: ds = xray.Dataset({"x": 0}) - In [2]: xray.concat([ds, ds], dim='y') + In [2]: xray.concat([ds, ds], dim="y") Out[2]: Dimensions: () @@ -3881,13 +4389,13 @@ Backwards incompatible changes Now, the default always concatenates data variables: .. ipython:: python - :suppress: + :suppress: - ds = xray.Dataset({'x': 0}) + ds = xray.Dataset({"x": 0}) .. ipython:: python - xray.concat([ds, ds], dim='y') + xray.concat([ds, ds], dim="y") To obtain the old behavior, supply the argument ``concat_over=[]``. @@ -3900,17 +4408,20 @@ Enhancements .. ipython:: python - ds = xray.Dataset({'a': 1, 'b': ('x', [1, 2, 3])}, - coords={'c': 42}, attrs={'Conventions': 'None'}) + ds = xray.Dataset( + {"a": 1, "b": ("x", [1, 2, 3])}, + coords={"c": 42}, + attrs={"Conventions": "None"}, + ) ds.to_array() - ds.to_array().to_dataset(dim='variable') + ds.to_array().to_dataset(dim="variable") - New ``xray.Dataset.fillna`` method to fill missing values, modeled off the pandas method of the same name: .. ipython:: python - array = xray.DataArray([np.nan, 1, np.nan, 3], dims='x') + array = xray.DataArray([np.nan, 1, np.nan, 3], dims="x") array.fillna(0) ``fillna`` works on both ``Dataset`` and ``DataArray`` objects, and uses @@ -3923,9 +4434,9 @@ Enhancements .. ipython:: python - ds = xray.Dataset({'y': ('x', [1, 2, 3])}) - ds.assign(z = lambda ds: ds.y ** 2) - ds.assign_coords(z = ('x', ['a', 'b', 'c'])) + ds = xray.Dataset({"y": ("x", [1, 2, 3])}) + ds.assign(z=lambda ds: ds.y ** 2) + ds.assign_coords(z=("x", ["a", "b", "c"])) These methods return a new Dataset (or DataArray) with updated data or coordinate variables. @@ -3938,7 +4449,7 @@ Enhancements .. ipython:: :verbatim: - In [12]: ds.sel(x=1.1, method='nearest') + In [12]: ds.sel(x=1.1, method="nearest") Out[12]: Dimensions: () @@ -3947,7 +4458,7 @@ Enhancements Data variables: y int64 2 - In [13]: ds.sel(x=[1.1, 2.1], method='pad') + In [13]: ds.sel(x=[1.1, 2.1], method="pad") Out[13]: Dimensions: (x: 2) @@ -3973,7 +4484,7 @@ Enhancements .. ipython:: python - ds = xray.Dataset({'x': np.arange(1000)}) + ds = xray.Dataset({"x": np.arange(1000)}) with xray.set_options(display_width=40): print(ds) @@ -4011,42 +4522,42 @@ Enhancements need to supply the time dimension explicitly: .. ipython:: python - :verbatim: + :verbatim: - time = pd.date_range('2000-01-01', freq='6H', periods=10) - array = xray.DataArray(np.arange(10), [('time', time)]) - array.resample('1D', dim='time') + time = pd.date_range("2000-01-01", freq="6H", periods=10) + array = xray.DataArray(np.arange(10), [("time", time)]) + array.resample("1D", dim="time") You can specify how to do the resampling with the ``how`` argument and other options such as ``closed`` and ``label`` let you control labeling: .. ipython:: python - :verbatim: + :verbatim: - array.resample('1D', dim='time', how='sum', label='right') + array.resample("1D", dim="time", how="sum", label="right") If the desired temporal resolution is higher than the original data (upsampling), xray will insert missing values: .. ipython:: python - :verbatim: + :verbatim: - array.resample('3H', 'time') + array.resample("3H", "time") - ``first`` and ``last`` methods on groupby objects let you take the first or last examples from each group along the grouped axis: .. ipython:: python - :verbatim: + :verbatim: - array.groupby('time.day').first() + array.groupby("time.day").first() These methods combine well with ``resample``: .. ipython:: python - :verbatim: + :verbatim: - array.resample('1D', dim='time', how='first') + array.resample("1D", dim="time", how="first") - ``xray.Dataset.swap_dims`` allows for easily swapping one dimension @@ -4054,9 +4565,9 @@ Enhancements .. ipython:: python - ds = xray.Dataset({'x': range(3), 'y': ('x', list('abc'))}) - ds - ds.swap_dims({'x': 'y'}) + ds = xray.Dataset({"x": range(3), "y": ("x", list("abc"))}) + ds + ds.swap_dims({"x": "y"}) This was possible in earlier versions of xray, but required some contortions. - ``xray.open_dataset`` and ``xray.Dataset.to_netcdf`` now @@ -4102,8 +4613,8 @@ Breaking changes .. ipython:: python - lhs = xray.DataArray([1, 2, 3], [('x', [0, 1, 2])]) - rhs = xray.DataArray([2, 3, 4], [('x', [1, 2, 3])]) + lhs = xray.DataArray([1, 2, 3], [("x", [0, 1, 2])]) + rhs = xray.DataArray([2, 3, 4], [("x", [1, 2, 3])]) lhs + rhs :ref:`For dataset construction and merging`, we align based on the @@ -4111,14 +4622,14 @@ Breaking changes .. ipython:: python - xray.Dataset({'foo': lhs, 'bar': rhs}) + xray.Dataset({"foo": lhs, "bar": rhs}) :ref:`For update and __setitem__`, we align based on the **original** object: .. ipython:: python - lhs.coords['rhs'] = rhs + lhs.coords["rhs"] = rhs lhs - Aggregations like ``mean`` or ``median`` now skip missing values by default: @@ -4141,8 +4652,8 @@ Breaking changes .. ipython:: python - a = xray.DataArray([1, 2], coords={'c': 0}, dims='x') - b = xray.DataArray([1, 2], coords={'c': ('x', [0, 0])}, dims='x') + a = xray.DataArray([1, 2], coords={"c": 0}, dims="x") + b = xray.DataArray([1, 2], coords={"c": ("x", [0, 0])}, dims="x") (a + b).coords This functionality can be controlled through the ``compat`` option, which @@ -4153,9 +4664,10 @@ Breaking changes .. ipython:: python - time = xray.DataArray(pd.date_range('2000-01-01', periods=365), - dims='time', name='time') - counts = time.groupby('time.month').count() + time = xray.DataArray( + pd.date_range("2000-01-01", periods=365), dims="time", name="time" + ) + counts = time.groupby("time.month").count() counts.sel(month=2) Previously, you would need to use something like @@ -4165,8 +4677,8 @@ Breaking changes .. ipython:: python - ds = xray.Dataset({'t': pd.date_range('2000-01-01', periods=12, freq='M')}) - ds['t.season'] + ds = xray.Dataset({"t": pd.date_range("2000-01-01", periods=12, freq="M")}) + ds["t.season"] Previously, it returned numbered seasons 1 through 4. - We have updated our use of the terms of "coordinates" and "variables". What @@ -4189,8 +4701,8 @@ Enhancements .. ipython:: python - data = xray.DataArray([1, 2, 3], [('x', range(3))]) - data.reindex(x=[0.5, 1, 1.5, 2, 2.5], method='pad') + data = xray.DataArray([1, 2, 3], [("x", range(3))]) + data.reindex(x=[0.5, 1, 1.5, 2, 2.5], method="pad") This will be especially useful once pandas 0.16 is released, at which point xray will immediately support reindexing with @@ -4209,15 +4721,15 @@ Enhancements makes it easy to drop explicitly listed variables or index labels: .. ipython:: python - :okwarning: + :okwarning: # drop variables - ds = xray.Dataset({'x': 0, 'y': 1}) - ds.drop('x') + ds = xray.Dataset({"x": 0, "y": 1}) + ds.drop("x") # drop index labels - arr = xray.DataArray([1, 2, 3], coords=[('x', list('abc'))]) - arr.drop(['a', 'c'], dim='x') + arr = xray.DataArray([1, 2, 3], coords=[("x", list("abc"))]) + arr.drop(["a", "c"], dim="x") - ``xray.Dataset.broadcast_equals`` has been added to correspond to the new ``compat`` option. @@ -4285,7 +4797,8 @@ Backwards incompatible changes .. ipython:: python from datetime import datetime - xray.Dataset({'t': [datetime(2000, 1, 1)]}) + + xray.Dataset({"t": [datetime(2000, 1, 1)]}) - xray now has support (including serialization to netCDF) for :py:class:`~pandas.TimedeltaIndex`. :py:class:`datetime.timedelta` objects @@ -4301,8 +4814,8 @@ Enhancements .. ipython:: python - ds = xray.Dataset({'tmin': ([], 25, {'units': 'celsius'})}) - ds.tmin.units + ds = xray.Dataset({"tmin": ([], 25, {"units": "celsius"})}) + ds.tmin.units Tab-completion for these variables should work in editors such as IPython. However, setting variables or attributes in this fashion is not yet @@ -4312,7 +4825,7 @@ Enhancements .. ipython:: python - array = xray.DataArray(np.zeros(5), dims=['x']) + array = xray.DataArray(np.zeros(5), dims=["x"]) array[dict(x=slice(3))] = 1 array diff --git a/doc/why-xarray.rst b/doc/why-xarray.rst index 7d14a6c9f9e..a5093a1ff2a 100644 --- a/doc/why-xarray.rst +++ b/doc/why-xarray.rst @@ -49,7 +49,7 @@ Core data structures -------------------- xarray has two core data structures, which build upon and extend the core -strengths of NumPy_ and pandas_. Both data structures are fundamentally N-dimensional: +strengths of NumPy_ and pandas_. Both data structures are fundamentally N-dimensional: - :py:class:`~xarray.DataArray` is our implementation of a labeled, N-dimensional array. It is an N-D generalization of a :py:class:`pandas.Series`. The name diff --git a/licenses/PYTHON_LICENSE b/licenses/PYTHON_LICENSE index 43829c533b9..88251f5b6e8 100644 --- a/licenses/PYTHON_LICENSE +++ b/licenses/PYTHON_LICENSE @@ -251,4 +251,4 @@ FITNESS, IN NO EVENT SHALL STICHTING MATHEMATISCH CENTRUM BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT -OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. \ No newline at end of file +OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. diff --git a/readthedocs.yml b/readthedocs.yml index 173d61ec6f3..072a4b5110c 100644 --- a/readthedocs.yml +++ b/readthedocs.yml @@ -1,9 +1,12 @@ version: 2 build: - image: stable + image: latest conda: environment: ci/requirements/doc.yml +sphinx: + fail_on_warning: true + formats: [] diff --git a/requirements.txt b/requirements.txt index f73887ff5cc..23eff8f07cb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,4 @@ numpy >= 1.15 pandas >= 0.25 -setuptools >= 41.2 +setuptools >= 40.4 diff --git a/setup.cfg b/setup.cfg index 42dc53bb882..a695191bf02 100644 --- a/setup.cfg +++ b/setup.cfg @@ -64,23 +64,70 @@ classifiers = Intended Audience :: Science/Research Programming Language :: Python Programming Language :: Python :: 3 - Programming Language :: Python :: 3.6 Programming Language :: Python :: 3.7 + Programming Language :: Python :: 3.8 Topic :: Scientific/Engineering [options] -packages = xarray +packages = find: zip_safe = False # https://mypy.readthedocs.io/en/latest/installed_packages.html include_package_data = True -python_requires = >=3.6 +python_requires = >=3.7 install_requires = numpy >= 1.15 pandas >= 0.25 - setuptools >= 41.2 # For pkg_resources + setuptools >= 40.4 # For pkg_resources setup_requires = - setuptools >= 41.2 + setuptools >= 40.4 setuptools_scm + +[options.extras_require] +io = + netCDF4 + h5netcdf + scipy + pydap + zarr + fsspec + cftime + rasterio + cfgrib + ## Scitools packages & dependencies (e.g: cartopy, cf-units) can be hard to install + # scitools-iris + +accel = + scipy + bottleneck + numbagg + +parallel = + dask[complete] + +viz = + matplotlib + seaborn + nc-time-axis + ## Cartopy requires 3rd party libraries and only provides source distributions + ## See: https://github.com/SciTools/cartopy/issues/805 + # cartopy + +complete = + %(io)s + %(accel)s + %(parallel)s + %(viz)s + +docs = + %(complete)s + sphinx-autosummary-accessors + sphinx_rtd_theme + ipython + ipykernel + jupyter-client + nbsphinx + scanpydoc + [options.package_data] xarray = py.typed @@ -94,8 +141,6 @@ testpaths = xarray/tests properties # Fixed upstream in https://github.com/pydata/bottleneck/pull/199 filterwarnings = ignore:Using a non-tuple sequence for multidimensional indexing is deprecated:FutureWarning -env = - UVCDAT_ANONYMOUS_LOG=no markers = flaky: flaky tests network: tests requiring a network connection @@ -103,27 +148,21 @@ markers = [flake8] ignore = - # whitespace before ':' - doesn't work well with black - E203 - E402 - # line too long - let black worry about that - E501 - # do not assign a lambda expression, use a def - E731 - # line break before binary operator - W503 + E203 # whitespace before ':' - doesn't work well with black + E402 # module level import not at top of file + E501 # line too long - let black worry about that + E731 # do not assign a lambda expression, use a def + W503 # line break before binary operator exclude= .eggs doc [isort] +profile = black +skip_gitignore = true +force_to_top = true default_section = THIRDPARTY known_first_party = xarray -multi_line_output = 3 -include_trailing_comma = True -force_grid_wrap = 0 -use_parentheses = True -line_length = 88 # Most of the numerical computing stack doesn't have type annotations yet. [mypy-affine.*] @@ -138,6 +177,8 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-cftime.*] ignore_missing_imports = True +[mypy-cupy.*] +ignore_missing_imports = True [mypy-dask.*] ignore_missing_imports = True [mypy-distributed.*] @@ -195,4 +236,4 @@ ignore_errors = True test = pytest [pytest-watch] -nobeep = True \ No newline at end of file +nobeep = True diff --git a/setup.py b/setup.py index 76755a445f7..088d7e4eac6 100755 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ #!/usr/bin/env python from setuptools import setup -setup(use_scm_version=True) +setup(use_scm_version={"fallback_version": "999"}) diff --git a/xarray/__init__.py b/xarray/__init__.py index 0fead57e5fb..3886edc60e6 100644 --- a/xarray/__init__.py +++ b/xarray/__init__.py @@ -13,11 +13,12 @@ from .backends.zarr import open_zarr from .coding.cftime_offsets import cftime_range from .coding.cftimeindex import CFTimeIndex +from .coding.frequencies import infer_freq from .conventions import SerializationWarning, decode_cf from .core.alignment import align, broadcast -from .core.combine import auto_combine, combine_by_coords, combine_nested +from .core.combine import combine_by_coords, combine_nested from .core.common import ALL_DIMS, full_like, ones_like, zeros_like -from .core.computation import apply_ufunc, dot, polyval, where +from .core.computation import apply_ufunc, corr, cov, dot, polyval, where from .core.concat import concat from .core.dataarray import DataArray from .core.dataset import Dataset @@ -46,7 +47,6 @@ "align", "apply_ufunc", "as_variable", - "auto_combine", "broadcast", "cftime_range", "combine_by_coords", @@ -54,7 +54,10 @@ "concat", "decode_cf", "dot", + "cov", + "corr", "full_like", + "infer_freq", "load_dataarray", "load_dataset", "map_blocks", diff --git a/xarray/backends/__init__.py b/xarray/backends/__init__.py index 2a769b1335e..1500ea5061f 100644 --- a/xarray/backends/__init__.py +++ b/xarray/backends/__init__.py @@ -9,6 +9,7 @@ from .h5netcdf_ import H5NetCDFStore from .memory import InMemoryDataStore from .netCDF4_ import NetCDF4DataStore +from .plugins import list_engines from .pseudonetcdf_ import PseudoNetCDFDataStore from .pydap_ import PydapDataStore from .pynio_ import NioDataStore @@ -29,4 +30,5 @@ "H5NetCDFStore", "ZarrStore", "PseudoNetCDFDataStore", + "list_engines", ] diff --git a/xarray/backends/api.py b/xarray/backends/api.py index c7481e22b59..4958062a262 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1,10 +1,8 @@ -import os.path -import warnings +import os from glob import glob from io import BytesIO from numbers import Number from pathlib import Path -from textwrap import dedent from typing import ( TYPE_CHECKING, Callable, @@ -12,6 +10,7 @@ Hashable, Iterable, Mapping, + MutableMapping, Tuple, Union, ) @@ -23,12 +22,11 @@ from ..core.combine import ( _infer_concat_order_from_positions, _nested_combine, - auto_combine, combine_by_coords, ) from ..core.dataarray import DataArray -from ..core.dataset import Dataset -from ..core.utils import close_on_error, is_grib_path, is_remote_uri +from ..core.dataset import Dataset, _get_chunk, _maybe_chunk +from ..core.utils import close_on_error, is_grib_path, is_remote_uri, read_magic_number from .common import AbstractDataStore, ArrayWriter from .locks import _get_scheduler @@ -42,6 +40,17 @@ DATAARRAY_NAME = "__xarray_dataarray_name__" DATAARRAY_VARIABLE = "__xarray_dataarray_variable__" +ENGINES = { + "netcdf4": backends.NetCDF4DataStore.open, + "scipy": backends.ScipyDataStore, + "pydap": backends.PydapDataStore.open, + "h5netcdf": backends.H5NetCDFStore.open, + "pynio": backends.NioDataStore, + "pseudonetcdf": backends.PseudoNetCDFDataStore.open, + "cfgrib": backends.CfGribDataStore, + "zarr": backends.ZarrStore.open_group, +} + def _get_default_engine_remote_uri(): try: @@ -78,7 +87,7 @@ def _get_default_engine_grib(): if msgs: raise ValueError(" or\n".join(msgs)) else: - raise ValueError("PyNIO or cfgrib is required for accessing " "GRIB files") + raise ValueError("PyNIO or cfgrib is required for accessing GRIB files") def _get_default_engine_gz(): @@ -110,39 +119,22 @@ def _get_default_engine_netcdf(): def _get_engine_from_magic_number(filename_or_obj): - # check byte header to determine file type - if isinstance(filename_or_obj, bytes): - magic_number = filename_or_obj[:8] - else: - if filename_or_obj.tell() != 0: - raise ValueError( - "file-like object read/write pointer not at zero " - "please close and reopen, or use a context " - "manager" - ) - magic_number = filename_or_obj.read(8) - filename_or_obj.seek(0) + magic_number = read_magic_number(filename_or_obj) if magic_number.startswith(b"CDF"): engine = "scipy" elif magic_number.startswith(b"\211HDF\r\n\032\n"): engine = "h5netcdf" - if isinstance(filename_or_obj, bytes): - raise ValueError( - "can't open netCDF4/HDF5 as bytes " - "try passing a path or file-like object" - ) else: - if isinstance(filename_or_obj, bytes) and len(filename_or_obj) > 80: - filename_or_obj = filename_or_obj[:80] + b"..." raise ValueError( - "{} is not a valid netCDF file " - "did you mean to pass a string for a path instead?".format(filename_or_obj) + "cannot guess the engine, " + f"{magic_number} is not the signature of any supported file format " + "did you mean to pass a string for a path instead?" ) return engine -def _get_default_engine(path, allow_remote=False): +def _get_default_engine(path: str, allow_remote: bool = False): if allow_remote and is_remote_uri(path): engine = _get_default_engine_remote_uri() elif is_grib_path(path): @@ -154,11 +146,35 @@ def _get_default_engine(path, allow_remote=False): return engine -def _normalize_path(path): - if is_remote_uri(path): - return path +def _autodetect_engine(filename_or_obj): + if isinstance(filename_or_obj, AbstractDataStore): + engine = "store" + elif isinstance(filename_or_obj, (str, Path)): + engine = _get_default_engine(str(filename_or_obj), allow_remote=True) else: - return os.path.abspath(os.path.expanduser(path)) + engine = _get_engine_from_magic_number(filename_or_obj) + return engine + + +def _get_backend_cls(engine, engines=ENGINES): + """Select open_dataset method based on current engine""" + try: + return engines[engine] + except KeyError: + raise ValueError( + "unrecognized engine for open_dataset: {}\n" + "must be one of: {}".format(engine, list(ENGINES)) + ) + + +def _normalize_path(path): + if isinstance(path, Path): + path = str(path) + + if isinstance(path, str) and not is_remote_uri(path): + path = os.path.abspath(os.path.expanduser(path)) + + return path def _validate_dataset_names(dataset): @@ -168,14 +184,15 @@ def check_name(name): if isinstance(name, str): if not name: raise ValueError( - "Invalid name for DataArray or Dataset key: " + f"Invalid name {name!r} for DataArray or Dataset key: " "string must be length 1 or greater for " "serialization to netCDF files" ) elif name is not None: raise TypeError( - "DataArray.name or Dataset key must be either a " - "string or None for serialization to netCDF files" + f"Invalid name {name!r} for DataArray or Dataset key: " + "must be either a string or None for serialization to netCDF " + "files" ) for k in dataset.variables: @@ -191,22 +208,22 @@ def check_attr(name, value): if isinstance(name, str): if not name: raise ValueError( - "Invalid name for attr: string must be " + f"Invalid name for attr {name!r}: string must be " "length 1 or greater for serialization to " "netCDF files" ) else: raise TypeError( - "Invalid name for attr: {} must be a string for " - "serialization to netCDF files".format(name) + f"Invalid name for attr: {name!r} must be a string for " + "serialization to netCDF files" ) if not isinstance(value, (str, Number, np.ndarray, np.number, list, tuple)): raise TypeError( - "Invalid value for attr: {} must be a number, " + f"Invalid value for attr {name!r}: {value!r} must be a number, " "a string, an ndarray or a list/tuple of " "numbers/strings for serialization to netCDF " - "files".format(value) + "files" ) # Check attrs on the dataset itself @@ -293,7 +310,6 @@ def open_dataset( decode_cf=True, mask_and_scale=None, decode_times=True, - autoclose=None, concat_characters=True, decode_coords=True, engine=None, @@ -303,12 +319,13 @@ def open_dataset( drop_variables=None, backend_kwargs=None, use_cftime=None, + decode_timedelta=None, ): """Open and decode a dataset from a file or file-like object. Parameters ---------- - filename_or_obj : str, Path, file or xarray.backends.*DataStore + filename_or_obj : str, Path, file-like or DataStore Strings and Path objects are interpreted as a path to a netCDF file or an OpenDAP URL and opened with python-netCDF4, unless the filename ends with .gz, in which case the file is gunzipped and opened with @@ -332,10 +349,6 @@ def open_dataset( decode_times : bool, optional If True, decode times encoded in the standard NetCDF datetime format into datetime objects. Otherwise, leave them encoded as numbers. - autoclose : bool, optional - If True, automatically close files to avoid OS Error of too many files - being open. However, this option doesn't work with streams, e.g., - BytesIO. concat_characters : bool, optional If True, concatenate along the last dimension of character arrays to form string arrays. Dimensions will only be concatenated over (and @@ -344,16 +357,20 @@ def open_dataset( decode_coords : bool, optional If True, decode the 'coordinates' attribute to identify coordinates in the resulting dataset. - engine : {'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio', 'cfgrib', \ - 'pseudonetcdf'}, optional + engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", "cfgrib", \ + "pseudonetcdf", "zarr"}, optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for - 'netcdf4'. + "netcdf4". chunks : int or dict, optional - If chunks is provided, it used to load the new dataset into dask - arrays. ``chunks={}`` loads the dataset with dask using a single - chunk for all arrays. - lock : False or duck threading.Lock, optional + If chunks is provided, it is used to load the new dataset into dask + arrays. ``chunks=-1`` loads the dataset with dask using a single + chunk for all arrays. `chunks={}`` loads the dataset with dask using + engine preferred chunks if exposed by the backend, otherwise with + a single chunk for all arrays. + ``chunks='auto'`` will use dask ``auto`` chunking taking into account the + engine preferred chunks. See dask chunking for more details. + lock : False or lock-like, optional Resource lock to use when reading data from disk. Only relevant when using dask or another form of parallelism. By default, appropriate locks are chosen to safely read and write files with the currently @@ -365,17 +382,17 @@ def open_dataset( argument to use dask, in which case it defaults to False. Does not change the behavior of coordinates corresponding to dimensions, which always load their data from disk into a ``pandas.Index``. - drop_variables: string or iterable, optional + drop_variables: str or iterable, optional A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or inconsistent values. - backend_kwargs: dictionary, optional + backend_kwargs: dict, optional A dictionary of keyword arguments to pass on to the backend. This may be useful when backend options would improve performance or allow user control of dataset processing. use_cftime: bool, optional Only relevant if encoded dates come from a standard calendar - (e.g. 'gregorian', 'proleptic_gregorian', 'standard', or not + (e.g. "gregorian", "proleptic_gregorian", "standard", or not specified). If None (default), attempt to decode times to ``np.datetime64[ns]`` objects; if this is not possible, decode times to ``cftime.datetime`` objects. If True, always decode times to @@ -383,6 +400,11 @@ def open_dataset( represented using ``np.datetime64[ns]`` objects. If False, always decode times to ``np.datetime64[ns]`` objects; if this is not possible raise an error. + decode_timedelta : bool, optional + If True, decode variables and coordinates with time units in + {"days", "hours", "minutes", "seconds", "milliseconds", "microseconds"} + into timedelta objects. If False, leave them encoded as numbers. + If None (default), assume the same value of decode_time. Returns ------- @@ -400,32 +422,11 @@ def open_dataset( -------- open_mfdataset """ - engines = [ - None, - "netcdf4", - "scipy", - "pydap", - "h5netcdf", - "pynio", - "cfgrib", - "pseudonetcdf", - ] - if engine not in engines: - raise ValueError( - "unrecognized engine for open_dataset: {}\n" - "must be one of: {}".format(engine, engines) - ) + if os.environ.get("XARRAY_BACKEND_API", "v1") == "v2": + kwargs = {k: v for k, v in locals().items() if v is not None} + from . import apiv2 - if autoclose is not None: - warnings.warn( - "The autoclose argument is no longer used by " - "xarray.open_dataset() and is now ignored; it will be removed in " - "a future version of xarray. If necessary, you can control the " - "maximum number of simultaneous open files with " - "xarray.set_options(file_cache_maxsize=...).", - FutureWarning, - stacklevel=2, - ) + return apiv2.open_dataset(**kwargs) if mask_and_scale is None: mask_and_scale = not engine == "pseudonetcdf" @@ -435,6 +436,7 @@ def open_dataset( decode_times = False concat_characters = False decode_coords = False + decode_timedelta = False if cache is None: cache = chunks is None @@ -442,7 +444,7 @@ def open_dataset( if backend_kwargs is None: backend_kwargs = {} - def maybe_decode_store(store, lock=False): + def maybe_decode_store(store, chunks): ds = conventions.decode_cf( store, mask_and_scale=mask_and_scale, @@ -451,11 +453,12 @@ def maybe_decode_store(store, lock=False): decode_coords=decode_coords, drop_variables=drop_variables, use_cftime=use_cftime, + decode_timedelta=decode_timedelta, ) _protect_dataset_variables_inplace(ds, cache) - if chunks is not None: + if chunks is not None and engine != "zarr": from dask.base import tokenize # if passed an actual file path, augment the token with @@ -477,65 +480,76 @@ def maybe_decode_store(store, lock=False): chunks, drop_variables, use_cftime, + decode_timedelta, ) name_prefix = "open_dataset-%s" % token ds2 = ds.chunk(chunks, name_prefix=name_prefix, token=token) - ds2._file_obj = ds._file_obj + + elif engine == "zarr": + # adapted from Dataset.Chunk() and taken from open_zarr + if not (isinstance(chunks, (int, dict)) or chunks is None): + if chunks != "auto": + raise ValueError( + "chunks must be an int, dict, 'auto', or None. " + "Instead found %s. " % chunks + ) + + if chunks == "auto": + try: + import dask.array # noqa + except ImportError: + chunks = None + + # auto chunking needs to be here and not in ZarrStore because + # the variable chunks does not survive decode_cf + # return trivial case + if chunks is None: + return ds + + if isinstance(chunks, int): + chunks = dict.fromkeys(ds.dims, chunks) + + variables = { + k: _maybe_chunk( + k, + v, + _get_chunk(v, chunks), + overwrite_encoded_chunks=overwrite_encoded_chunks, + ) + for k, v in ds.variables.items() + } + ds2 = ds._replace(variables) + else: ds2 = ds - + ds2._file_obj = ds._file_obj return ds2 - if isinstance(filename_or_obj, Path): - filename_or_obj = str(filename_or_obj) + filename_or_obj = _normalize_path(filename_or_obj) if isinstance(filename_or_obj, AbstractDataStore): store = filename_or_obj - - elif isinstance(filename_or_obj, str): - filename_or_obj = _normalize_path(filename_or_obj) - + else: if engine is None: - engine = _get_default_engine(filename_or_obj, allow_remote=True) - if engine == "netcdf4": - store = backends.NetCDF4DataStore.open( - filename_or_obj, group=group, lock=lock, **backend_kwargs - ) - elif engine == "scipy": - store = backends.ScipyDataStore(filename_or_obj, **backend_kwargs) - elif engine == "pydap": - store = backends.PydapDataStore.open(filename_or_obj, **backend_kwargs) - elif engine == "h5netcdf": - store = backends.H5NetCDFStore.open( - filename_or_obj, group=group, lock=lock, **backend_kwargs - ) - elif engine == "pynio": - store = backends.NioDataStore(filename_or_obj, lock=lock, **backend_kwargs) - elif engine == "pseudonetcdf": - store = backends.PseudoNetCDFDataStore.open( - filename_or_obj, lock=lock, **backend_kwargs - ) - elif engine == "cfgrib": - store = backends.CfGribDataStore( - filename_or_obj, lock=lock, **backend_kwargs + engine = _autodetect_engine(filename_or_obj) + + extra_kwargs = {} + if group is not None: + extra_kwargs["group"] = group + if lock is not None: + extra_kwargs["lock"] = lock + + if engine == "zarr": + backend_kwargs = backend_kwargs.copy() + overwrite_encoded_chunks = backend_kwargs.pop( + "overwrite_encoded_chunks", None ) - else: - if engine not in [None, "scipy", "h5netcdf"]: - raise ValueError( - "can only read bytes or file-like objects " - "with engine='scipy' or 'h5netcdf'" - ) - engine = _get_engine_from_magic_number(filename_or_obj) - if engine == "scipy": - store = backends.ScipyDataStore(filename_or_obj, **backend_kwargs) - elif engine == "h5netcdf": - store = backends.H5NetCDFStore.open( - filename_or_obj, group=group, lock=lock, **backend_kwargs - ) + opener = _get_backend_cls(engine) + store = opener(filename_or_obj, **extra_kwargs, **backend_kwargs) with close_on_error(store): - ds = maybe_decode_store(store) + ds = maybe_decode_store(store, chunks) # Ensure source filename always stored in dataset object (GH issue #2550) if "source" not in ds.encoding: @@ -551,7 +565,6 @@ def open_dataarray( decode_cf=True, mask_and_scale=None, decode_times=True, - autoclose=None, concat_characters=True, decode_coords=True, engine=None, @@ -561,6 +574,7 @@ def open_dataarray( drop_variables=None, backend_kwargs=None, use_cftime=None, + decode_timedelta=None, ): """Open an DataArray from a file or file-like object containing a single data variable. @@ -570,7 +584,7 @@ def open_dataarray( Parameters ---------- - filename_or_obj : str, Path, file or xarray.backends.*DataStore + filename_or_obj : str, Path, file-like or DataStore Strings and Paths are interpreted as a path to a netCDF file or an OpenDAP URL and opened with python-netCDF4, unless the filename ends with .gz, in which case the file is gunzipped and opened with @@ -602,15 +616,15 @@ def open_dataarray( decode_coords : bool, optional If True, decode the 'coordinates' attribute to identify coordinates in the resulting dataset. - engine : {'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio', 'cfgrib'}, \ + engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", "cfgrib"}, \ optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for - 'netcdf4'. + "netcdf4". chunks : int or dict, optional If chunks is provided, it used to load the new dataset into dask arrays. - lock : False or duck threading.Lock, optional + lock : False or lock-like, optional Resource lock to use when reading data from disk. Only relevant when using dask or another form of parallelism. By default, appropriate locks are chosen to safely read and write files with the currently @@ -622,17 +636,17 @@ def open_dataarray( argument to use dask, in which case it defaults to False. Does not change the behavior of coordinates corresponding to dimensions, which always load their data from disk into a ``pandas.Index``. - drop_variables: string or iterable, optional + drop_variables: str or iterable, optional A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or inconsistent values. - backend_kwargs: dictionary, optional + backend_kwargs: dict, optional A dictionary of keyword arguments to pass on to the backend. This may be useful when backend options would improve performance or allow user control of dataset processing. use_cftime: bool, optional Only relevant if encoded dates come from a standard calendar - (e.g. 'gregorian', 'proleptic_gregorian', 'standard', or not + (e.g. "gregorian", "proleptic_gregorian", "standard", or not specified). If None (default), attempt to decode times to ``np.datetime64[ns]`` objects; if this is not possible, decode times to ``cftime.datetime`` objects. If True, always decode times to @@ -640,6 +654,11 @@ def open_dataarray( represented using ``np.datetime64[ns]`` objects. If False, always decode times to ``np.datetime64[ns]`` objects; if this is not possible raise an error. + decode_timedelta : bool, optional + If True, decode variables and coordinates with time units in + {"days", "hours", "minutes", "seconds", "milliseconds", "microseconds"} + into timedelta objects. If False, leave them encoded as numbers. + If None (default), assume the same value of decode_time. Notes ----- @@ -661,7 +680,6 @@ def open_dataarray( decode_cf=decode_cf, mask_and_scale=mask_and_scale, decode_times=decode_times, - autoclose=autoclose, concat_characters=concat_characters, decode_coords=decode_coords, engine=engine, @@ -671,6 +689,7 @@ def open_dataarray( drop_variables=drop_variables, backend_kwargs=backend_kwargs, use_cftime=use_cftime, + decode_timedelta=decode_timedelta, ) if len(dataset.data_vars) != 1: @@ -710,15 +729,14 @@ def close(self): def open_mfdataset( paths, chunks=None, - concat_dim="_not_supplied", + concat_dim=None, compat="no_conflicts", preprocess=None, engine=None, lock=None, data_vars="all", coords="different", - combine="_old_auto", - autoclose=None, + combine="by_coords", parallel=False, join="outer", attrs_file=None, @@ -730,9 +748,8 @@ def open_mfdataset( the datasets into one before returning the result, and if combine='nested' then ``combine_nested`` is used. The filepaths must be structured according to which combining function is used, the details of which are given in the documentation for - ``combine_by_coords`` and ``combine_nested``. By default the old (now deprecated) - ``auto_combine`` will be used, please specify either ``combine='by_coords'`` or - ``combine='nested'`` in future. Requires dask to be installed. See documentation for + ``combine_by_coords`` and ``combine_nested``. By default ``combine='by_coords'`` + will be used. Requires dask to be installed. See documentation for details on dask [1]_. Global attributes from the ``attrs_file`` are used for the combined dataset. @@ -742,7 +759,7 @@ def open_mfdataset( Either a string glob in the form ``"path/to/my/files/*.nc"`` or an explicit list of files to open. Paths can be given as strings or as pathlib Paths. If concatenation along more than one dimension is desired, then ``paths`` must be a - nested list-of-lists (see ``manual_combine`` for details). (A string glob will + nested list-of-lists (see ``combine_nested`` for details). (A string glob will be expanded to a 1-dimensional list.) chunks : int or dict, optional Dictionary with keys given by dimension names and values given by chunk sizes. @@ -752,83 +769,84 @@ def open_mfdataset( see the full documentation for more details [2]_. concat_dim : str, or list of str, DataArray, Index or None, optional Dimensions to concatenate files along. You only need to provide this argument - if any of the dimensions along which you want to concatenate is not a dimension - in the original datasets, e.g., if you want to stack a collection of 2D arrays - along a third dimension. Set ``concat_dim=[..., None, ...]`` explicitly to - disable concatenation along a particular dimension. - combine : {'by_coords', 'nested'}, optional + if ``combine='by_coords'``, and if any of the dimensions along which you want to + concatenate is not a dimension in the original datasets, e.g., if you want to + stack a collection of 2D arrays along a third dimension. Set + ``concat_dim=[..., None, ...]`` explicitly to disable concatenation along a + particular dimension. Default is None, which for a 1D list of filepaths is + equivalent to opening the files separately and then merging them with + ``xarray.merge``. + combine : {"by_coords", "nested"}, optional Whether ``xarray.combine_by_coords`` or ``xarray.combine_nested`` is used to - combine all the data. If this argument is not provided, `xarray.auto_combine` is - used, but in the future this behavior will switch to use - `xarray.combine_by_coords` by default. - compat : {'identical', 'equals', 'broadcast_equals', - 'no_conflicts', 'override'}, optional + combine all the data. Default is to use ``xarray.combine_by_coords``. + compat : {"identical", "equals", "broadcast_equals", \ + "no_conflicts", "override"}, optional String indicating how to compare variables of the same name for potential conflicts when merging: - * 'broadcast_equals': all values must be equal when variables are + * "broadcast_equals": all values must be equal when variables are broadcast against each other to ensure common dimensions. - * 'equals': all values and dimensions must be the same. - * 'identical': all values, dimensions and attributes must be the + * "equals": all values and dimensions must be the same. + * "identical": all values, dimensions and attributes must be the same. - * 'no_conflicts': only values which are not null in both datasets + * "no_conflicts": only values which are not null in both datasets must be equal. The returned dataset then contains the combination of all non-null values. - * 'override': skip comparing and pick variable from first dataset + * "override": skip comparing and pick variable from first dataset preprocess : callable, optional If provided, call this function on each dataset prior to concatenation. You can find the file-name from which each dataset was loaded in - ``ds.encoding['source']``. - engine : {'netcdf4', 'scipy', 'pydap', 'h5netcdf', 'pynio', 'cfgrib'}, \ + ``ds.encoding["source"]``. + engine : {"netcdf4", "scipy", "pydap", "h5netcdf", "pynio", "cfgrib", "zarr"}, \ optional Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for - 'netcdf4'. - lock : False or duck threading.Lock, optional + "netcdf4". + lock : False or lock-like, optional Resource lock to use when reading data from disk. Only relevant when using dask or another form of parallelism. By default, appropriate locks are chosen to safely read and write files with the currently active dask scheduler. - data_vars : {'minimal', 'different', 'all' or list of str}, optional + data_vars : {"minimal", "different", "all"} or list of str, optional These data variables will be concatenated together: - * 'minimal': Only data variables in which the dimension already + * "minimal": Only data variables in which the dimension already appears are included. - * 'different': Data variables which are not equal (ignoring + * "different": Data variables which are not equal (ignoring attributes) across all datasets are also concatenated (as well as all for which dimension already appears). Beware: this option may load the data payload of data variables into memory if they are not already loaded. - * 'all': All data variables will be concatenated. + * "all": All data variables will be concatenated. * list of str: The listed data variables will be concatenated, in - addition to the 'minimal' data variables. - coords : {'minimal', 'different', 'all' or list of str}, optional + addition to the "minimal" data variables. + coords : {"minimal", "different", "all"} or list of str, optional These coordinate variables will be concatenated together: - * 'minimal': Only coordinates in which the dimension already appears + * "minimal": Only coordinates in which the dimension already appears are included. - * 'different': Coordinates which are not equal (ignoring attributes) + * "different": Coordinates which are not equal (ignoring attributes) across all datasets are also concatenated (as well as all for which dimension already appears). Beware: this option may load the data payload of coordinate variables into memory if they are not already loaded. - * 'all': All coordinate variables will be concatenated, except + * "all": All coordinate variables will be concatenated, except those corresponding to other dimensions. * list of str: The listed coordinate variables will be concatenated, - in addition the 'minimal' coordinates. + in addition the "minimal" coordinates. parallel : bool, optional If True, the open and preprocess steps of this function will be performed in parallel using ``dask.delayed``. Default is False. - join : {'outer', 'inner', 'left', 'right', 'exact, 'override'}, optional + join : {"outer", "inner", "left", "right", "exact, "override"}, optional String indicating how to combine differing indexes (excluding concat_dim) in objects - - 'outer': use the union of object indexes - - 'inner': use the intersection of object indexes - - 'left': use indexes from the first object with each dimension - - 'right': use indexes from the last object with each dimension - - 'exact': instead of aligning, raise `ValueError` when indexes to be + - "outer": use the union of object indexes + - "inner": use the intersection of object indexes + - "left": use indexes from the first object with each dimension + - "right": use indexes from the last object with each dimension + - "exact": instead of aligning, raise `ValueError` when indexes to be aligned are not equal - - 'override': if indexes are of same size, rewrite indexes to be + - "override": if indexes are of same size, rewrite indexes to be those of the first object with that dimension. Indexes for the same dimension must have the same size in all objects. attrs_file : str or pathlib.Path, optional @@ -853,7 +871,6 @@ def open_mfdataset( -------- combine_by_coords combine_nested - auto_combine open_dataset References @@ -870,7 +887,7 @@ def open_mfdataset( paths ) ) - paths = sorted(glob(paths)) + paths = sorted(glob(_normalize_path(paths))) else: paths = [str(p) if isinstance(p, Path) else p for p in paths] @@ -881,17 +898,12 @@ def open_mfdataset( # If combine='nested' then this creates a flat list which is easier to # iterate over, while saving the originally-supplied structure as "ids" if combine == "nested": - if str(concat_dim) == "_not_supplied": - raise ValueError("Must supply concat_dim when using " "combine='nested'") - else: - if isinstance(concat_dim, (str, DataArray)) or concat_dim is None: - concat_dim = [concat_dim] + if isinstance(concat_dim, (str, DataArray)) or concat_dim is None: + concat_dim = [concat_dim] combined_ids_paths = _infer_concat_order_from_positions(paths) ids, paths = (list(combined_ids_paths.keys()), list(combined_ids_paths.values())) - open_kwargs = dict( - engine=engine, chunks=chunks or {}, lock=lock, autoclose=autoclose, **kwargs - ) + open_kwargs = dict(engine=engine, chunks=chunks or {}, lock=lock, **kwargs) if parallel: import dask @@ -917,30 +929,7 @@ def open_mfdataset( # Combine all datasets, closing them in case of a ValueError try: - if combine == "_old_auto": - # Use the old auto_combine for now - # Remove this after deprecation cycle from #2616 is complete - basic_msg = dedent( - """\ - In xarray version 0.15 the default behaviour of `open_mfdataset` - will change. To retain the existing behavior, pass - combine='nested'. To use future default behavior, pass - combine='by_coords'. See - http://xarray.pydata.org/en/stable/combining.html#combining-multi - """ - ) - warnings.warn(basic_msg, FutureWarning, stacklevel=2) - - combined = auto_combine( - datasets, - concat_dim=concat_dim, - compat=compat, - data_vars=data_vars, - coords=coords, - join=join, - from_openmfds=True, - ) - elif combine == "nested": + if combine == "nested": # Combined nested list by successive concat and merge operations # along each dimension, using structure given by "ids" combined = _nested_combine( @@ -951,12 +940,18 @@ def open_mfdataset( coords=coords, ids=ids, join=join, + combine_attrs="drop", ) elif combine == "by_coords": # Redo ordering from coordinates, ignoring how they were ordered # previously combined = combine_by_coords( - datasets, compat=compat, data_vars=data_vars, coords=coords, join=join + datasets, + compat=compat, + data_vars=data_vars, + coords=coords, + join=join, + combine_attrs="drop", ) else: raise ValueError( @@ -1149,15 +1144,15 @@ def save_mfdataset( Parameters ---------- - datasets : list of xarray.Dataset + datasets : list of Dataset List of datasets to save. - paths : list of str or list of Paths + paths : list of str or list of Path List of paths to which to save each corresponding dataset. - mode : {'w', 'a'}, optional - Write ('w') or append ('a') mode. If mode='w', any existing file at + mode : {"w", "a"}, optional + Write ("w") or append ("a") mode. If mode="w", any existing file at these locations will be overwritten. - format : {'NETCDF4', 'NETCDF4_CLASSIC', 'NETCDF3_64BIT', - 'NETCDF3_CLASSIC'}, optional + format : {"NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", \ + "NETCDF3_CLASSIC"}, optional File format for the resulting netCDF file: @@ -1180,14 +1175,14 @@ def save_mfdataset( NETCDF3_64BIT format (scipy does not support netCDF4). groups : list of str, optional Paths to the netCDF4 group in each corresponding file to which to save - datasets (only works for format='NETCDF4'). The groups will be created + datasets (only works for format="NETCDF4"). The groups will be created if necessary. - engine : {'netcdf4', 'scipy', 'h5netcdf'}, optional + engine : {"netcdf4", "scipy", "h5netcdf"}, optional Engine to use when writing netCDF files. If not provided, the default engine is chosen based on available dependencies, with a - preference for 'netcdf4' if writing to a file on disk. + preference for "netcdf4" if writing to a file on disk. See `Dataset.to_netcdf` for additional information. - compute: boolean + compute : bool If true compute immediately, otherwise return a ``dask.delayed.Delayed`` object that can be computed later. @@ -1196,13 +1191,24 @@ def save_mfdataset( Save a dataset into one netCDF per year of data: + >>> ds = xr.Dataset( + ... {"a": ("time", np.linspace(0, 1, 48))}, + ... coords={"time": pd.date_range("2010-01-01", freq="M", periods=48)}, + ... ) + >>> ds + + Dimensions: (time: 48) + Coordinates: + * time (time) datetime64[ns] 2010-01-31 2010-02-28 ... 2013-12-31 + Data variables: + a (time) float64 0.0 0.02128 0.04255 0.06383 ... 0.9574 0.9787 1.0 >>> years, datasets = zip(*ds.groupby("time.year")) >>> paths = ["%s.nc" % y for y in years] >>> xr.save_mfdataset(datasets, paths) """ if mode == "w" and len(set(paths)) < len(paths): raise ValueError( - "cannot use mode='w' when writing multiple " "datasets to the same path" + "cannot use mode='w' when writing multiple datasets to the same path" ) for obj in datasets: @@ -1271,50 +1277,149 @@ def check_dtype(var): def _validate_append_dim_and_encoding( - ds_to_append, store, append_dim, encoding, **open_kwargs + ds_to_append, store, append_dim, region, encoding, **open_kwargs ): try: ds = backends.zarr.open_zarr(store, **open_kwargs) except ValueError: # store empty return + if append_dim: if append_dim not in ds.dims: - raise ValueError(f"{append_dim} not a valid dimension in the Dataset") - for data_var in ds_to_append: - if data_var in ds: - if append_dim is None: + raise ValueError( + f"append_dim={append_dim!r} does not match any existing " + f"dataset dimensions {ds.dims}" + ) + if region is not None and append_dim in region: + raise ValueError( + f"cannot list the same dimension in both ``append_dim`` and " + f"``region`` with to_zarr(), got {append_dim} in both" + ) + + if region is not None: + if not isinstance(region, dict): + raise TypeError(f"``region`` must be a dict, got {type(region)}") + for k, v in region.items(): + if k not in ds_to_append.dims: + raise ValueError( + f"all keys in ``region`` are not in Dataset dimensions, got " + f"{list(region)} and {list(ds_to_append.dims)}" + ) + if not isinstance(v, slice): + raise TypeError( + "all values in ``region`` must be slice objects, got " + f"region={region}" + ) + if v.step not in {1, None}: + raise ValueError( + "step on all slices in ``region`` must be 1 or None, got " + f"region={region}" + ) + + non_matching_vars = [ + k + for k, v in ds_to_append.variables.items() + if not set(region).intersection(v.dims) + ] + if non_matching_vars: + raise ValueError( + f"when setting `region` explicitly in to_zarr(), all " + f"variables in the dataset to write must have at least " + f"one dimension in common with the region's dimensions " + f"{list(region.keys())}, but that is not " + f"the case for some variables here. To drop these variables " + f"from this dataset before exporting to zarr, write: " + f".drop({non_matching_vars!r})" + ) + + for var_name, new_var in ds_to_append.variables.items(): + if var_name in ds.variables: + existing_var = ds.variables[var_name] + if new_var.dims != existing_var.dims: + raise ValueError( + f"variable {var_name!r} already exists with different " + f"dimension names {existing_var.dims} != " + f"{new_var.dims}, but changing variable " + f"dimensions is not supported by to_zarr()." + ) + + existing_sizes = {} + for dim, size in existing_var.sizes.items(): + if region is not None and dim in region: + start, stop, stride = region[dim].indices(size) + assert stride == 1 # region was already validated above + size = stop - start + if dim != append_dim: + existing_sizes[dim] = size + + new_sizes = { + dim: size for dim, size in new_var.sizes.items() if dim != append_dim + } + if existing_sizes != new_sizes: raise ValueError( - "variable '{}' already exists, but append_dim " - "was not set".format(data_var) + f"variable {var_name!r} already exists with different " + f"dimension sizes: {existing_sizes} != {new_sizes}. " + f"to_zarr() only supports changing dimension sizes when " + f"explicitly appending, but append_dim={append_dim!r}." ) - if data_var in encoding.keys(): + if var_name in encoding.keys(): raise ValueError( - "variable '{}' already exists, but encoding was" - "provided".format(data_var) + f"variable {var_name!r} already exists, but encoding was provided" ) def to_zarr( - dataset, - store=None, - mode=None, + dataset: Dataset, + store: Union[MutableMapping, str, Path] = None, + chunk_store=None, + mode: str = None, synchronizer=None, - group=None, - encoding=None, - compute=True, - consolidated=False, - append_dim=None, + group: str = None, + encoding: Mapping = None, + compute: bool = True, + consolidated: bool = False, + append_dim: Hashable = None, + region: Mapping[str, slice] = None, ): """This function creates an appropriate datastore for writing a dataset to a zarr ztore See `Dataset.to_zarr` for full API docs. """ - if isinstance(store, Path): - store = str(store) + + # expand str and Path arguments + store = _normalize_path(store) + chunk_store = _normalize_path(chunk_store) + if encoding is None: encoding = {} + if mode is None: + if append_dim is not None or region is not None: + mode = "a" + else: + mode = "w-" + + if mode != "a" and append_dim is not None: + raise ValueError("cannot set append_dim unless mode='a' or mode=None") + + if mode != "a" and region is not None: + raise ValueError("cannot set region unless mode='a' or mode=None") + + if mode not in ["w", "w-", "a"]: + # TODO: figure out how to handle 'r+' + raise ValueError( + "The only supported options for mode are 'w', " + f"'w-' and 'a', but mode={mode!r}" + ) + + if consolidated and region is not None: + raise ValueError( + "cannot use consolidated=True when the region argument is set. " + "Instead, set consolidated=True when writing to zarr with " + "compute=False before writing data." + ) + # validate Dataset keys, DataArray names, and attr keys/values _validate_dataset_names(dataset) _validate_attrs(dataset) @@ -1327,6 +1432,7 @@ def to_zarr( append_dim, group=group, consolidated=consolidated, + region=region, encoding=encoding, ) @@ -1336,8 +1442,10 @@ def to_zarr( synchronizer=synchronizer, group=group, consolidate_on_close=consolidated, + chunk_store=chunk_store, + append_dim=append_dim, + write_region=region, ) - zstore.append_dim = append_dim writer = ArrayWriter() # TODO: figure out how to properly handle unlimited_dims dump_to_store(dataset, zstore, writer, encoding=encoding) diff --git a/xarray/backends/apiv2.py b/xarray/backends/apiv2.py new file mode 100644 index 00000000000..0f98291983d --- /dev/null +++ b/xarray/backends/apiv2.py @@ -0,0 +1,282 @@ +import os + +from ..core import indexing +from ..core.dataset import _get_chunk, _maybe_chunk +from ..core.utils import is_remote_uri +from . import plugins + + +def _protect_dataset_variables_inplace(dataset, cache): + for name, variable in dataset.variables.items(): + if name not in variable.dims: + # no need to protect IndexVariable objects + data = indexing.CopyOnWriteArray(variable._data) + if cache: + data = indexing.MemoryCachedArray(data) + variable.data = data + + +def _get_mtime(filename_or_obj): + # if passed an actual file path, augment the token with + # the file modification time + mtime = None + + try: + path = os.fspath(filename_or_obj) + except TypeError: + path = None + + if path and not is_remote_uri(path): + mtime = os.path.getmtime(filename_or_obj) + + return mtime + + +def _chunk_ds( + backend_ds, + filename_or_obj, + engine, + chunks, + overwrite_encoded_chunks, + **extra_tokens, +): + from dask.base import tokenize + + mtime = _get_mtime(filename_or_obj) + token = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens) + name_prefix = "open_dataset-%s" % token + + variables = {} + for name, var in backend_ds.variables.items(): + var_chunks = _get_chunk(var, chunks) + variables[name] = _maybe_chunk( + name, + var, + var_chunks, + overwrite_encoded_chunks=overwrite_encoded_chunks, + name_prefix=name_prefix, + token=token, + ) + ds = backend_ds._replace(variables) + return ds + + +def _dataset_from_backend_dataset( + backend_ds, + filename_or_obj, + engine, + chunks, + cache, + overwrite_encoded_chunks, + **extra_tokens, +): + if not (isinstance(chunks, (int, dict)) or chunks is None): + if chunks != "auto": + raise ValueError( + "chunks must be an int, dict, 'auto', or None. " + "Instead found %s. " % chunks + ) + + _protect_dataset_variables_inplace(backend_ds, cache) + if chunks is None: + ds = backend_ds + else: + ds = _chunk_ds( + backend_ds, + filename_or_obj, + engine, + chunks, + overwrite_encoded_chunks, + **extra_tokens, + ) + + ds._file_obj = backend_ds._file_obj + + # Ensure source filename always stored in dataset object (GH issue #2550) + if "source" not in ds.encoding: + if isinstance(filename_or_obj, str): + ds.encoding["source"] = filename_or_obj + + return ds + + +def _resolve_decoders_kwargs(decode_cf, open_backend_dataset_parameters, **decoders): + for d in list(decoders): + if decode_cf is False and d in open_backend_dataset_parameters: + decoders[d] = False + if decoders[d] is None: + decoders.pop(d) + return decoders + + +def open_dataset( + filename_or_obj, + *, + engine=None, + chunks=None, + cache=None, + decode_cf=None, + mask_and_scale=None, + decode_times=None, + decode_timedelta=None, + use_cftime=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + backend_kwargs=None, + **kwargs, +): + """Open and decode a dataset from a file or file-like object. + + Parameters + ---------- + filename_or_obj : str, Path, file-like or DataStore + Strings and Path objects are interpreted as a path to a netCDF file + or an OpenDAP URL and opened with python-netCDF4, unless the filename + ends with .gz, in which case the file is unzipped and opened with + scipy.io.netcdf (only netCDF3 supported). Byte-strings or file-like + objects are opened by scipy.io.netcdf (netCDF3) or h5py (netCDF4/HDF). + engine : str, optional + Engine to use when reading files. If not provided, the default engine + is chosen based on available dependencies, with a preference for + "netcdf4". Options are: {"netcdf4", "scipy", "pydap", "h5netcdf",\ + "pynio", "cfgrib", "pseudonetcdf", "zarr"}. + chunks : int or dict, optional + If chunks is provided, it is used to load the new dataset into dask + arrays. ``chunks=-1`` loads the dataset with dask using a single + chunk for all arrays. `chunks={}`` loads the dataset with dask using + engine preferred chunks if exposed by the backend, otherwise with + a single chunk for all arrays. + ``chunks='auto'`` will use dask ``auto`` chunking taking into account the + engine preferred chunks. See dask chunking for more details. + cache : bool, optional + If True, cache data is loaded from the underlying datastore in memory as + NumPy arrays when accessed to avoid reading from the underlying data- + store multiple times. Defaults to True unless you specify the `chunks` + argument to use dask, in which case it defaults to False. Does not + change the behavior of coordinates corresponding to dimensions, which + always load their data from disk into a ``pandas.Index``. + decode_cf : bool, optional + Setting ``decode_cf=False`` will disable ``mask_and_scale``, + ``decode_times``, ``decode_timedelta``, ``concat_characters``, + ``decode_coords``. + mask_and_scale : bool, optional + If True, array values equal to `_FillValue` are replaced with NA and other + values are scaled according to the formula `original_values * scale_factor + + add_offset`, where `_FillValue`, `scale_factor` and `add_offset` are + taken from variable attributes (if they exist). If the `_FillValue` or + `missing_value` attribute contains multiple values, a warning will be + issued and all array values matching one of the multiple values will + be replaced by NA. mask_and_scale defaults to True except for the + pseudonetcdf backend. This keyword may not be supported by all the backends. + decode_times : bool, optional + If True, decode times encoded in the standard NetCDF datetime format + into datetime objects. Otherwise, leave them encoded as numbers. + This keyword may not be supported by all the backends. + decode_timedelta : bool, optional + If True, decode variables and coordinates with time units in + {"days", "hours", "minutes", "seconds", "milliseconds", "microseconds"} + into timedelta objects. If False, they remain encoded as numbers. + If None (default), assume the same value of decode_time. + This keyword may not be supported by all the backends. + use_cftime: bool, optional + Only relevant if encoded dates come from a standard calendar + (e.g. "gregorian", "proleptic_gregorian", "standard", or not + specified). If None (default), attempt to decode times to + ``np.datetime64[ns]`` objects; if this is not possible, decode times to + ``cftime.datetime`` objects. If True, always decode times to + ``cftime.datetime`` objects, regardless of whether or not they can be + represented using ``np.datetime64[ns]`` objects. If False, always + decode times to ``np.datetime64[ns]`` objects; if this is not possible + raise an error. This keyword may not be supported by all the backends. + concat_characters : bool, optional + If True, concatenate along the last dimension of character arrays to + form string arrays. Dimensions will only be concatenated over (and + removed) if they have no corresponding variable and if they are only + used as the last dimension of character arrays. + This keyword may not be supported by all the backends. + decode_coords : bool, optional + If True, decode the 'coordinates' attribute to identify coordinates in + the resulting dataset. This keyword may not be supported by all the + backends. + drop_variables: str or iterable, optional + A variable or list of variables to exclude from the dataset parsing. + This may be useful to drop variables with problems or + inconsistent values. + backend_kwargs: + Additional keyword arguments passed on to the engine open function. + **kwargs: dict + Additional keyword arguments passed on to the engine open function. + For example: + + - 'group': path to the netCDF4 group in the given file to open given as + a str,supported by "netcdf4", "h5netcdf", "zarr". + + - 'lock': resource lock to use when reading data from disk. Only + relevant when using dask or another form of parallelism. By default, + appropriate locks are chosen to safely read and write files with the + currently active dask scheduler. Supported by "netcdf4", "h5netcdf", + "pynio", "pseudonetcdf", "cfgrib". + + See engine open function for kwargs accepted by each specific engine. + + + Returns + ------- + dataset : Dataset + The newly created dataset. + + Notes + ----- + ``open_dataset`` opens the file with read-only access. When you modify + values of a Dataset, even one linked to files on disk, only the in-memory + copy you are manipulating in xarray is modified: the original file on disk + is never touched. + + See Also + -------- + open_mfdataset + """ + + if cache is None: + cache = chunks is None + + if backend_kwargs is not None: + kwargs.update(backend_kwargs) + + if engine is None: + engine = plugins.guess_engine(filename_or_obj) + + backend = plugins.get_backend(engine) + + decoders = _resolve_decoders_kwargs( + decode_cf, + open_backend_dataset_parameters=backend.open_dataset_parameters, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + decode_timedelta=decode_timedelta, + concat_characters=concat_characters, + use_cftime=use_cftime, + decode_coords=decode_coords, + ) + + overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None) + backend_ds = backend.open_dataset( + filename_or_obj, + drop_variables=drop_variables, + **decoders, + **kwargs, + ) + ds = _dataset_from_backend_dataset( + backend_ds, + filename_or_obj, + engine, + chunks, + cache, + overwrite_encoded_chunks, + drop_variables=drop_variables, + **decoders, + **kwargs, + ) + + return ds diff --git a/xarray/backends/cfgrib_.py b/xarray/backends/cfgrib_.py index bd946df89b2..d4933e370c7 100644 --- a/xarray/backends/cfgrib_.py +++ b/xarray/backends/cfgrib_.py @@ -1,10 +1,13 @@ +import os + import numpy as np from ..core import indexing -from ..core.utils import Frozen, FrozenDict +from ..core.utils import Frozen, FrozenDict, close_on_error from ..core.variable import Variable -from .common import AbstractDataStore, BackendArray +from .common import AbstractDataStore, BackendArray, BackendEntrypoint from .locks import SerializableLock, ensure_lock +from .store import open_backend_dataset_store # FIXME: Add a dedicated lock, even if ecCodes is supposed to be thread-safe # in most circumstances. See: @@ -21,7 +24,7 @@ def __init__(self, datastore, array): def __getitem__(self, key): return indexing.explicit_indexing_adapter( - key, self.shape, indexing.IndexingSupport.OUTER, self._getitem + key, self.shape, indexing.IndexingSupport.BASIC, self._getitem ) def _getitem(self, key): @@ -69,3 +72,60 @@ def get_encoding(self): dims = self.get_dimensions() encoding = {"unlimited_dims": {k for k, v in dims.items() if v is None}} return encoding + + +def guess_can_open_cfgrib(store_spec): + try: + _, ext = os.path.splitext(store_spec) + except TypeError: + return False + return ext in {".grib", ".grib2", ".grb", ".grb2"} + + +def open_backend_dataset_cfgrib( + filename_or_obj, + *, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + lock=None, + indexpath="{path}.{short_hash}.idx", + filter_by_keys={}, + read_keys=[], + encode_cf=("parameter", "time", "geography", "vertical"), + squeeze=True, + time_dims=("time", "step"), +): + + store = CfGribDataStore( + filename_or_obj, + indexpath=indexpath, + filter_by_keys=filter_by_keys, + read_keys=read_keys, + encode_cf=encode_cf, + squeeze=squeeze, + time_dims=time_dims, + lock=lock, + ) + + with close_on_error(store): + ds = open_backend_dataset_store( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds + + +cfgrib_backend = BackendEntrypoint( + open_dataset=open_backend_dataset_cfgrib, guess_can_open=guess_can_open_cfgrib +) diff --git a/xarray/backends/common.py b/xarray/backends/common.py index fa3ee19f542..72a63957662 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -1,14 +1,12 @@ import logging import time import traceback -import warnings -from collections.abc import Mapping import numpy as np from ..conventions import cf_encoder from ..core import indexing -from ..core.pycompat import dask_array_type +from ..core.pycompat import is_duck_dask_array from ..core.utils import FrozenDict, NdimSizeLenMixin # Create a logger object, but don't add any handlers. Leave that to user code. @@ -74,18 +72,9 @@ def __array__(self, dtype=None): return np.asarray(self[key], dtype=dtype) -class AbstractDataStore(Mapping): +class AbstractDataStore: __slots__ = () - def __iter__(self): - return iter(self.variables) - - def __getitem__(self, key): - return self.variables[key] - - def __len__(self): - return len(self.variables) - def get_dimensions(self): # pragma: no cover raise NotImplementedError() @@ -125,38 +114,6 @@ def load(self): attributes = FrozenDict(self.get_attrs()) return variables, attributes - @property - def variables(self): # pragma: no cover - warnings.warn( - "The ``variables`` property has been deprecated and " - "will be removed in xarray v0.11.", - FutureWarning, - stacklevel=2, - ) - variables, _ = self.load() - return variables - - @property - def attrs(self): # pragma: no cover - warnings.warn( - "The ``attrs`` property has been deprecated and " - "will be removed in xarray v0.11.", - FutureWarning, - stacklevel=2, - ) - _, attrs = self.load() - return attrs - - @property - def dimensions(self): # pragma: no cover - warnings.warn( - "The ``dimensions`` property has been deprecated and " - "will be removed in xarray v0.11.", - FutureWarning, - stacklevel=2, - ) - return self.get_dimensions() - def close(self): pass @@ -177,7 +134,7 @@ def __init__(self, lock=None): self.lock = lock def add(self, source, target, region=None): - if isinstance(source, dask_array_type): + if is_duck_dask_array(source): self.sources.append(source) self.targets.append(target) self.regions.append(region) @@ -241,7 +198,7 @@ def encode_attribute(self, a): """encode one attribute""" return a - def set_dimension(self, d, l): # pragma: no cover + def set_dimension(self, dim, length): # pragma: no cover raise NotImplementedError() def set_attribute(self, k, v): # pragma: no cover @@ -383,3 +340,12 @@ def encode(self, variables, attributes): variables = {k: self.encode_variable(v) for k, v in variables.items()} attributes = {k: self.encode_attribute(v) for k, v in attributes.items()} return variables, attributes + + +class BackendEntrypoint: + __slots__ = ("guess_can_open", "open_dataset", "open_dataset_parameters") + + def __init__(self, open_dataset, open_dataset_parameters=None, guess_can_open=None): + self.open_dataset = open_dataset + self.open_dataset_parameters = open_dataset_parameters + self.guess_can_open = guess_can_open diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index 4967788a1e7..4b9c95ec792 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -175,7 +175,8 @@ def acquire(self, needs_lock=True): Returns ------- - An open file object, as returned by ``opener(*args, **kwargs)``. + file-like + An open file object, as returned by ``opener(*args, **kwargs)``. """ file, _ = self._acquire_with_cache_info(needs_lock) return file @@ -313,8 +314,7 @@ def __hash__(self): class DummyFileManager(FileManager): - """FileManager that simply wraps an open file in the FileManager interface. - """ + """FileManager that simply wraps an open file in the FileManager interface.""" def __init__(self, value): self._value = value diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 393db14a7e9..b2996369ee7 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -1,12 +1,14 @@ import functools +import io +import os from distutils.version import LooseVersion import numpy as np from ..core import indexing -from ..core.utils import FrozenDict, is_remote_uri +from ..core.utils import FrozenDict, is_remote_uri, read_magic_number from ..core.variable import Variable -from .common import WritableCFDataStore, find_root_and_group +from .common import BackendEntrypoint, WritableCFDataStore, find_root_and_group from .file_manager import CachingFileManager, DummyFileManager from .locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock from .netCDF4_ import ( @@ -16,6 +18,7 @@ _get_datatype, _nc4_require_group, ) +from .store import open_backend_dataset_store class H5NetCDFArrayWrapper(BaseNetCDF4Array): @@ -67,8 +70,7 @@ def _h5netcdf_create_group(dataset, name): class H5NetCDFStore(WritableCFDataStore): - """Store for reading and writing data via h5netcdf - """ + """Store for reading and writing data via h5netcdf""" __slots__ = ( "autoclose", @@ -122,6 +124,18 @@ def open( ): import h5netcdf + if isinstance(filename, bytes): + raise ValueError( + "can't open netCDF4/HDF5 as bytes " + "try passing a path or file-like object" + ) + elif isinstance(filename, io.IOBase): + magic_number = read_magic_number(filename) + if not magic_number.startswith(b"\211HDF\r\n\032\n"): + raise ValueError( + f"{magic_number} is not the signature of a valid netCDF file" + ) + if format not in [None, "NETCDF4"]: raise ValueError("invalid format for h5netcdf backend") @@ -262,7 +276,7 @@ def prepare_variable( and "compression_opts" in encoding and encoding["complevel"] != encoding["compression_opts"] ): - raise ValueError("'complevel' and 'compression_opts' encodings " "mismatch") + raise ValueError("'complevel' and 'compression_opts' encodings mismatch") complevel = encoding.pop("complevel", 0) if complevel != 0: encoding.setdefault("compression_opts", complevel) @@ -303,3 +317,61 @@ def sync(self): def close(self, **kwargs): self._manager.close(**kwargs) + + +def guess_can_open_h5netcdf(store_spec): + try: + return read_magic_number(store_spec).startswith(b"\211HDF\r\n\032\n") + except TypeError: + pass + + try: + _, ext = os.path.splitext(store_spec) + except TypeError: + return False + + return ext in {".nc", ".nc4", ".cdf"} + + +def open_backend_dataset_h5netcdf( + filename_or_obj, + *, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + format=None, + group=None, + lock=None, + invalid_netcdf=None, + phony_dims=None, +): + + store = H5NetCDFStore.open( + filename_or_obj, + format=format, + group=group, + lock=lock, + invalid_netcdf=invalid_netcdf, + phony_dims=phony_dims, + ) + + ds = open_backend_dataset_store( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds + + +h5netcdf_backend = BackendEntrypoint( + open_dataset=open_backend_dataset_h5netcdf, guess_can_open=guess_can_open_h5netcdf +) diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py index 435690f2079..bb876a432c8 100644 --- a/xarray/backends/locks.py +++ b/xarray/backends/locks.py @@ -72,12 +72,15 @@ def _get_scheduler(get=None, collection=None) -> Optional[str]: dask.base.get_scheduler """ try: - import dask # noqa: F401 + # Fix for bug caused by dask installation that doesn't involve the toolz library + # Issue: 4164 + import dask + from dask.base import get_scheduler # noqa: F401 + + actual_get = get_scheduler(get, collection) except ImportError: return None - actual_get = dask.base.get_scheduler(get, collection) - try: from dask.distributed import Client diff --git a/xarray/backends/lru_cache.py b/xarray/backends/lru_cache.py index 56062256001..5ca49a0311a 100644 --- a/xarray/backends/lru_cache.py +++ b/xarray/backends/lru_cache.py @@ -55,8 +55,7 @@ def __getitem__(self, key: K) -> V: return value def _enforce_size_limit(self, capacity: int) -> None: - """Shrink the cache if necessary, evicting the oldest items. - """ + """Shrink the cache if necessary, evicting the oldest items.""" while len(self._cache) > capacity: key, value = self._cache.popitem(last=False) if self._on_evict is not None: diff --git a/xarray/backends/memory.py b/xarray/backends/memory.py index bee6521bce2..17095d09651 100644 --- a/xarray/backends/memory.py +++ b/xarray/backends/memory.py @@ -40,6 +40,6 @@ def set_attribute(self, k, v): # copy to imitate writing to disk. self._attributes[k] = copy.deepcopy(v) - def set_dimension(self, d, l, unlimited_dims=None): + def set_dimension(self, dim, length, unlimited_dims=None): # in this model, dimensions are accounted for in the variables pass diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 0a917cde4d7..0e35270ea9a 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -1,5 +1,7 @@ import functools import operator +import os +import pathlib from contextlib import suppress import numpy as np @@ -7,10 +9,11 @@ from .. import coding from ..coding.variables import pop_to from ..core import indexing -from ..core.utils import FrozenDict, is_remote_uri +from ..core.utils import FrozenDict, close_on_error, is_remote_uri from ..core.variable import Variable from .common import ( BackendArray, + BackendEntrypoint, WritableCFDataStore, find_root_and_group, robust_getitem, @@ -18,6 +21,7 @@ from .file_manager import CachingFileManager, DummyFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock, get_write_lock from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable +from .store import open_backend_dataset_store # This lookup table maps from dtype.byteorder to a readable endian # string used by netCDF4. @@ -333,6 +337,15 @@ def open( ): import netCDF4 + if isinstance(filename, pathlib.Path): + filename = os.fspath(filename) + + if not isinstance(filename, str): + raise ValueError( + "can only read bytes or file-like objects " + "with engine='scipy' or 'h5netcdf'" + ) + if format is None: format = "NETCDF4" @@ -490,3 +503,63 @@ def sync(self): def close(self, **kwargs): self._manager.close(**kwargs) + + +def guess_can_open_netcdf4(store_spec): + if isinstance(store_spec, str) and is_remote_uri(store_spec): + return True + try: + _, ext = os.path.splitext(store_spec) + except TypeError: + return False + return ext in {".nc", ".nc4", ".cdf"} + + +def open_backend_dataset_netcdf4( + filename_or_obj, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + group=None, + mode="r", + format="NETCDF4", + clobber=True, + diskless=False, + persist=False, + lock=None, + autoclose=False, +): + + store = NetCDF4DataStore.open( + filename_or_obj, + mode=mode, + format=format, + group=group, + clobber=clobber, + diskless=diskless, + persist=persist, + lock=lock, + autoclose=autoclose, + ) + + with close_on_error(store): + ds = open_backend_dataset_store( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds + + +netcdf4_backend = BackendEntrypoint( + open_dataset=open_backend_dataset_netcdf4, guess_can_open=guess_can_open_netcdf4 +) diff --git a/xarray/backends/netcdf3.py b/xarray/backends/netcdf3.py index c9c4baf9b01..001af0bf8e1 100644 --- a/xarray/backends/netcdf3.py +++ b/xarray/backends/netcdf3.py @@ -20,7 +20,8 @@ "uint", "int64", "uint64", - "float" "real", + "float", + "real", "double", "bool", "string", @@ -28,7 +29,14 @@ # These data-types aren't supported by netCDF3, so they are automatically # coerced instead as indicated by the "coerce_nc3_dtype" function -_nc3_dtype_coercions = {"int64": "int32", "bool": "int8"} +_nc3_dtype_coercions = { + "int64": "int32", + "uint64": "int32", + "uint32": "int32", + "uint16": "int16", + "uint8": "int8", + "bool": "int8", +} # encode all strings as UTF-8 STRING_ENCODING = "utf-8" @@ -37,12 +45,17 @@ def coerce_nc3_dtype(arr): """Coerce an array to a data type that can be stored in a netCDF-3 file - This function performs the following dtype conversions: - int64 -> int32 - bool -> int8 - - Data is checked for equality, or equivalence (non-NaN values) with - `np.allclose` with the default keyword arguments. + This function performs the dtype conversions as specified by the + ``_nc3_dtype_coercions`` mapping: + int64 -> int32 + uint64 -> int32 + uint32 -> int32 + uint16 -> int16 + uint8 -> int8 + bool -> int8 + + Data is checked for equality, or equivalence (non-NaN values) using the + ``(cast_array == original_array).all()``. """ dtype = str(arr.dtype) if dtype in _nc3_dtype_coercions: diff --git a/xarray/backends/plugins.py b/xarray/backends/plugins.py new file mode 100644 index 00000000000..d5799a78f91 --- /dev/null +++ b/xarray/backends/plugins.py @@ -0,0 +1,124 @@ +import functools +import inspect +import itertools +import logging +import typing as T +import warnings + +import pkg_resources + +from .cfgrib_ import cfgrib_backend +from .common import BackendEntrypoint +from .h5netcdf_ import h5netcdf_backend +from .netCDF4_ import netcdf4_backend +from .pseudonetcdf_ import pseudonetcdf_backend +from .pydap_ import pydap_backend +from .pynio_ import pynio_backend +from .scipy_ import scipy_backend +from .store import store_backend +from .zarr import zarr_backend + +BACKEND_ENTRYPOINTS: T.Dict[str, BackendEntrypoint] = { + "store": store_backend, + "netcdf4": netcdf4_backend, + "h5netcdf": h5netcdf_backend, + "scipy": scipy_backend, + "pseudonetcdf": pseudonetcdf_backend, + "zarr": zarr_backend, + "cfgrib": cfgrib_backend, + "pydap": pydap_backend, + "pynio": pynio_backend, +} + + +def remove_duplicates(backend_entrypoints): + + # sort and group entrypoints by name + backend_entrypoints = sorted(backend_entrypoints, key=lambda ep: ep.name) + backend_entrypoints_grouped = itertools.groupby( + backend_entrypoints, key=lambda ep: ep.name + ) + # check if there are multiple entrypoints for the same name + unique_backend_entrypoints = [] + for name, matches in backend_entrypoints_grouped: + matches = list(matches) + unique_backend_entrypoints.append(matches[0]) + matches_len = len(matches) + if matches_len > 1: + selected_module_name = matches[0].module_name + all_module_names = [e.module_name for e in matches] + warnings.warn( + f"Found {matches_len} entrypoints for the engine name {name}:" + f"\n {all_module_names}.\n It will be used: {selected_module_name}.", + RuntimeWarning, + ) + return unique_backend_entrypoints + + +def detect_parameters(open_dataset): + signature = inspect.signature(open_dataset) + parameters = signature.parameters + for name, param in parameters.items(): + if param.kind in ( + inspect.Parameter.VAR_KEYWORD, + inspect.Parameter.VAR_POSITIONAL, + ): + raise TypeError( + f"All the parameters in {open_dataset!r} signature should be explicit. " + "*args and **kwargs is not supported" + ) + return tuple(parameters) + + +def create_engines_dict(backend_entrypoints): + engines = {} + for backend_ep in backend_entrypoints: + name = backend_ep.name + backend = backend_ep.load() + engines[name] = backend + return engines + + +def set_missing_parameters(engines): + for name, backend in engines.items(): + if backend.open_dataset_parameters is None: + open_dataset = backend.open_dataset + backend.open_dataset_parameters = detect_parameters(open_dataset) + + +def build_engines(entrypoints): + backend_entrypoints = BACKEND_ENTRYPOINTS.copy() + pkg_entrypoints = remove_duplicates(entrypoints) + external_backend_entrypoints = create_engines_dict(pkg_entrypoints) + backend_entrypoints.update(external_backend_entrypoints) + set_missing_parameters(backend_entrypoints) + return backend_entrypoints + + +@functools.lru_cache(maxsize=1) +def list_engines(): + entrypoints = pkg_resources.iter_entry_points("xarray.backends") + return build_engines(entrypoints) + + +def guess_engine(store_spec): + engines = list_engines() + + for engine, backend in engines.items(): + try: + if backend.guess_can_open and backend.guess_can_open(store_spec): + return engine + except Exception: + logging.exception(f"{engine!r} fails while guessing") + + raise ValueError("cannot guess the engine, try passing one explicitly") + + +def get_backend(engine): + """Select open_dataset method based on current engine""" + engines = list_engines() + if engine not in engines: + raise ValueError( + f"unrecognized engine {engine} must be one of: {list(engines)}" + ) + return engines[engine] diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index 17a4eb8f6bf..d9128d1d503 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -1,11 +1,12 @@ import numpy as np from ..core import indexing -from ..core.utils import Frozen, FrozenDict +from ..core.utils import Frozen, FrozenDict, close_on_error from ..core.variable import Variable -from .common import AbstractDataStore, BackendArray +from .common import AbstractDataStore, BackendArray, BackendEntrypoint from .file_manager import CachingFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock +from .store import open_backend_dataset_store # psuedonetcdf can invoke netCDF libraries internally PNETCDF_LOCK = combine_locks([HDF5_LOCK, NETCDFC_LOCK]) @@ -35,8 +36,7 @@ def _getitem(self, key): class PseudoNetCDFDataStore(AbstractDataStore): - """Store for accessing datasets via PseudoNetCDF - """ + """Store for accessing datasets via PseudoNetCDF""" @classmethod def open(cls, filename, lock=None, mode=None, **format_kwargs): @@ -86,3 +86,55 @@ def get_encoding(self): def close(self): self._manager.close() + + +def open_backend_dataset_pseudonetcdf( + filename_or_obj, + mask_and_scale=False, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + mode=None, + lock=None, + **format_kwargs, +): + + store = PseudoNetCDFDataStore.open( + filename_or_obj, lock=lock, mode=mode, **format_kwargs + ) + + with close_on_error(store): + ds = open_backend_dataset_store( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds + + +# *args and **kwargs are not allowed in open_backend_dataset_ kwargs, +# unless the open_dataset_parameters are explicity defined like this: +open_dataset_parameters = ( + "filename_or_obj", + "mask_and_scale", + "decode_times", + "concat_characters", + "decode_coords", + "drop_variables", + "use_cftime", + "decode_timedelta", + "mode", + "lock", +) +pseudonetcdf_backend = BackendEntrypoint( + open_dataset=open_backend_dataset_pseudonetcdf, + open_dataset_parameters=open_dataset_parameters, +) diff --git a/xarray/backends/pydap_.py b/xarray/backends/pydap_.py index 20e943ab561..4995045a739 100644 --- a/xarray/backends/pydap_.py +++ b/xarray/backends/pydap_.py @@ -2,9 +2,10 @@ from ..core import indexing from ..core.pycompat import integer_types -from ..core.utils import Frozen, FrozenDict, is_dict_like +from ..core.utils import Frozen, FrozenDict, close_on_error, is_dict_like, is_remote_uri from ..core.variable import Variable -from .common import AbstractDataStore, BackendArray, robust_getitem +from .common import AbstractDataStore, BackendArray, BackendEntrypoint, robust_getitem +from .store import open_backend_dataset_store class PydapArrayWrapper(BackendArray): @@ -92,3 +93,43 @@ def get_attrs(self): def get_dimensions(self): return Frozen(self.ds.dimensions) + + +def guess_can_open_pydap(store_spec): + return isinstance(store_spec, str) and is_remote_uri(store_spec) + + +def open_backend_dataset_pydap( + filename_or_obj, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + session=None, +): + + store = PydapDataStore.open( + filename_or_obj, + session=session, + ) + + with close_on_error(store): + ds = open_backend_dataset_store( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds + + +pydap_backend = BackendEntrypoint( + open_dataset=open_backend_dataset_pydap, guess_can_open=guess_can_open_pydap +) diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 1c66ff1ee48..dc6c47935e8 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -1,11 +1,12 @@ import numpy as np from ..core import indexing -from ..core.utils import Frozen, FrozenDict +from ..core.utils import Frozen, FrozenDict, close_on_error from ..core.variable import Variable -from .common import AbstractDataStore, BackendArray +from .common import AbstractDataStore, BackendArray, BackendEntrypoint from .file_manager import CachingFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, SerializableLock, combine_locks, ensure_lock +from .store import open_backend_dataset_store # PyNIO can invoke netCDF libraries internally # Add a dedicated lock just in case NCL as well isn't thread-safe. @@ -41,8 +42,7 @@ def _getitem(self, key): class NioDataStore(AbstractDataStore): - """Store for accessing datasets via PyNIO - """ + """Store for accessing datasets via PyNIO""" def __init__(self, filename, mode="r", lock=None, **kwargs): import Nio @@ -83,3 +83,39 @@ def get_encoding(self): def close(self): self._manager.close() + + +def open_backend_dataset_pynio( + filename_or_obj, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + mode="r", + lock=None, +): + + store = NioDataStore( + filename_or_obj, + mode=mode, + lock=lock, + ) + + with close_on_error(store): + ds = open_backend_dataset_store( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds + + +pynio_backend = BackendEntrypoint(open_dataset=open_backend_dataset_pynio) diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 77beffd09b1..a0500c7e1c2 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -50,7 +50,7 @@ def shape(self): return self._shape def _get_indexer(self, key): - """ Get indexer for rasterio array. + """Get indexer for rasterio array. Parameter --------- @@ -221,14 +221,17 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, loc vrt = filename filename = vrt.src_dataset.name vrt_params = dict( + src_crs=vrt.src_crs.to_string(), crs=vrt.crs.to_string(), resampling=vrt.resampling, + tolerance=vrt.tolerance, src_nodata=vrt.src_nodata, nodata=vrt.nodata, - tolerance=vrt.tolerance, - transform=vrt.transform, width=vrt.width, height=vrt.height, + src_transform=vrt.src_transform, + transform=vrt.transform, + dtype=vrt.working_dtype, warp_extras=vrt.warp_extras, ) diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 9863285d6de..873a91f9c07 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -1,14 +1,16 @@ -from io import BytesIO +import io +import os import numpy as np from ..core.indexing import NumpyIndexingAdapter -from ..core.utils import Frozen, FrozenDict +from ..core.utils import Frozen, FrozenDict, close_on_error, read_magic_number from ..core.variable import Variable -from .common import BackendArray, WritableCFDataStore +from .common import BackendArray, BackendEntrypoint, WritableCFDataStore from .file_manager import CachingFileManager, DummyFileManager from .locks import ensure_lock, get_write_lock from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable, is_valid_nc3_name +from .store import open_backend_dataset_store def _decode_string(s): @@ -57,9 +59,10 @@ def __setitem__(self, key, value): def _open_scipy_netcdf(filename, mode, mmap, version): - import scipy.io import gzip + import scipy.io + # if the string ends with .gz, then gunzip and open as netcdf file if isinstance(filename, str) and filename.endswith(".gz"): try: @@ -69,15 +72,13 @@ def _open_scipy_netcdf(filename, mode, mmap, version): except TypeError as e: # TODO: gzipped loading only works with NetCDF3 files. if "is not a valid NetCDF 3 file" in e.message: - raise ValueError( - "gzipped file loading only supports " "NetCDF 3 files." - ) + raise ValueError("gzipped file loading only supports NetCDF 3 files.") else: raise if isinstance(filename, bytes) and filename.startswith(b"CDF"): # it's a NetCDF3 bytestring - filename = BytesIO(filename) + filename = io.BytesIO(filename) try: return scipy.io.netcdf_file(filename, mode=mode, mmap=mmap, version=version) @@ -109,9 +110,7 @@ def __init__( self, filename_or_obj, mode="r", format=None, group=None, mmap=None, lock=None ): if group is not None: - raise ValueError( - "cannot save to a group with the " "scipy.io.netcdf backend" - ) + raise ValueError("cannot save to a group with the scipy.io.netcdf backend") if format is None or format == "NETCDF3_64BIT": version = 2 @@ -221,3 +220,54 @@ def sync(self): def close(self): self._manager.close() + + +def guess_can_open_scipy(store_spec): + try: + return read_magic_number(store_spec).startswith(b"CDF") + except TypeError: + pass + + try: + _, ext = os.path.splitext(store_spec) + except TypeError: + return False + return ext in {".nc", ".nc4", ".cdf", ".gz"} + + +def open_backend_dataset_scipy( + filename_or_obj, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + mode="r", + format=None, + group=None, + mmap=None, + lock=None, +): + + store = ScipyDataStore( + filename_or_obj, mode=mode, format=format, group=group, mmap=mmap, lock=lock + ) + with close_on_error(store): + ds = open_backend_dataset_store( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds + + +scipy_backend = BackendEntrypoint( + open_dataset=open_backend_dataset_scipy, guess_can_open=guess_can_open_scipy +) diff --git a/xarray/backends/store.py b/xarray/backends/store.py new file mode 100644 index 00000000000..d314a9c3ca9 --- /dev/null +++ b/xarray/backends/store.py @@ -0,0 +1,47 @@ +from .. import conventions +from ..core.dataset import Dataset +from .common import AbstractDataStore, BackendEntrypoint + + +def guess_can_open_store(store_spec): + return isinstance(store_spec, AbstractDataStore) + + +def open_backend_dataset_store( + store, + *, + mask_and_scale=True, + decode_times=True, + concat_characters=True, + decode_coords=True, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, +): + vars, attrs = store.load() + file_obj = store + encoding = store.get_encoding() + + vars, attrs, coord_names = conventions.decode_cf_variables( + vars, + attrs, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + + ds = Dataset(vars, attrs=attrs) + ds = ds.set_coords(coord_names.intersection(vars)) + ds._file_obj = file_obj + ds.encoding = encoding + + return ds + + +store_backend = BackendEntrypoint( + open_dataset=open_backend_dataset_store, guess_can_open=guess_can_open_store +) diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index c262dae2811..3b4b3a3d9d5 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -1,13 +1,20 @@ -import warnings +import os +import pathlib import numpy as np from .. import coding, conventions from ..core import indexing from ..core.pycompat import integer_types -from ..core.utils import FrozenDict, HiddenKeyDict +from ..core.utils import FrozenDict, HiddenKeyDict, close_on_error from ..core.variable import Variable -from .common import AbstractWritableDataStore, BackendArray, _encode_variable_name +from .common import ( + AbstractWritableDataStore, + BackendArray, + BackendEntrypoint, + _encode_variable_name, +) +from .store import open_backend_dataset_store # need some special secret attributes to tell us the dimensions DIMENSION_KEY = "_ARRAY_DIMENSIONS" @@ -65,7 +72,7 @@ def __getitem__(self, key): # could possibly have a work-around for 0d data here -def _determine_zarr_chunks(enc_chunks, var_chunks, ndim): +def _determine_zarr_chunks(enc_chunks, var_chunks, ndim, name): """ Given encoding chunks (possibly None) and variable chunks (possibly None) """ @@ -88,15 +95,16 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim): if var_chunks and enc_chunks is None: if any(len(set(chunks[:-1])) > 1 for chunks in var_chunks): raise ValueError( - "Zarr requires uniform chunk sizes except for final chunk." - " Variable dask chunks %r are incompatible. Consider " - "rechunking using `chunk()`." % (var_chunks,) + "Zarr requires uniform chunk sizes except for final chunk. " + f"Variable named {name!r} has incompatible dask chunks: {var_chunks!r}. " + "Consider rechunking using `chunk()`." ) if any((chunks[0] < chunks[-1]) for chunks in var_chunks): raise ValueError( "Final chunk of Zarr array must be the same size or smaller " - "than the first. Variable Dask chunks %r are incompatible. " - "Consider rechunking using `chunk()`." % var_chunks + f"than the first. Variable named {name!r} has incompatible Dask chunks {var_chunks!r}." + "Consider either rechunking using `chunk()` or instead deleting " + "or modifying `encoding['chunks']`." ) # return the first chunk for each dimension return tuple(chunk[0] for chunk in var_chunks) @@ -114,13 +122,15 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim): if len(enc_chunks_tuple) != ndim: # throw away encoding chunks, start over - return _determine_zarr_chunks(None, var_chunks, ndim) + return _determine_zarr_chunks(None, var_chunks, ndim, name) for x in enc_chunks_tuple: if not isinstance(x, int): raise TypeError( - "zarr chunks must be an int or a tuple of ints. " - "Instead found %r" % (enc_chunks_tuple,) + "zarr chunk sizes specified in `encoding['chunks']` " + "must be an int or a tuple of ints. " + f"Instead found encoding['chunks']={enc_chunks_tuple!r} " + f"for variable named {name!r}." ) # if there are chunks in encoding and the variable data is a numpy array, @@ -139,22 +149,27 @@ def _determine_zarr_chunks(enc_chunks, var_chunks, ndim): # threads if var_chunks and enc_chunks_tuple: for zchunk, dchunks in zip(enc_chunks_tuple, var_chunks): + if len(dchunks) == 1: + continue for dchunk in dchunks[:-1]: if dchunk % zchunk: raise NotImplementedError( - "Specified zarr chunks %r would overlap multiple dask " - "chunks %r. This is not implemented in xarray yet. " - " Consider rechunking the data using " - "`chunk()` or specifying different chunks in encoding." - % (enc_chunks_tuple, var_chunks) + f"Specified zarr chunks encoding['chunks']={enc_chunks_tuple!r} for " + f"variable named {name!r} would overlap multiple dask chunks {var_chunks!r}. " + "This is not implemented in xarray yet. " + "Consider either rechunking using `chunk()` or instead deleting " + "or modifying `encoding['chunks']`." ) if dchunks[-1] > zchunk: raise ValueError( "Final chunk of Zarr array must be the same size or " - "smaller than the first. The specified Zarr chunk " - "encoding is %r, but %r in variable Dask chunks %r is " - "incompatible. Consider rechunking using `chunk()`." - % (enc_chunks_tuple, dchunks, var_chunks) + "smaller than the first. " + f"Specified Zarr chunk encoding['chunks']={enc_chunks_tuple}, " + f"for variable named {name!r} " + f"but {dchunks} in the variable's Dask chunks {var_chunks} is " + "incompatible with this encoding. " + "Consider either rechunking using `chunk()` or instead deleting " + "or modifying `encoding['chunks']`." ) return enc_chunks_tuple @@ -177,13 +192,13 @@ def _get_zarr_dims_and_attrs(zarr_obj, dimension_key): return dimensions, attributes -def extract_zarr_variable_encoding(variable, raise_on_invalid=False): +def extract_zarr_variable_encoding(variable, raise_on_invalid=False, name=None): """ Extract zarr encoding dictionary from xarray Variable Parameters ---------- - variable : xarray.Variable + variable : Variable raise_on_invalid : bool, optional Returns @@ -199,7 +214,7 @@ def extract_zarr_variable_encoding(variable, raise_on_invalid=False): invalid = [k for k in encoding if k not in valid_encodings] if invalid: raise ValueError( - "unexpected encoding parameters for zarr " "backend: %r" % invalid + "unexpected encoding parameters for zarr backend: %r" % invalid ) else: for k in list(encoding): @@ -207,7 +222,7 @@ def extract_zarr_variable_encoding(variable, raise_on_invalid=False): del encoding[k] chunks = _determine_zarr_chunks( - encoding.get("chunks"), variable.chunks, variable.ndim + encoding.get("chunks"), variable.chunks, variable.ndim, name ) encoding["chunks"] = chunks return encoding @@ -227,12 +242,12 @@ def encode_zarr_variable(var, needs_copy=True, name=None): Parameters ---------- - var : xarray.Variable + var : Variable A variable holding un-encoded data. Returns ------- - out : xarray.Variable + out : Variable A variable which has been encoded as described above. """ @@ -249,16 +264,16 @@ def encode_zarr_variable(var, needs_copy=True, name=None): class ZarrStore(AbstractWritableDataStore): - """Store for reading and writing data via zarr - """ + """Store for reading and writing data via zarr""" __slots__ = ( - "append_dim", "ds", + "_append_dim", "_consolidate_on_close", "_group", "_read_only", "_synchronizer", + "_write_region", ) @classmethod @@ -270,24 +285,37 @@ def open_group( group=None, consolidated=False, consolidate_on_close=False, + chunk_store=None, + append_dim=None, + write_region=None, ): import zarr + # zarr doesn't support pathlib.Path objects yet. zarr-python#601 + if isinstance(store, pathlib.Path): + store = os.fspath(store) + open_kwargs = dict(mode=mode, synchronizer=synchronizer, path=group) + if chunk_store: + open_kwargs["chunk_store"] = chunk_store + if consolidated: # TODO: an option to pass the metadata_key keyword zarr_group = zarr.open_consolidated(store, **open_kwargs) else: zarr_group = zarr.open_group(store, **open_kwargs) - return cls(zarr_group, consolidate_on_close) + return cls(zarr_group, consolidate_on_close, append_dim, write_region) - def __init__(self, zarr_group, consolidate_on_close=False): + def __init__( + self, zarr_group, consolidate_on_close=False, append_dim=None, write_region=None + ): self.ds = zarr_group self._read_only = self.ds.read_only self._synchronizer = self.ds.synchronizer self._group = self.ds.path self._consolidate_on_close = consolidate_on_close - self.append_dim = None + self._append_dim = append_dim + self._write_region = write_region def open_store_variable(self, name, zarr_array): data = indexing.LazilyOuterIndexedArray(ZarrArrayWrapper(name, self)) @@ -295,6 +323,7 @@ def open_store_variable(self, name, zarr_array): attributes = dict(attributes) encoding = { "chunks": zarr_array.chunks, + "preferred_chunks": dict(zip(dimensions, zarr_array.chunks)), "compressor": zarr_array.compressor, "filters": zarr_array.filters, } @@ -380,6 +409,7 @@ def store( dimension on which the zarray will be appended only needed in append mode """ + import zarr existing_variables = { vn for vn in variables if _encode_variable_name(vn) in self.ds @@ -401,11 +431,14 @@ def store( variables_with_encoding, _ = self.encode(variables_with_encoding, {}) variables_encoded.update(variables_with_encoding) - self.set_attributes(attributes) - self.set_dimensions(variables_encoded, unlimited_dims=unlimited_dims) + if self._write_region is None: + self.set_attributes(attributes) + self.set_dimensions(variables_encoded, unlimited_dims=unlimited_dims) self.set_variables( variables_encoded, check_encoding_set, writer, unlimited_dims=unlimited_dims ) + if self._consolidate_on_close: + zarr.consolidate_metadata(self.ds.store) def sync(self): pass @@ -439,21 +472,15 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No fill_value = attrs.pop("_FillValue", None) if v.encoding == {"_FillValue": None} and fill_value is None: v.encoding = {} + if name in self.ds: + # existing variable zarr_array = self.ds[name] - if self.append_dim in dims: - # this is the DataArray that has append_dim as a - # dimension - append_axis = dims.index(self.append_dim) - new_shape = list(zarr_array.shape) - new_shape[append_axis] += v.shape[append_axis] - new_region = [slice(None)] * len(new_shape) - new_region[append_axis] = slice(zarr_array.shape[append_axis], None) - zarr_array.resize(new_shape) - writer.add(v.data, zarr_array, region=tuple(new_region)) else: # new variable - encoding = extract_zarr_variable_encoding(v, raise_on_invalid=check) + encoding = extract_zarr_variable_encoding( + v, raise_on_invalid=check, name=vn + ) encoded_attrs = {} # the magic for storing the hidden dimension data encoded_attrs[DIMENSION_KEY] = dims @@ -466,13 +493,27 @@ def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=No name, shape=shape, dtype=dtype, fill_value=fill_value, **encoding ) zarr_array.attrs.put(encoded_attrs) - writer.add(v.data, zarr_array) - def close(self): - if self._consolidate_on_close: - import zarr + write_region = self._write_region if self._write_region is not None else {} + write_region = {dim: write_region.get(dim, slice(None)) for dim in dims} - zarr.consolidate_metadata(self.ds.store) + if self._append_dim is not None and self._append_dim in dims: + # resize existing variable + append_axis = dims.index(self._append_dim) + assert write_region[self._append_dim] == slice(None) + write_region[self._append_dim] = slice( + zarr_array.shape[append_axis], None + ) + + new_shape = list(zarr_array.shape) + new_shape[append_axis] += v.shape[append_axis] + zarr_array.resize(new_shape) + + region = tuple(write_region[dim] for dim in dims) + writer.add(v.data, zarr_array, region) + + def close(self): + pass def open_zarr( @@ -488,6 +529,9 @@ def open_zarr( drop_variables=None, consolidated=False, overwrite_encoded_chunks=False, + chunk_store=None, + decode_timedelta=None, + use_cftime=None, **kwargs, ): """Load and decode a dataset from a Zarr store. @@ -540,13 +584,30 @@ def open_zarr( decode_coords : bool, optional If True, decode the 'coordinates' attribute to identify coordinates in the resulting dataset. - drop_variables : string or iterable, optional + drop_variables : str or iterable, optional A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or inconsistent values. consolidated : bool, optional Whether to open the store using zarr's consolidated metadata capability. Only works for stores that have already been consolidated. + chunk_store : MutableMapping, optional + A separate Zarr store only for chunk data. + decode_timedelta : bool, optional + If True, decode variables and coordinates with time units in + {'days', 'hours', 'minutes', 'seconds', 'milliseconds', 'microseconds'} + into timedelta objects. If False, leave them encoded as numbers. + If None (default), assume the same value of decode_time. + use_cftime: bool, optional + Only relevant if encoded dates come from a standard calendar + (e.g. "gregorian", "proleptic_gregorian", "standard", or not + specified). If None (default), attempt to decode times to + ``np.datetime64[ns]`` objects; if this is not possible, decode times to + ``cftime.datetime`` objects. If True, always decode times to + ``cftime.datetime`` objects, regardless of whether or not they can be + represented using ``np.datetime64[ns]`` objects. If False, always + decode times to ``np.datetime64[ns]`` objects; if this is not possible + raise an error. Returns ------- @@ -561,126 +622,86 @@ def open_zarr( ---------- http://zarr.readthedocs.io/ """ - if "auto_chunk" in kwargs: - auto_chunk = kwargs.pop("auto_chunk") - if auto_chunk: - chunks = "auto" # maintain backwards compatibility - else: - chunks = None + from .api import open_dataset - warnings.warn( - "auto_chunk is deprecated. Use chunks='auto' instead.", - FutureWarning, - stacklevel=2, - ) + if chunks == "auto": + try: + import dask.array # noqa + + chunks = {} + except ImportError: + chunks = None if kwargs: raise TypeError( "open_zarr() got unexpected keyword arguments " + ",".join(kwargs.keys()) ) - if not isinstance(chunks, (int, dict)): - if chunks != "auto" and chunks is not None: - raise ValueError( - "chunks must be an int, dict, 'auto', or None. " - "Instead found %s. " % chunks - ) + backend_kwargs = { + "synchronizer": synchronizer, + "consolidated": consolidated, + "overwrite_encoded_chunks": overwrite_encoded_chunks, + "chunk_store": chunk_store, + } - if chunks == "auto": - try: - import dask.array # noqa - except ImportError: - chunks = None + ds = open_dataset( + filename_or_obj=store, + group=group, + decode_cf=decode_cf, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + engine="zarr", + chunks=chunks, + drop_variables=drop_variables, + backend_kwargs=backend_kwargs, + decode_timedelta=decode_timedelta, + use_cftime=use_cftime, + ) + + return ds + + +def open_backend_dataset_zarr( + filename_or_obj, + mask_and_scale=True, + decode_times=None, + concat_characters=None, + decode_coords=None, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + group=None, + mode="r", + synchronizer=None, + consolidated=False, + consolidate_on_close=False, + chunk_store=None, +): - if not decode_cf: - mask_and_scale = False - decode_times = False - concat_characters = False - decode_coords = False + store = ZarrStore.open_group( + filename_or_obj, + group=group, + mode=mode, + synchronizer=synchronizer, + consolidated=consolidated, + consolidate_on_close=consolidate_on_close, + chunk_store=chunk_store, + ) - def maybe_decode_store(store, lock=False): - ds = conventions.decode_cf( + with close_on_error(store): + ds = open_backend_dataset_store( store, mask_and_scale=mask_and_scale, decode_times=decode_times, concat_characters=concat_characters, decode_coords=decode_coords, drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, ) + return ds - # TODO: this is where we would apply caching - - return ds - - # Zarr supports a wide range of access modes, but for now xarray either - # reads or writes from a store, never both. For open_zarr, we only read - mode = "r" - zarr_store = ZarrStore.open_group( - store, - mode=mode, - synchronizer=synchronizer, - group=group, - consolidated=consolidated, - ) - ds = maybe_decode_store(zarr_store) - - # auto chunking needs to be here and not in ZarrStore because variable - # chunks do not survive decode_cf - # return trivial case - if not chunks: - return ds - - # adapted from Dataset.Chunk() - if isinstance(chunks, int): - chunks = dict.fromkeys(ds.dims, chunks) - - if isinstance(chunks, tuple) and len(chunks) == len(ds.dims): - chunks = dict(zip(ds.dims, chunks)) - - def get_chunk(name, var, chunks): - chunk_spec = dict(zip(var.dims, var.encoding.get("chunks"))) - - # Coordinate labels aren't chunked - if var.ndim == 1 and var.dims[0] == name: - return chunk_spec - - if chunks == "auto": - return chunk_spec - - for dim in var.dims: - if dim in chunks: - spec = chunks[dim] - if isinstance(spec, int): - spec = (spec,) - if isinstance(spec, (tuple, list)) and chunk_spec[dim]: - if any(s % chunk_spec[dim] for s in spec): - warnings.warn( - "Specified Dask chunks %r would " - "separate Zarr chunk shape %r for " - "dimension %r. This significantly " - "degrades performance. Consider " - "rechunking after loading instead." - % (chunks[dim], chunk_spec[dim], dim), - stacklevel=2, - ) - chunk_spec[dim] = chunks[dim] - return chunk_spec - - def maybe_chunk(name, var, chunks): - from dask.base import tokenize - - chunk_spec = get_chunk(name, var, chunks) - - if (var.ndim > 0) and (chunk_spec is not None): - # does this cause any data to be read? - token2 = tokenize(name, var._data) - name2 = "zarr-%s" % token2 - var = var.chunk(chunk_spec, name=name2, lock=None) - if overwrite_encoded_chunks and var.chunks is not None: - var.encoding["chunks"] = tuple(x[0] for x in var.chunks) - return var - else: - return var - variables = {k: maybe_chunk(k, v, chunks) for k, v in ds.variables.items()} - return ds._replace_vars_and_dims(variables) +zarr_backend = BackendEntrypoint(open_dataset=open_backend_dataset_zarr) diff --git a/xarray/coding/cftime_offsets.py b/xarray/coding/cftime_offsets.py index a2306331ca7..3c92c816e12 100644 --- a/xarray/coding/cftime_offsets.py +++ b/xarray/coding/cftime_offsets.py @@ -102,7 +102,7 @@ def __sub__(self, other): import cftime if isinstance(other, cftime.datetime): - raise TypeError("Cannot subtract a cftime.datetime " "from a time offset.") + raise TypeError("Cannot subtract a cftime.datetime from a time offset.") elif type(other) == type(self): return type(self)(self.n - other.n) else: @@ -122,7 +122,7 @@ def __radd__(self, other): def __rsub__(self, other): if isinstance(other, BaseCFTimeOffset) and type(self) != type(other): - raise TypeError("Cannot subtract cftime offsets of differing " "types") + raise TypeError("Cannot subtract cftime offsets of differing types") return -self + other def __apply__(self): @@ -221,8 +221,7 @@ def _adjust_n_years(other, n, month, reference_day): def _shift_month(date, months, day_option="start"): - """Shift the date to a month start or end a given number of months away. - """ + """Shift the date to a month start or end a given number of months away.""" import cftime delta_year = (date.month + months) // 12 @@ -354,8 +353,7 @@ def onOffset(self, date): class QuarterOffset(BaseCFTimeOffset): - """Quarter representation copied off of pandas/tseries/offsets.py - """ + """Quarter representation copied off of pandas/tseries/offsets.py""" _freq: ClassVar[str] _default_month: ClassVar[int] @@ -795,19 +793,19 @@ def cftime_range( Left bound for generating dates. end : str or cftime.datetime, optional Right bound for generating dates. - periods : integer, optional + periods : int, optional Number of periods to generate. - freq : str, default 'D', BaseCFTimeOffset, or None - Frequency strings can have multiples, e.g. '5H'. - normalize : bool, default False + freq : str or None, default: "D" + Frequency strings can have multiples, e.g. "5H". + normalize : bool, default: False Normalize start/end dates to midnight before generating date range. - name : str, default None + name : str, default: None Name of the resulting index - closed : {None, 'left', 'right'}, optional + closed : {"left", "right"} or None, default: None Make the interval closed with respect to the given frequency to the - 'left', 'right', or both sides (None, the default). - calendar : str - Calendar type for the datetimes (default 'standard'). + "left", "right", or both sides (None). + calendar : str, default: "standard" + Calendar type for the datetimes. Returns ------- @@ -941,12 +939,12 @@ def cftime_range( >>> xr.cftime_range(start="2000", periods=6, freq="2MS", calendar="noleap") CFTimeIndex([2000-01-01 00:00:00, 2000-03-01 00:00:00, 2000-05-01 00:00:00, 2000-07-01 00:00:00, 2000-09-01 00:00:00, 2000-11-01 00:00:00], - dtype='object') + dtype='object', length=6, calendar='noleap', freq='2MS') As in the standard pandas function, three of the ``start``, ``end``, ``periods``, or ``freq`` arguments must be specified at a given time, with the other set to ``None``. See the `pandas documentation - `_ + `_ for more examples of the behavior of ``date_range`` with each of the parameters. diff --git a/xarray/coding/cftimeindex.py b/xarray/coding/cftimeindex.py index 6fc28d213dd..e414740d420 100644 --- a/xarray/coding/cftimeindex.py +++ b/xarray/coding/cftimeindex.py @@ -50,8 +50,14 @@ from xarray.core.utils import is_scalar from ..core.common import _contains_cftime_datetimes +from ..core.options import OPTIONS from .times import _STANDARD_CALENDARS, cftime_to_nptime, infer_calendar_name +# constants for cftimeindex.repr +CFTIME_REPR_LENGTH = 19 +ITEMS_IN_REPR_MAX_ELSE_ELLIPSIS = 100 +REPR_ELLIPSIS_SHOW_ITEMS_FRONT_END = 10 + def named(name, pattern): return "(?P<" + name + ">" + pattern + ")" @@ -85,22 +91,25 @@ def build_pattern(date_sep=r"\-", datetime_sep=r"T", time_sep=r"\:"): _BASIC_PATTERN = build_pattern(date_sep="", time_sep="") _EXTENDED_PATTERN = build_pattern() -_PATTERNS = [_BASIC_PATTERN, _EXTENDED_PATTERN] +_CFTIME_PATTERN = build_pattern(datetime_sep=" ") +_PATTERNS = [_BASIC_PATTERN, _EXTENDED_PATTERN, _CFTIME_PATTERN] -def parse_iso8601(datetime_string): +def parse_iso8601_like(datetime_string): for pattern in _PATTERNS: match = re.match(pattern, datetime_string) if match: return match.groupdict() - raise ValueError("no ISO-8601 match for string: %s" % datetime_string) + raise ValueError( + f"no ISO-8601 or cftime-string-like match for string: {datetime_string}" + ) def _parse_iso8601_with_reso(date_type, timestr): import cftime default = date_type(1, 1, 1) - result = parse_iso8601(timestr) + result = parse_iso8601_like(timestr) replace = {} for attr in ["year", "month", "day", "hour", "minute", "second"]: @@ -215,6 +224,49 @@ def assert_all_valid_date_type(data): ) +def format_row(times, indent=0, separator=", ", row_end=",\n"): + """Format a single row from format_times.""" + return indent * " " + separator.join(map(str, times)) + row_end + + +def format_times( + index, + max_width, + offset, + separator=", ", + first_row_offset=0, + intermediate_row_end=",\n", + last_row_end="", +): + """Format values of cftimeindex as pd.Index.""" + n_per_row = max(max_width // (CFTIME_REPR_LENGTH + len(separator)), 1) + n_rows = int(np.ceil(len(index) / n_per_row)) + + representation = "" + for row in range(n_rows): + indent = first_row_offset if row == 0 else offset + row_end = last_row_end if row == n_rows - 1 else intermediate_row_end + times_for_row = index[row * n_per_row : (row + 1) * n_per_row] + representation = representation + format_row( + times_for_row, indent=indent, separator=separator, row_end=row_end + ) + + return representation + + +def format_attrs(index, separator=", "): + """Format attributes of CFTimeIndex for __repr__.""" + attrs = { + "dtype": f"'{index.dtype}'", + "length": f"{len(index)}", + "calendar": f"'{index.calendar}'", + } + attrs["freq"] = f"'{index.freq}'" if len(index) >= 3 else None + attrs_str = [f"{k}={v}" for k, v in attrs.items()] + attrs_str = f"{separator}".join(attrs_str) + return attrs_str + + class CFTimeIndex(pd.Index): """Custom Index for working with CF calendars and dates @@ -224,7 +276,7 @@ class CFTimeIndex(pd.Index): ---------- data : array or CFTimeIndex Sequence of cftime.datetime objects to use in index - name : str, default None + name : str, default: None Name of the resulting index See Also @@ -259,6 +311,46 @@ def __new__(cls, data, name=None): result._cache = {} return result + def __repr__(self): + """ + Return a string representation for this object. + """ + klass_name = type(self).__name__ + display_width = OPTIONS["display_width"] + offset = len(klass_name) + 2 + + if len(self) <= ITEMS_IN_REPR_MAX_ELSE_ELLIPSIS: + datastr = format_times( + self.values, display_width, offset=offset, first_row_offset=0 + ) + else: + front_str = format_times( + self.values[:REPR_ELLIPSIS_SHOW_ITEMS_FRONT_END], + display_width, + offset=offset, + first_row_offset=0, + last_row_end=",", + ) + end_str = format_times( + self.values[-REPR_ELLIPSIS_SHOW_ITEMS_FRONT_END:], + display_width, + offset=offset, + first_row_offset=offset, + ) + datastr = "\n".join([front_str, f"{' '*offset}...", end_str]) + + attrs_str = format_attrs(self) + # oneliner only if smaller than display_width + full_repr_str = f"{klass_name}([{datastr}], {attrs_str})" + if len(full_repr_str) <= display_width: + return full_repr_str + else: + # if attrs_str too long, one per line + if len(attrs_str) >= display_width - offset: + attrs_str = attrs_str.replace(",", f",\n{' '*(offset-2)}") + full_repr_str = f"{klass_name}([{datastr}],\n{' '*(offset-1)}{attrs_str})" + return full_repr_str + def _partial_date_slice(self, resolution, parsed): """Adapted from pandas.tseries.index.DatetimeIndex._partial_date_slice @@ -432,9 +524,11 @@ def shift(self, n, freq): -------- >>> index = xr.cftime_range("2000", periods=1, freq="M") >>> index - CFTimeIndex([2000-01-31 00:00:00], dtype='object') + CFTimeIndex([2000-01-31 00:00:00], + dtype='object', length=1, calendar='gregorian', freq=None) >>> index.shift(1, "M") - CFTimeIndex([2000-02-29 00:00:00], dtype='object') + CFTimeIndex([2000-02-29 00:00:00], + dtype='object', length=1, calendar='gregorian', freq=None) """ from .cftime_offsets import to_offset @@ -520,7 +614,8 @@ def to_datetimeindex(self, unsafe=False): >>> import xarray as xr >>> times = xr.cftime_range("2000", periods=2, calendar="gregorian") >>> times - CFTimeIndex([2000-01-01 00:00:00, 2000-01-02 00:00:00], dtype='object') + CFTimeIndex([2000-01-01 00:00:00, 2000-01-02 00:00:00], + dtype='object', length=2, calendar='gregorian', freq=None) >>> times.to_datetimeindex() DatetimeIndex(['2000-01-01', '2000-01-02'], dtype='datetime64[ns]', freq=None) """ @@ -578,9 +673,24 @@ def asi8(self): [ _total_microseconds(exact_cftime_datetime_difference(epoch, date)) for date in self.values - ] + ], + dtype=np.int64, ) + @property + def calendar(self): + """The calendar used by the datetimes in the index.""" + from .times import infer_calendar_name + + return infer_calendar_name(self) + + @property + def freq(self): + """The frequency used by the dates in the index.""" + from .frequencies import infer_freq + + return infer_freq(self) + def _round_via_method(self, freq, method): """Round dates using a specified method.""" from .cftime_offsets import CFTIME_TICKS, to_offset @@ -599,7 +709,7 @@ def floor(self, freq): Parameters ---------- - freq : str or CFTimeOffset + freq : str The frequency level to round the index to. Must be a fixed frequency like 'S' (second) not 'ME' (month end). See `frequency aliases `_ @@ -616,7 +726,7 @@ def ceil(self, freq): Parameters ---------- - freq : str or CFTimeOffset + freq : str The frequency level to round the index to. Must be a fixed frequency like 'S' (second) not 'ME' (month end). See `frequency aliases `_ @@ -633,7 +743,7 @@ def round(self, freq): Parameters ---------- - freq : str or CFTimeOffset + freq : str The frequency level to round the index to. Must be a fixed frequency like 'S' (second) not 'ME' (month end). See `frequency aliases `_ diff --git a/xarray/coding/frequencies.py b/xarray/coding/frequencies.py new file mode 100644 index 00000000000..fa11d05923f --- /dev/null +++ b/xarray/coding/frequencies.py @@ -0,0 +1,272 @@ +"""FrequencyInferer analog for cftime.datetime objects""" +# The infer_freq method and the _CFTimeFrequencyInferer +# subclass defined here were copied and adapted for +# use with cftime.datetime objects based on the source code in +# pandas.tseries.Frequencies._FrequencyInferer + +# For reference, here is a copy of the pandas copyright notice: + +# (c) 2011-2012, Lambda Foundry, Inc. and PyData Development Team +# All rights reserved. + +# Copyright (c) 2008-2011 AQR Capital Management, LLC +# All rights reserved. + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: + +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. + +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following +# disclaimer in the documentation and/or other materials provided +# with the distribution. + +# * Neither the name of the copyright holder nor the names of any +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import numpy as np +import pandas as pd + +from ..core.common import _contains_datetime_like_objects +from .cftime_offsets import _MONTH_ABBREVIATIONS +from .cftimeindex import CFTimeIndex + +_ONE_MICRO = 1 +_ONE_MILLI = _ONE_MICRO * 1000 +_ONE_SECOND = _ONE_MILLI * 1000 +_ONE_MINUTE = 60 * _ONE_SECOND +_ONE_HOUR = 60 * _ONE_MINUTE +_ONE_DAY = 24 * _ONE_HOUR + + +def infer_freq(index): + """ + Infer the most likely frequency given the input index. + + Parameters + ---------- + index : CFTimeIndex, DataArray, DatetimeIndex, TimedeltaIndex, Series + If not passed a CFTimeIndex, this simply calls `pandas.infer_freq`. + If passed a Series or a DataArray will use the values of the series (NOT THE INDEX). + + Returns + ------- + str or None + None if no discernible frequency. + + Raises + ------ + TypeError + If the index is not datetime-like. + ValueError + If there are fewer than three values or the index is not 1D. + """ + from xarray.core.dataarray import DataArray + + if isinstance(index, (DataArray, pd.Series)): + if index.ndim != 1: + raise ValueError("'index' must be 1D") + elif not _contains_datetime_like_objects(DataArray(index)): + raise ValueError("'index' must contain datetime-like objects") + dtype = np.asarray(index).dtype + if dtype == "datetime64[ns]": + index = pd.DatetimeIndex(index.values) + elif dtype == "timedelta64[ns]": + index = pd.TimedeltaIndex(index.values) + else: + index = CFTimeIndex(index.values) + + if isinstance(index, CFTimeIndex): + inferer = _CFTimeFrequencyInferer(index) + return inferer.get_freq() + + return pd.infer_freq(index) + + +class _CFTimeFrequencyInferer: # (pd.tseries.frequencies._FrequencyInferer): + def __init__(self, index): + self.index = index + self.values = index.asi8 + + if len(index) < 3: + raise ValueError("Need at least 3 dates to infer frequency") + + self.is_monotonic = ( + self.index.is_monotonic_decreasing or self.index.is_monotonic_increasing + ) + + self._deltas = None + self._year_deltas = None + self._month_deltas = None + + def get_freq(self): + """Find the appropriate frequency string to describe the inferred frequency of self.index + + Adapted from `pandas.tsseries.frequencies._FrequencyInferer.get_freq` for CFTimeIndexes. + + Returns + ------- + str or None + """ + if not self.is_monotonic or not self.index.is_unique: + return None + + delta = self.deltas[0] # Smallest delta + if _is_multiple(delta, _ONE_DAY): + return self._infer_daily_rule() + # There is no possible intraday frequency with a non-unique delta + # Different from pandas: we don't need to manage DST and business offsets in cftime + elif not len(self.deltas) == 1: + return None + + if _is_multiple(delta, _ONE_HOUR): + return _maybe_add_count("H", delta / _ONE_HOUR) + elif _is_multiple(delta, _ONE_MINUTE): + return _maybe_add_count("T", delta / _ONE_MINUTE) + elif _is_multiple(delta, _ONE_SECOND): + return _maybe_add_count("S", delta / _ONE_SECOND) + elif _is_multiple(delta, _ONE_MILLI): + return _maybe_add_count("L", delta / _ONE_MILLI) + else: + return _maybe_add_count("U", delta / _ONE_MICRO) + + def _infer_daily_rule(self): + annual_rule = self._get_annual_rule() + if annual_rule: + nyears = self.year_deltas[0] + month = _MONTH_ABBREVIATIONS[self.index[0].month] + alias = f"{annual_rule}-{month}" + return _maybe_add_count(alias, nyears) + + quartely_rule = self._get_quartely_rule() + if quartely_rule: + nquarters = self.month_deltas[0] / 3 + mod_dict = {0: 12, 2: 11, 1: 10} + month = _MONTH_ABBREVIATIONS[mod_dict[self.index[0].month % 3]] + alias = f"{quartely_rule}-{month}" + return _maybe_add_count(alias, nquarters) + + monthly_rule = self._get_monthly_rule() + if monthly_rule: + return _maybe_add_count(monthly_rule, self.month_deltas[0]) + + if len(self.deltas) == 1: + # Daily as there is no "Weekly" offsets with CFTime + days = self.deltas[0] / _ONE_DAY + return _maybe_add_count("D", days) + + # CFTime has no business freq and no "week of month" (WOM) + return None + + def _get_annual_rule(self): + if len(self.year_deltas) > 1: + return None + + if len(np.unique(self.index.month)) > 1: + return None + + return {"cs": "AS", "ce": "A"}.get(month_anchor_check(self.index)) + + def _get_quartely_rule(self): + if len(self.month_deltas) > 1: + return None + + if not self.month_deltas[0] % 3 == 0: + return None + + return {"cs": "QS", "ce": "Q"}.get(month_anchor_check(self.index)) + + def _get_monthly_rule(self): + if len(self.month_deltas) > 1: + return None + + return {"cs": "MS", "ce": "M"}.get(month_anchor_check(self.index)) + + @property + def deltas(self): + """Sorted unique timedeltas as microseconds.""" + if self._deltas is None: + self._deltas = _unique_deltas(self.values) + return self._deltas + + @property + def year_deltas(self): + """Sorted unique year deltas.""" + if self._year_deltas is None: + self._year_deltas = _unique_deltas(self.index.year) + return self._year_deltas + + @property + def month_deltas(self): + """Sorted unique month deltas.""" + if self._month_deltas is None: + self._month_deltas = _unique_deltas(self.index.year * 12 + self.index.month) + return self._month_deltas + + +def _unique_deltas(arr): + """Sorted unique deltas of numpy array""" + return np.sort(np.unique(np.diff(arr))) + + +def _is_multiple(us, mult: int): + """Whether us is a multiple of mult""" + return us % mult == 0 + + +def _maybe_add_count(base: str, count: float): + """If count is greater than 1, add it to the base offset string""" + if count != 1: + assert count == int(count) + count = int(count) + return f"{count}{base}" + else: + return base + + +def month_anchor_check(dates): + """Return the monthly offset string. + + Return "cs" if all dates are the first days of the month, + "ce" if all dates are the last day of the month, + None otherwise. + + Replicated pandas._libs.tslibs.resolution.month_position_check + but without business offset handling. + """ + calendar_end = True + calendar_start = True + + for date in dates: + if calendar_start: + calendar_start &= date.day == 1 + + if calendar_end: + cal = date.day == date.daysinmonth + if calendar_end: + calendar_end &= cal + elif not calendar_start: + break + + if calendar_end: + return "ce" + elif calendar_start: + return "cs" + else: + return None diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py index 35cc190ffe3..e16e983fd8a 100644 --- a/xarray/coding/strings.py +++ b/xarray/coding/strings.py @@ -4,7 +4,7 @@ import numpy as np from ..core import indexing -from ..core.pycompat import dask_array_type +from ..core.pycompat import is_duck_dask_array from ..core.variable import Variable from .variables import ( VariableCoder, @@ -130,7 +130,7 @@ def bytes_to_char(arr): if arr.dtype.kind != "S": raise ValueError("argument must have a fixed-width bytes dtype") - if isinstance(arr, dask_array_type): + if is_duck_dask_array(arr): import dask.array as da return da.map_blocks( @@ -145,8 +145,7 @@ def bytes_to_char(arr): def _numpy_bytes_to_char(arr): - """Like netCDF4.stringtochar, but faster and more flexible. - """ + """Like netCDF4.stringtochar, but faster and more flexible.""" # ensure the array is contiguous arr = np.array(arr, copy=False, order="C", dtype=np.string_) return arr.reshape(arr.shape + (1,)).view("S1") @@ -167,7 +166,7 @@ def char_to_bytes(arr): # can't make an S0 dtype return np.zeros(arr.shape[:-1], dtype=np.string_) - if isinstance(arr, dask_array_type): + if is_duck_dask_array(arr): import dask.array as da if len(arr.chunks[-1]) > 1: @@ -189,8 +188,7 @@ def char_to_bytes(arr): def _numpy_char_to_bytes(arr): - """Like netCDF4.chartostring, but faster and more flexible. - """ + """Like netCDF4.chartostring, but faster and more flexible.""" # based on: http://stackoverflow.com/a/10984878/809705 arr = np.array(arr, copy=False, order="C") dtype = "S" + str(arr.shape[-1]) @@ -201,9 +199,9 @@ class StackedBytesArray(indexing.ExplicitlyIndexedNDArrayMixin): """Wrapper around array-like objects to create a new indexable object where values, when accessed, are automatically stacked along the last dimension. - >>> StackedBytesArray(np.array(["a", "b", "c"]))[:] - array('abc', - dtype='|S3') + >>> indexer = indexing.BasicIndexer((slice(None),)) + >>> StackedBytesArray(np.array(["a", "b", "c"], dtype="S1"))[indexer] + array(b'abc', dtype='|S3') """ def __init__(self, array): diff --git a/xarray/coding/times.py b/xarray/coding/times.py index 965ddd8f043..3d877a169f5 100644 --- a/xarray/coding/times.py +++ b/xarray/coding/times.py @@ -26,6 +26,7 @@ _STANDARD_CALENDARS = {"standard", "gregorian", "proleptic_gregorian"} _NS_PER_TIME_DELTA = { + "ns": 1, "us": int(1e3), "ms": int(1e6), "s": int(1e9), @@ -35,7 +36,15 @@ } TIME_UNITS = frozenset( - ["days", "hours", "minutes", "seconds", "milliseconds", "microseconds"] + [ + "days", + "hours", + "minutes", + "seconds", + "milliseconds", + "microseconds", + "nanoseconds", + ] ) @@ -44,6 +53,7 @@ def _netcdf_to_numpy_timeunit(units): if not units.endswith("s"): units = "%ss" % units return { + "nanoseconds": "ns", "microseconds": "us", "milliseconds": "ms", "seconds": "s", @@ -53,14 +63,50 @@ def _netcdf_to_numpy_timeunit(units): }[units] +def _ensure_padded_year(ref_date): + # Reference dates without a padded year (e.g. since 1-1-1 or since 2-3-4) + # are ambiguous (is it YMD or DMY?). This can lead to some very odd + # behaviour e.g. pandas (via dateutil) passes '1-1-1 00:00:0.0' as + # '2001-01-01 00:00:00' (because it assumes a) DMY and b) that year 1 is + # shorthand for 2001 (like 02 would be shorthand for year 2002)). + + # Here we ensure that there is always a four-digit year, with the + # assumption being that year comes first if we get something ambiguous. + matches_year = re.match(r".*\d{4}.*", ref_date) + if matches_year: + # all good, return + return ref_date + + # No four-digit strings, assume the first digits are the year and pad + # appropriately + matches_start_digits = re.match(r"(\d+)(.*)", ref_date) + ref_year, everything_else = [s for s in matches_start_digits.groups()] + ref_date_padded = "{:04d}{}".format(int(ref_year), everything_else) + + warning_msg = ( + f"Ambiguous reference date string: {ref_date}. The first value is " + "assumed to be the year hence will be padded with zeros to remove " + f"the ambiguity (the padded reference date string is: {ref_date_padded}). " + "To remove this message, remove the ambiguity by padding your reference " + "date strings with zeros." + ) + warnings.warn(warning_msg, SerializationWarning) + + return ref_date_padded + + def _unpack_netcdf_time_units(units): # CF datetime units follow the format: "UNIT since DATE" # this parses out the unit and date allowing for extraneous - # whitespace. - matches = re.match("(.+) since (.+)", units) + # whitespace. It also ensures that the year is padded with zeros + # so it will be correctly understood by pandas (via dateutil). + matches = re.match(r"(.+) since (.+)", units) if not matches: - raise ValueError("invalid time units: %s" % units) + raise ValueError(f"invalid time units: {units}") + delta_units, ref_date = [s.strip() for s in matches.groups()] + ref_date = _ensure_padded_year(ref_date) + return delta_units, ref_date @@ -80,8 +126,9 @@ def _decode_cf_datetime_dtype(data, units, calendar, use_cftime): "the default calendar" if calendar is None else "calendar %r" % calendar ) msg = ( - "unable to decode time units %r with %s. Try " - "opening your dataset with decode_times=False." % (units, calendar_msg) + f"unable to decode time units {units!r} with {calendar_msg!r}. Try " + "opening your dataset with decode_times=False or installing cftime " + "if it is not installed." ) raise ValueError(msg) else: @@ -114,21 +161,22 @@ def _decode_datetime_with_pandas(flat_num_dates, units, calendar): # strings, in which case we fall back to using cftime raise OutOfBoundsDatetime - # fixes: https://github.com/pydata/pandas/issues/14068 - # these lines check if the the lowest or the highest value in dates - # cause an OutOfBoundsDatetime (Overflow) error - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "invalid value encountered", RuntimeWarning) - pd.to_timedelta(flat_num_dates.min(), delta) + ref_date - pd.to_timedelta(flat_num_dates.max(), delta) + ref_date - - # Cast input dates to integers of nanoseconds because `pd.to_datetime` - # works much faster when dealing with integers - # make _NS_PER_TIME_DELTA an array to ensure type upcasting - flat_num_dates_ns_int = ( - flat_num_dates.astype(np.float64) * _NS_PER_TIME_DELTA[delta] - ).astype(np.int64) + # To avoid integer overflow when converting to nanosecond units for integer + # dtypes smaller than np.int64 cast all integer-dtype arrays to np.int64 + # (GH 2002). + if flat_num_dates.dtype.kind == "i": + flat_num_dates = flat_num_dates.astype(np.int64) + + # Cast input ordinals to integers of nanoseconds because pd.to_timedelta + # works much faster when dealing with integers (GH 1399). + flat_num_dates_ns_int = (flat_num_dates * _NS_PER_TIME_DELTA[delta]).astype( + np.int64 + ) + # Use pd.to_timedelta to safely cast integer values to timedeltas, + # and add those to a Timestamp to safely produce a DatetimeIndex. This + # ensures that we do not encounter integer overflow at any point in the + # process without raising OutOfBoundsDatetime. return (pd.to_timedelta(flat_num_dates_ns_int, "ns") + ref_date).values @@ -155,9 +203,9 @@ def decode_cf_datetime(num_dates, units, calendar=None, use_cftime=None): if use_cftime is None: try: dates = _decode_datetime_with_pandas(flat_num_dates, units, calendar) - except (OutOfBoundsDatetime, OverflowError): + except (KeyError, OutOfBoundsDatetime, OverflowError): dates = _decode_datetime_with_cftime( - flat_num_dates.astype(np.float), units, calendar + flat_num_dates.astype(float), units, calendar ) if ( @@ -178,7 +226,7 @@ def decode_cf_datetime(num_dates, units, calendar=None, use_cftime=None): dates = cftime_to_nptime(dates) elif use_cftime: dates = _decode_datetime_with_cftime( - flat_num_dates.astype(np.float), units, calendar + flat_num_dates.astype(float), units, calendar ) else: dates = _decode_datetime_with_pandas(flat_num_dates, units, calendar) @@ -215,11 +263,24 @@ def decode_cf_timedelta(num_timedeltas, units): def _infer_time_units_from_diff(unique_timedeltas): - for time_unit in ["days", "hours", "minutes", "seconds"]: + # Note that the modulus operator was only implemented for np.timedelta64 + # arrays as of NumPy version 1.16.0. Once our minimum version of NumPy + # supported is greater than or equal to this we will no longer need to cast + # unique_timedeltas to a TimedeltaIndex. In the meantime, however, the + # modulus operator works for TimedeltaIndex objects. + unique_deltas_as_index = pd.TimedeltaIndex(unique_timedeltas) + for time_unit in [ + "days", + "hours", + "minutes", + "seconds", + "milliseconds", + "microseconds", + "nanoseconds", + ]: delta_ns = _NS_PER_TIME_DELTA[_netcdf_to_numpy_timeunit(time_unit)] unit_delta = np.timedelta64(delta_ns, "ns") - diffs = unique_timedeltas / unit_delta - if np.all(diffs == diffs.astype(int)): + if np.all(unique_deltas_as_index % unit_delta == np.timedelta64(0, "ns")): return time_unit return "seconds" @@ -379,7 +440,15 @@ def encode_cf_datetime(dates, units=None, calendar=None): # Wrap the dates in a DatetimeIndex to do the subtraction to ensure # an OverflowError is raised if the ref_date is too far away from # dates to be encoded (GH 2272). - num = (pd.DatetimeIndex(dates.ravel()) - ref_date) / time_delta + dates_as_index = pd.DatetimeIndex(dates.ravel()) + time_deltas = dates_as_index - ref_date + + # Use floor division if time_delta evenly divides all differences + # to preserve integer dtype if possible (GH 4045). + if np.all(time_deltas % time_delta == np.timedelta64(0, "ns")): + num = time_deltas // time_delta + else: + num = time_deltas / time_delta num = num.values.reshape(dates.shape) except (OutOfBoundsDatetime, OverflowError): diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index 28ead397461..b035ff82086 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -7,7 +7,7 @@ import pandas as pd from ..core import dtypes, duck_array_ops, indexing -from ..core.pycompat import dask_array_type +from ..core.pycompat import is_duck_dask_array from ..core.variable import Variable @@ -35,15 +35,13 @@ class VariableCoder: def encode( self, variable: Variable, name: Hashable = None ) -> Variable: # pragma: no cover - """Convert an encoded variable to a decoded variable - """ + """Convert an encoded variable to a decoded variable""" raise NotImplementedError() def decode( self, variable: Variable, name: Hashable = None ) -> Variable: # pragma: no cover - """Convert an decoded variable to a encoded variable - """ + """Convert an decoded variable to a encoded variable""" raise NotImplementedError() @@ -56,7 +54,7 @@ class _ElementwiseFunctionArray(indexing.ExplicitlyIndexedNDArrayMixin): """ def __init__(self, array, func, dtype): - assert not isinstance(array, dask_array_type) + assert not is_duck_dask_array(array) self.array = indexing.as_indexable(array) self.func = func self._dtype = dtype @@ -93,8 +91,10 @@ def lazy_elemwise_func(array, func, dtype): ------- Either a dask.array.Array or _ElementwiseFunctionArray. """ - if isinstance(array, dask_array_type): - return array.map_blocks(func, dtype=dtype) + if is_duck_dask_array(array): + import dask.array as da + + return da.map_blocks(func, array, dtype=dtype) else: return _ElementwiseFunctionArray(array, func, dtype) @@ -269,6 +269,10 @@ def decode(self, variable, name=None): scale_factor = pop_to(attrs, encoding, "scale_factor", name=name) add_offset = pop_to(attrs, encoding, "add_offset", name=name) dtype = _choose_float_dtype(data.dtype, "add_offset" in attrs) + if np.ndim(scale_factor) > 0: + scale_factor = np.asarray(scale_factor).item() + if np.ndim(add_offset) > 0: + add_offset = np.asarray(add_offset).item() transform = partial( _scale_offset_decoding, scale_factor=scale_factor, diff --git a/xarray/conventions.py b/xarray/conventions.py index df24d0d3d8d..bb0b92c77a1 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -8,7 +8,7 @@ from .coding.variables import SerializationWarning, pop_to from .core import duck_array_ops, indexing from .core.common import contains_cftime_datetimes -from .core.pycompat import dask_array_type +from .core.pycompat import is_duck_dask_array from .core.variable import IndexVariable, Variable, as_variable @@ -24,10 +24,11 @@ class NativeEndiannessArray(indexing.ExplicitlyIndexedNDArrayMixin): >>> x.dtype dtype('>i2') - >>> NativeEndianArray(x).dtype + >>> NativeEndiannessArray(x).dtype dtype('int16') - >>> NativeEndianArray(x)[:].dtype + >>> indexer = indexing.BasicIndexer((slice(None),)) + >>> NativeEndiannessArray(x)[indexer].dtype dtype('int16') """ @@ -53,12 +54,13 @@ class BoolTypeArray(indexing.ExplicitlyIndexedNDArrayMixin): >>> x = np.array([1, 0, 1, 1, 0], dtype="i1") >>> x.dtype - dtype('>i2') + dtype('int8') >>> BoolTypeArray(x).dtype dtype('bool') - >>> BoolTypeArray(x)[:].dtype + >>> indexer = indexing.BasicIndexer((slice(None),)) + >>> BoolTypeArray(x)[indexer].dtype dtype('bool') """ @@ -116,7 +118,7 @@ def maybe_default_fill_value(var): def maybe_encode_bools(var): if ( - (var.dtype == np.bool) + (var.dtype == bool) and ("dtype" not in var.encoding) and ("dtype" not in var.attrs) ): @@ -178,7 +180,7 @@ def ensure_dtype_not_object(var, name=None): if var.dtype.kind == "O": dims, data, attrs, encoding = _var_as_tuple(var) - if isinstance(data, dask_array_type): + if is_duck_dask_array(data): warnings.warn( "variable {} has data in the form of a dask array with " "dtype=object, which means it is being loaded into memory " @@ -230,12 +232,12 @@ def encode_cf_variable(var, needs_copy=True, name=None): Parameters ---------- - var : xarray.Variable + var : Variable A variable holding un-encoded data. Returns ------- - out : xarray.Variable + out : Variable A variable which has been encoded as described above. """ ensure_not_multiindex(var, name=name) @@ -266,6 +268,7 @@ def decode_cf_variable( decode_endianness=True, stack_char_dim=True, use_cftime=None, + decode_timedelta=None, ): """ Decodes a variable which may hold CF encoded information. @@ -277,28 +280,28 @@ def decode_cf_variable( Parameters ---------- - name: str + name : str Name of the variable. Used for better error messages. var : Variable A variable holding potentially CF encoded information. concat_characters : bool Should character arrays be concatenated to strings, for - example: ['h', 'e', 'l', 'l', 'o'] -> 'hello' - mask_and_scale: bool + example: ["h", "e", "l", "l", "o"] -> "hello" + mask_and_scale : bool Lazily scale (using scale_factor and add_offset) and mask (using _FillValue). If the _Unsigned attribute is present treat integer arrays as unsigned. decode_times : bool - Decode cf times ('hours since 2000-01-01') to np.datetime64. + Decode cf times ("hours since 2000-01-01") to np.datetime64. decode_endianness : bool Decode arrays from non-native to native endianness. stack_char_dim : bool Whether to stack characters into bytes along the last dimension of this array. Passed as an argument because we need to look at the full dataset to figure out if this is appropriate. - use_cftime: bool, optional + use_cftime : bool, optional Only relevant if encoded dates come from a standard calendar - (e.g. 'gregorian', 'proleptic_gregorian', 'standard', or not + (e.g. "gregorian", "proleptic_gregorian", "standard", or not specified). If None (default), attempt to decode times to ``np.datetime64[ns]`` objects; if this is not possible, decode times to ``cftime.datetime`` objects. If True, always decode times to @@ -315,6 +318,9 @@ def decode_cf_variable( var = as_variable(var) original_dtype = var.dtype + if decode_timedelta is None: + decode_timedelta = decode_times + if concat_characters: if stack_char_dim: var = strings.CharacterArrayCoder().decode(var, name=name) @@ -328,12 +334,10 @@ def decode_cf_variable( ]: var = coder.decode(var, name=name) + if decode_timedelta: + var = times.CFTimedeltaCoder().decode(var, name=name) if decode_times: - for coder in [ - times.CFTimedeltaCoder(), - times.CFDatetimeCoder(use_cftime=use_cftime), - ]: - var = coder.decode(var, name=name) + var = times.CFDatetimeCoder(use_cftime=use_cftime).decode(var, name=name) dimensions, data, attributes, encoding = variables.unpack_for_decoding(var) # TODO(shoyer): convert everything below to use coders @@ -349,7 +353,7 @@ def decode_cf_variable( del attributes["dtype"] data = BoolTypeArray(data) - if not isinstance(data, dask_array_type): + if not is_duck_dask_array(data): data = indexing.LazilyOuterIndexedArray(data) return Variable(dimensions, data, attributes, encoding=encoding) @@ -442,6 +446,7 @@ def decode_cf_variables( decode_coords=True, drop_variables=None, use_cftime=None, + decode_timedelta=None, ): """ Decode several CF encoded variables. @@ -492,6 +497,7 @@ def stackable(dim): decode_times=decode_times, stack_char_dim=stack_char_dim, use_cftime=use_cftime, + decode_timedelta=decode_timedelta, ) if decode_coords: var_attrs = new_vars[k].attrs @@ -518,6 +524,7 @@ def decode_cf( decode_coords=True, drop_variables=None, use_cftime=None, + decode_timedelta=None, ): """Decode the given Dataset or Datastore according to CF conventions into a new Dataset. @@ -528,23 +535,23 @@ def decode_cf( Object to decode. concat_characters : bool, optional Should character arrays be concatenated to strings, for - example: ['h', 'e', 'l', 'l', 'o'] -> 'hello' - mask_and_scale: bool, optional + example: ["h", "e", "l", "l", "o"] -> "hello" + mask_and_scale : bool, optional Lazily scale (using scale_factor and add_offset) and mask (using _FillValue). decode_times : bool, optional - Decode cf times (e.g., integers since 'hours since 2000-01-01') to + Decode cf times (e.g., integers since "hours since 2000-01-01") to np.datetime64. decode_coords : bool, optional Use the 'coordinates' attribute on variable (or the dataset itself) to identify coordinates. - drop_variables: string or iterable, optional + drop_variables : str or iterable, optional A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or inconsistent values. - use_cftime: bool, optional + use_cftime : bool, optional Only relevant if encoded dates come from a standard calendar - (e.g. 'gregorian', 'proleptic_gregorian', 'standard', or not + (e.g. "gregorian", "proleptic_gregorian", "standard", or not specified). If None (default), attempt to decode times to ``np.datetime64[ns]`` objects; if this is not possible, decode times to ``cftime.datetime`` objects. If True, always decode times to @@ -552,13 +559,18 @@ def decode_cf( represented using ``np.datetime64[ns]`` objects. If False, always decode times to ``np.datetime64[ns]`` objects; if this is not possible raise an error. + decode_timedelta : bool, optional + If True, decode variables and coordinates with time units in + {"days", "hours", "minutes", "seconds", "milliseconds", "microseconds"} + into timedelta objects. If False, leave them encoded as numbers. + If None (default), assume the same value of decode_time. Returns ------- decoded : Dataset """ - from .core.dataset import Dataset from .backends.common import AbstractDataStore + from .core.dataset import Dataset if isinstance(obj, Dataset): vars = obj._variables @@ -583,6 +595,7 @@ def decode_cf( decode_coords, drop_variables=drop_variables, use_cftime=use_cftime, + decode_timedelta=decode_timedelta, ) ds = Dataset(vars, attrs=attrs) ds = ds.set_coords(coord_names.union(extra_coords).intersection(vars)) @@ -610,12 +623,12 @@ def cf_decoder( A dictionary mapping from attribute name to value concat_characters : bool Should character arrays be concatenated to strings, for - example: ['h', 'e', 'l', 'l', 'o'] -> 'hello' + example: ["h", "e", "l", "l", "o"] -> "hello" mask_and_scale: bool Lazily scale (using scale_factor and add_offset) and mask (using _FillValue). decode_times : bool - Decode cf times ('hours since 2000-01-01') to np.datetime64. + Decode cf times ("hours since 2000-01-01") to np.datetime64. Returns ------- diff --git a/xarray/convert.py b/xarray/convert.py index 4974a55d8e2..0fbd1e13163 100644 --- a/xarray/convert.py +++ b/xarray/convert.py @@ -10,6 +10,7 @@ from .core import duck_array_ops from .core.dataarray import DataArray from .core.dtypes import get_fill_value +from .core.pycompat import dask_array_type cdms2_ignored_attrs = {"name", "tileIndex"} iris_forbidden_keys = { @@ -55,14 +56,12 @@ def encode(var): def _filter_attrs(attrs, ignored_attrs): - """ Return attrs that are not in ignored_attrs - """ + """Return attrs that are not in ignored_attrs""" return {k: v for k, v in attrs.items() if k not in ignored_attrs} def from_cdms2(variable): - """Convert a cdms2 variable into an DataArray - """ + """Convert a cdms2 variable into an DataArray""" values = np.asarray(variable) name = variable.id dims = variable.getAxisIds() @@ -89,8 +88,7 @@ def from_cdms2(variable): def to_cdms2(dataarray, copy=True): - """Convert a DataArray into a cdms2 variable - """ + """Convert a DataArray into a cdms2 variable""" # we don't want cdms2 to be a hard dependency import cdms2 @@ -151,14 +149,12 @@ def set_cdms2_attrs(var, attrs): def _pick_attrs(attrs, keys): - """ Return attrs with keys in keys list - """ + """Return attrs with keys in keys list""" return {k: v for k, v in attrs.items() if k in keys} def _get_iris_args(attrs): - """ Converts the xarray attrs into args that can be passed into Iris - """ + """Converts the xarray attrs into args that can be passed into Iris""" # iris.unit is deprecated in Iris v1.9 import cf_units @@ -172,8 +168,7 @@ def _get_iris_args(attrs): # TODO: Add converting bounds from xarray to Iris and back def to_iris(dataarray): - """ Convert a DataArray into a Iris Cube - """ + """Convert a DataArray into a Iris Cube""" # Iris not a hard dependency import iris from iris.fileformats.netcdf import parse_cell_methods @@ -213,8 +208,7 @@ def to_iris(dataarray): def _iris_obj_to_attrs(obj): - """ Return a dictionary of attrs when given a Iris object - """ + """Return a dictionary of attrs when given a Iris object""" attrs = {"standard_name": obj.standard_name, "long_name": obj.long_name} if obj.units.calendar: attrs["calendar"] = obj.units.calendar @@ -225,15 +219,14 @@ def _iris_obj_to_attrs(obj): def _iris_cell_methods_to_str(cell_methods_obj): - """ Converts a Iris cell methods into a string - """ + """Converts a Iris cell methods into a string""" cell_methods = [] for cell_method in cell_methods_obj: - names = "".join([f"{n}: " for n in cell_method.coord_names]) + names = "".join(f"{n}: " for n in cell_method.coord_names) intervals = " ".join( - [f"interval: {interval}" for interval in cell_method.intervals] + f"interval: {interval}" for interval in cell_method.intervals ) - comments = " ".join([f"comment: {comment}" for comment in cell_method.comments]) + comments = " ".join(f"comment: {comment}" for comment in cell_method.comments) extra = " ".join([intervals, comments]).strip() if extra: extra = f" ({extra})" @@ -242,7 +235,7 @@ def _iris_cell_methods_to_str(cell_methods_obj): def _name(iris_obj, default="unknown"): - """ Mimicks `iris_obj.name()` but with different name resolution order. + """Mimicks `iris_obj.name()` but with different name resolution order. Similar to iris_obj.name() method, but using iris_obj.var_name first to enable roundtripping. @@ -251,10 +244,8 @@ def _name(iris_obj, default="unknown"): def from_iris(cube): - """ Convert a Iris cube into an DataArray - """ + """Convert a Iris cube into an DataArray""" import iris.exceptions - from xarray.core.pycompat import dask_array_type name = _name(cube) if name == "unknown": diff --git a/xarray/core/accessor_dt.py b/xarray/core/accessor_dt.py index 2977596036c..3fc682f8c32 100644 --- a/xarray/core/accessor_dt.py +++ b/xarray/core/accessor_dt.py @@ -1,3 +1,6 @@ +import warnings +from distutils.version import LooseVersion + import numpy as np import pandas as pd @@ -6,12 +9,11 @@ is_np_datetime_like, is_np_timedelta_like, ) -from .pycompat import dask_array_type +from .pycompat import is_duck_dask_array def _season_from_months(months): - """Compute season (DJF, MAM, JJA, SON) from month ordinal - """ + """Compute season (DJF, MAM, JJA, SON) from month ordinal""" # TODO: Move "season" accessor upstream into pandas seasons = np.array(["DJF", "MAM", "JJA", "SON"]) months = np.asarray(months) @@ -41,6 +43,10 @@ def _access_through_series(values, name): if name == "season": months = values_as_series.dt.month.values field_values = _season_from_months(months) + elif name == "isocalendar": + # isocalendar returns iso- year, week, and weekday -> reshape + field_values = np.array(values_as_series.dt.isocalendar(), dtype=np.int64) + return field_values.T.reshape(3, *values.shape) else: field_values = getattr(values_as_series.dt, name).values return field_values.reshape(values.shape) @@ -70,10 +76,18 @@ def _get_date_field(values, name, dtype): else: access_method = _access_through_cftimeindex - if isinstance(values, dask_array_type): + if is_duck_dask_array(values): from dask.array import map_blocks - return map_blocks(access_method, values, name, dtype=dtype) + new_axis = chunks = None + # isocalendar adds adds an axis + if name == "isocalendar": + chunks = (3,) + values.chunksize + new_axis = 0 + + return map_blocks( + access_method, values, name, dtype=dtype, new_axis=new_axis, chunks=chunks + ) else: return access_method(values, name) @@ -104,9 +118,10 @@ def _round_field(values, name, freq): ---------- values : np.ndarray or dask.array-like Array-like container of datetime-like values - name : str (ceil, floor, round) + name : {"ceil", "floor", "round"} Name of rounding function - freq : a freq string indicating the rounding resolution + freq : str + a freq string indicating the rounding resolution Returns ------- @@ -114,7 +129,7 @@ def _round_field(values, name, freq): Array-like of datetime fields accessed for each element in values """ - if isinstance(values, dask_array_type): + if is_duck_dask_array(values): from dask.array import map_blocks dtype = np.datetime64 if is_np_datetime_like(values.dtype) else np.dtype("O") @@ -151,7 +166,7 @@ def _strftime(values, date_format): access_method = _strftime_through_series else: access_method = _strftime_through_cftimeindex - if isinstance(values, dask_array_type): + if is_duck_dask_array(values): from dask.array import map_blocks return map_blocks(access_method, values, date_format) @@ -190,8 +205,8 @@ def floor(self, freq): Parameters ---------- - freq : a freq string indicating the rounding resolution - e.g. 'D' for daily resolution + freq : str + a freq string indicating the rounding resolution e.g. "D" for daily resolution Returns ------- @@ -207,8 +222,8 @@ def ceil(self, freq): Parameters ---------- - freq : a freq string indicating the rounding resolution - e.g. 'D' for daily resolution + freq : str + a freq string indicating the rounding resolution e.g. "D" for daily resolution Returns ------- @@ -223,8 +238,8 @@ def round(self, freq): Parameters ---------- - freq : a freq string indicating the rounding resolution - e.g. 'D' for daily resolution + freq : str + a freq string indicating the rounding resolution e.g. "D" for daily resolution Returns ------- @@ -240,12 +255,6 @@ class DatetimeAccessor(Properties): Fields can be accessed through the `.dt` attribute for applicable DataArrays. - Notes - ------ - Note that these fields are not calendar-aware; if your datetimes are encoded - with a non-Gregorian calendar (e.g. a 360-day calendar) using cftime, - then some fields like `dayofyear` may not be accurate. - Examples --------- >>> import xarray as xr @@ -255,30 +264,30 @@ class DatetimeAccessor(Properties): >>> ts array(['2000-01-01T00:00:00.000000000', '2000-01-02T00:00:00.000000000', - '2000-01-03T00:00:00.000000000', '2000-01-04T00:00:00.000000000', - '2000-01-05T00:00:00.000000000', '2000-01-06T00:00:00.000000000', - '2000-01-07T00:00:00.000000000', '2000-01-08T00:00:00.000000000', - '2000-01-09T00:00:00.000000000', '2000-01-10T00:00:00.000000000'], - dtype='datetime64[ns]') + '2000-01-03T00:00:00.000000000', '2000-01-04T00:00:00.000000000', + '2000-01-05T00:00:00.000000000', '2000-01-06T00:00:00.000000000', + '2000-01-07T00:00:00.000000000', '2000-01-08T00:00:00.000000000', + '2000-01-09T00:00:00.000000000', '2000-01-10T00:00:00.000000000'], + dtype='datetime64[ns]') Coordinates: - * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-10 - >>> ts.dt - + * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-10 + >>> ts.dt # doctest: +ELLIPSIS + >>> ts.dt.dayofyear array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) Coordinates: - * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-10 + * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-10 >>> ts.dt.quarter array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) Coordinates: - * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-10 + * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-10 """ def strftime(self, date_format): - ''' + """ Return an array of formatted strings specified by date_format, which supports the same string format as the python standard library. Details of the string format can be found in `python string format doc @@ -296,13 +305,12 @@ def strftime(self, date_format): Examples -------- + >>> import datetime >>> rng = xr.Dataset({"time": datetime.datetime(2000, 1, 1)}) >>> rng["time"].dt.strftime("%B %d, %Y, %r") array('January 01, 2000, 12:00:00 AM', dtype=object) """ - - ''' obj_type = type(self._obj) result = _strftime(self._obj.data, date_format) @@ -311,6 +319,33 @@ def strftime(self, date_format): result, name="strftime", coords=self._obj.coords, dims=self._obj.dims ) + def isocalendar(self): + """Dataset containing ISO year, week number, and weekday. + + Note + ---- + The iso year and weekday differ from the nominal year and weekday. + """ + + from .dataset import Dataset + + if not is_np_datetime_like(self._obj.data.dtype): + raise AttributeError("'CFTimeIndex' object has no attribute 'isocalendar'") + + if LooseVersion(pd.__version__) < "1.1.0": + raise AttributeError("'isocalendar' not available in pandas < 1.1.0") + + values = _get_date_field(self._obj.data, "isocalendar", np.int64) + + obj_type = type(self._obj) + data_vars = {} + for i, name in enumerate(["year", "week", "weekday"]): + data_vars[name] = obj_type( + values[i], name=name, coords=self._obj.coords, dims=self._obj.dims + ) + + return Dataset(data_vars) + year = Properties._tslib_field_accessor( "year", "The year of the datetime", np.int64 ) @@ -333,9 +368,26 @@ def strftime(self, date_format): nanosecond = Properties._tslib_field_accessor( "nanosecond", "The nanoseconds of the datetime", np.int64 ) - weekofyear = Properties._tslib_field_accessor( - "weekofyear", "The week ordinal of the year", np.int64 - ) + + @property + def weekofyear(self): + "The week ordinal of the year" + + warnings.warn( + "dt.weekofyear and dt.week have been deprecated. Please use " + "dt.isocalendar().week instead.", + FutureWarning, + ) + + if LooseVersion(pd.__version__) < "1.1.0": + weekofyear = Properties._tslib_field_accessor( + "weekofyear", "The week ordinal of the year", np.int64 + ).fget(self) + else: + weekofyear = self.isocalendar().week + + return weekofyear + week = weekofyear dayofweek = Properties._tslib_field_accessor( "dayofweek", "The day of the week with Monday=0, Sunday=6", np.int64 @@ -404,32 +456,32 @@ class TimedeltaAccessor(Properties): >>> ts array([ 86400000000000, 108000000000000, 129600000000000, 151200000000000, - 172800000000000, 194400000000000, 216000000000000, 237600000000000, - 259200000000000, 280800000000000, 302400000000000, 324000000000000, - 345600000000000, 367200000000000, 388800000000000, 410400000000000, - 432000000000000, 453600000000000, 475200000000000, 496800000000000], - dtype='timedelta64[ns]') + 172800000000000, 194400000000000, 216000000000000, 237600000000000, + 259200000000000, 280800000000000, 302400000000000, 324000000000000, + 345600000000000, 367200000000000, 388800000000000, 410400000000000, + 432000000000000, 453600000000000, 475200000000000, 496800000000000], + dtype='timedelta64[ns]') Coordinates: - * time (time) timedelta64[ns] 1 days 00:00:00 ... 5 days 18:00:00 - >>> ts.dt - + * time (time) timedelta64[ns] 1 days 00:00:00 ... 5 days 18:00:00 + >>> ts.dt # doctest: +ELLIPSIS + >>> ts.dt.days array([1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5]) Coordinates: - * time (time) timedelta64[ns] 1 days 00:00:00 ... 5 days 18:00:00 + * time (time) timedelta64[ns] 1 days 00:00:00 ... 5 days 18:00:00 >>> ts.dt.microseconds array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) Coordinates: - * time (time) timedelta64[ns] 1 days 00:00:00 ... 5 days 18:00:00 + * time (time) timedelta64[ns] 1 days 00:00:00 ... 5 days 18:00:00 >>> ts.dt.seconds array([ 0, 21600, 43200, 64800, 0, 21600, 43200, 64800, 0, - 21600, 43200, 64800, 0, 21600, 43200, 64800, 0, 21600, - 43200, 64800]) + 21600, 43200, 64800, 0, 21600, 43200, 64800, 0, 21600, + 43200, 64800]) Coordinates: - * time (time) timedelta64[ns] 1 days 00:00:00 ... 5 days 18:00:00 + * time (time) timedelta64[ns] 1 days 00:00:00 ... 5 days 18:00:00 """ days = Properties._tslib_field_accessor( diff --git a/xarray/core/accessor_str.py b/xarray/core/accessor_str.py index 5502ba72855..02d8ca00bf9 100644 --- a/xarray/core/accessor_str.py +++ b/xarray/core/accessor_str.py @@ -68,7 +68,7 @@ class StringAccessor: for applicable DataArrays. >>> da = xr.DataArray(["some", "text", "in", "an", "array"]) - >>> ds.str.len() + >>> da.str.len() array([4, 4, 2, 2, 5]) Dimensions without coordinates: dim_0 @@ -90,7 +90,7 @@ def _apply(self, f, dtype=None): def len(self): """ - Compute the length of each element in the array. + Compute the length of each string in the array. Returns ------- @@ -104,9 +104,9 @@ def __getitem__(self, key): else: return self.get(key) - def get(self, i): + def get(self, i, default=""): """ - Extract element from indexable in each element in the array. + Extract character number `i` from each string in the array. Parameters ---------- @@ -118,14 +118,20 @@ def get(self, i): Returns ------- - items : array of objects + items : array of object """ - obj = slice(-1, None) if i == -1 else slice(i, i + 1) - return self._apply(lambda x: x[obj]) + s = slice(-1, None) if i == -1 else slice(i, i + 1) + + def f(x): + item = x[s] + + return item if item else default + + return self._apply(f) def slice(self, start=None, stop=None, step=None): """ - Slice substrings from each element in the array. + Slice substrings from each string in the array. Parameters ---------- @@ -338,15 +344,15 @@ def count(self, pat, flags=0): This function is used to count the number of times a particular regex pattern is repeated in each of the string elements of the - :class:`~xarray.DatArray`. + :class:`~xarray.DataArray`. Parameters ---------- pat : str Valid regular expression. - flags : int, default 0, meaning no flags - Flags for the `re` module. For a complete list, `see here - `_. + flags : int, default: 0 + Flags for the `re` module. Use 0 for no flags. For a complete list, + `see here `_. Returns ------- @@ -359,7 +365,7 @@ def count(self, pat, flags=0): def startswith(self, pat): """ - Test if the start of each string element matches a pattern. + Test if the start of each string in the array matches a pattern. Parameters ---------- @@ -378,7 +384,7 @@ def startswith(self, pat): def endswith(self, pat): """ - Test if the end of each string element matches a pattern. + Test if the end of each string in the array matches a pattern. Parameters ---------- @@ -404,9 +410,9 @@ def pad(self, width, side="left", fillchar=" "): width : int Minimum width of resulting string; additional characters will be filled with character defined in `fillchar`. - side : {'left', 'right', 'both'}, default 'left' + side : {"left", "right", "both"}, default: "left" Side from which to fill resulting string. - fillchar : str, default ' ' + fillchar : str, default: " " Additional character for filling, default is whitespace. Returns @@ -432,15 +438,14 @@ def pad(self, width, side="left", fillchar=" "): def center(self, width, fillchar=" "): """ - Filling left and right side of strings in the array with an - additional character. + Pad left and right side of each string in the array. Parameters ---------- width : int Minimum width of resulting string; additional characters will be filled with ``fillchar`` - fillchar : str + fillchar : str, default: " " Additional character for filling, default is whitespace Returns @@ -451,15 +456,14 @@ def center(self, width, fillchar=" "): def ljust(self, width, fillchar=" "): """ - Filling right side of strings in the array with an additional - character. + Pad right side of each string in the array. Parameters ---------- width : int Minimum width of resulting string; additional characters will be filled with ``fillchar`` - fillchar : str + fillchar : str, default: " " Additional character for filling, default is whitespace Returns @@ -470,14 +474,14 @@ def ljust(self, width, fillchar=" "): def rjust(self, width, fillchar=" "): """ - Filling left side of strings in the array with an additional character. + Pad left side of each string in the array. Parameters ---------- width : int Minimum width of resulting string; additional characters will be filled with ``fillchar`` - fillchar : str + fillchar : str, default: " " Additional character for filling, default is whitespace Returns @@ -488,11 +492,11 @@ def rjust(self, width, fillchar=" "): def zfill(self, width): """ - Pad strings in the array by prepending '0' characters. + Pad each string in the array by prepending '0' characters. Strings in the array are padded with '0' characters on the left of the string to reach a total string length `width`. Strings - in the array with length greater or equal to `width` are unchanged. + in the array with length greater or equal to `width` are unchanged. Parameters ---------- @@ -508,7 +512,7 @@ def zfill(self, width): def contains(self, pat, case=True, flags=0, regex=True): """ - Test if pattern or regex is contained within a string of the array. + Test if pattern or regex is contained within each string of the array. Return boolean array based on whether a given pattern or regex is contained within a string of the array. @@ -517,11 +521,12 @@ def contains(self, pat, case=True, flags=0, regex=True): ---------- pat : str Character sequence or regular expression. - case : bool, default True + case : bool, default: True If True, case sensitive. - flags : int, default 0 (no flags) + flags : int, default: 0 Flags to pass through to the re module, e.g. re.IGNORECASE. - regex : bool, default True + ``0`` means no flags. + regex : bool, default: True If True, assumes the pat is a regular expression. If False, treats the pat as a literal string. @@ -554,16 +559,16 @@ def contains(self, pat, case=True, flags=0, regex=True): def match(self, pat, case=True, flags=0): """ - Determine if each string matches a regular expression. + Determine if each string in the array matches a regular expression. Parameters ---------- - pat : string + pat : str Character sequence or regular expression - case : boolean, default True + case : bool, default: True If True, case sensitive - flags : int, default 0 (no flags) - re module flags, e.g. re.IGNORECASE + flags : int, default: 0 + re module flags, e.g. re.IGNORECASE. ``0`` means no flags Returns ------- @@ -586,11 +591,11 @@ def strip(self, to_strip=None, side="both"): Parameters ---------- - to_strip : str or None, default None + to_strip : str or None, default: None Specifying the set of characters to be removed. All combinations of this set of characters will be stripped. If None then whitespaces are removed. - side : {'left', 'right', 'both'}, default 'left' + side : {"left", "right", "both"}, default: "left" Side from which to strip. Returns @@ -613,14 +618,14 @@ def strip(self, to_strip=None, side="both"): def lstrip(self, to_strip=None): """ - Remove leading and trailing characters. + Remove leading characters. Strip whitespaces (including newlines) or a set of specified characters from each string in the array from the left side. Parameters ---------- - to_strip : str or None, default None + to_strip : str or None, default: None Specifying the set of characters to be removed. All combinations of this set of characters will be stripped. If None then whitespaces are removed. @@ -633,14 +638,14 @@ def lstrip(self, to_strip=None): def rstrip(self, to_strip=None): """ - Remove leading and trailing characters. + Remove trailing characters. Strip whitespaces (including newlines) or a set of specified characters from each string in the array from the right side. Parameters ---------- - to_strip : str or None, default None + to_strip : str or None, default: None Specifying the set of characters to be removed. All combinations of this set of characters will be stripped. If None then whitespaces are removed. @@ -653,8 +658,7 @@ def rstrip(self, to_strip=None): def wrap(self, width, **kwargs): """ - Wrap long strings in the array to be formatted in paragraphs with - length less than a given width. + Wrap long strings in the array in paragraphs with length less than `width`. This method has the same keyword parameters and defaults as :class:`textwrap.TextWrapper`. @@ -663,38 +667,20 @@ def wrap(self, width, **kwargs): ---------- width : int Maximum line-width - expand_tabs : bool, optional - If true, tab characters will be expanded to spaces (default: True) - replace_whitespace : bool, optional - If true, each whitespace character (as defined by - string.whitespace) remaining after tab expansion will be replaced - by a single space (default: True) - drop_whitespace : bool, optional - If true, whitespace that, after wrapping, happens to end up at the - beginning or end of a line is dropped (default: True) - break_long_words : bool, optional - If true, then words longer than width will be broken in order to - ensure that no lines are longer than width. If it is false, long - words will not be broken, and some lines may be longer than width. - (default: True) - break_on_hyphens : bool, optional - If true, wrapping will occur preferably on whitespace and right - after hyphens in compound words, as it is customary in English. If - false, only whitespaces will be considered as potentially good - places for line breaks, but you need to set break_long_words to - false if you want truly insecable words. (default: True) + **kwargs + keyword arguments passed into :class:`textwrap.TextWrapper`. Returns ------- wrapped : same type as values """ - tw = textwrap.TextWrapper(width=width) + tw = textwrap.TextWrapper(width=width, **kwargs) f = lambda x: "\n".join(tw.wrap(x)) return self._apply(f) def translate(self, table): """ - Map all characters in the string through the given mapping table. + Map characters of each string through the given mapping table. Parameters ---------- @@ -742,12 +728,12 @@ def find(self, sub, start=0, end=None, side="left"): Left edge index end : int Right edge index - side : {'left', 'right'}, default 'left' + side : {"left", "right"}, default: "left" Starting side for search. Returns ------- - found : array of integer values + found : array of int """ sub = self._obj.dtype.type(sub) @@ -782,7 +768,7 @@ def rfind(self, sub, start=0, end=None): Returns ------- - found : array of integer values + found : array of int """ return self.find(sub, start=start, end=end, side="right") @@ -801,12 +787,12 @@ def index(self, sub, start=0, end=None, side="left"): Left edge index end : int Right edge index - side : {'left', 'right'}, default 'left' + side : {"left", "right"}, default: "left" Starting side for search. Returns ------- - found : array of integer values + found : array of int """ sub = self._obj.dtype.type(sub) @@ -842,7 +828,7 @@ def rindex(self, sub, start=0, end=None): Returns ------- - found : array of integer values + found : array of int """ return self.index(sub, start=start, end=end, side="right") @@ -852,22 +838,22 @@ def replace(self, pat, repl, n=-1, case=None, flags=0, regex=True): Parameters ---------- - pat : string or compiled regex + pat : str or re.Pattern String can be a character sequence or regular expression. - repl : string or callable + repl : str or callable Replacement string or a callable. The callable is passed the regex match object and must return a replacement string to be used. See :func:`re.sub`. - n : int, default -1 (all) - Number of replacements to make from start - case : boolean, default None + n : int, default: -1 + Number of replacements to make from start. Use ``-1`` to replace all. + case : bool, default: None - If True, case sensitive (the default if `pat` is a string) - Set to False for case insensitive - Cannot be set if `pat` is a compiled regex - flags : int, default 0 (no flags) - - re module flags, e.g. re.IGNORECASE + flags : int, default: 0 + - re module flags, e.g. re.IGNORECASE. Use ``0`` for no flags. - Cannot be set if `pat` is a compiled regex - regex : boolean, default True + regex : bool, default: True - If True, assumes the passed-in pattern is a regular expression. - If False, treats the pattern as a literal string - Cannot be set to False if `pat` is a compiled regex or `repl` is @@ -893,7 +879,7 @@ def replace(self, pat, repl, n=-1, case=None, flags=0, regex=True): if is_compiled_re: if (case is not None) or (flags != 0): raise ValueError( - "case and flags cannot be set" " when pat is a compiled regex" + "case and flags cannot be set when pat is a compiled regex" ) else: # not a compiled regex @@ -917,9 +903,7 @@ def replace(self, pat, repl, n=-1, case=None, flags=0, regex=True): "pattern with regex=False" ) if callable(repl): - raise ValueError( - "Cannot use a callable replacement when " "regex=False" - ) + raise ValueError("Cannot use a callable replacement when regex=False") f = lambda x: x.replace(pat, repl, n) return self._apply(f) diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index abc180e049c..debf3aad96a 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -2,26 +2,39 @@ import operator from collections import defaultdict from contextlib import suppress -from typing import TYPE_CHECKING, Any, Dict, Hashable, Mapping, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Hashable, + Mapping, + Optional, + Tuple, + TypeVar, + Union, +) import numpy as np import pandas as pd from . import dtypes, utils from .indexing import get_indexer_nd -from .utils import is_dict_like, is_full_slice +from .utils import is_dict_like, is_full_slice, maybe_coerce_to_str from .variable import IndexVariable, Variable if TYPE_CHECKING: + from .common import DataWithCoords from .dataarray import DataArray from .dataset import Dataset + DataAlignable = TypeVar("DataAlignable", bound=DataWithCoords) + def _get_joiner(join): if join == "outer": - return functools.partial(functools.reduce, operator.or_) + return functools.partial(functools.reduce, pd.Index.union) elif join == "inner": - return functools.partial(functools.reduce, operator.and_) + return functools.partial(functools.reduce, pd.Index.intersection) elif join == "left": return operator.itemgetter(0) elif join == "right": @@ -59,13 +72,13 @@ def _override_indexes(objects, all_indexes, exclude): def align( - *objects, + *objects: "DataAlignable", join="inner", copy=True, indexes=None, exclude=frozenset(), fill_value=dtypes.NA, -): +) -> Tuple["DataAlignable", ...]: """ Given any number of Dataset and/or DataArray objects, returns new objects with aligned indexes and dimension sizes. @@ -80,17 +93,17 @@ def align( ---------- *objects : Dataset or DataArray Objects to align. - join : {'outer', 'inner', 'left', 'right', 'exact', 'override'}, optional + join : {"outer", "inner", "left", "right", "exact", "override"}, optional Method for joining the indexes of the passed objects along each dimension: - - 'outer': use the union of object indexes - - 'inner': use the intersection of object indexes - - 'left': use indexes from the first object with each dimension - - 'right': use indexes from the last object with each dimension - - 'exact': instead of aligning, raise `ValueError` when indexes to be + - "outer": use the union of object indexes + - "inner": use the intersection of object indexes + - "left": use indexes from the first object with each dimension + - "right": use indexes from the last object with each dimension + - "exact": instead of aligning, raise `ValueError` when indexes to be aligned are not equal - - 'override': if indexes are of same size, rewrite indexes to be + - "override": if indexes are of same size, rewrite indexes to be those of the first object with that dimension. Indexes for the same dimension must have the same size in all objects. copy : bool, optional @@ -103,13 +116,16 @@ def align( used in preference to the aligned indexes. exclude : sequence of str, optional Dimensions that must be excluded from alignment - fill_value : scalar, optional - Value to use for newly missing values + fill_value : scalar or dict-like, optional + Value to use for newly missing values. If a dict-like, maps + variable names to fill values. Use a data array's name to + refer to its values. Returns ------- - aligned : same as `*objects` - Tuple of objects with aligned coordinates. + aligned : DataArray or Dataset + Tuple of objects with the same type as `*objects` with aligned + coordinates. Raises ------ @@ -137,30 +153,30 @@ def align( array([[25, 35], [10, 24]]) Coordinates: - * lat (lat) float64 35.0 40.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 35.0 40.0 + * lon (lon) float64 100.0 120.0 >>> y array([[20, 5], [ 7, 13]]) Coordinates: - * lat (lat) float64 35.0 42.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 35.0 42.0 + * lon (lon) float64 100.0 120.0 >>> a, b = xr.align(x, y) >>> a array([[25, 35]]) Coordinates: - * lat (lat) float64 35.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 35.0 + * lon (lon) float64 100.0 120.0 >>> b array([[20, 5]]) Coordinates: - * lat (lat) float64 35.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 35.0 + * lon (lon) float64 100.0 120.0 >>> a, b = xr.align(x, y, join="outer") >>> a @@ -169,16 +185,16 @@ def align( [10., 24.], [nan, nan]]) Coordinates: - * lat (lat) float64 35.0 40.0 42.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 35.0 40.0 42.0 + * lon (lon) float64 100.0 120.0 >>> b array([[20., 5.], [nan, nan], [ 7., 13.]]) Coordinates: - * lat (lat) float64 35.0 40.0 42.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 35.0 40.0 42.0 + * lon (lon) float64 100.0 120.0 >>> a, b = xr.align(x, y, join="outer", fill_value=-999) >>> a @@ -187,16 +203,16 @@ def align( [ 10, 24], [-999, -999]]) Coordinates: - * lat (lat) float64 35.0 40.0 42.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 35.0 40.0 42.0 + * lon (lon) float64 100.0 120.0 >>> b array([[ 20, 5], [-999, -999], [ 7, 13]]) Coordinates: - * lat (lat) float64 35.0 40.0 42.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 35.0 40.0 42.0 + * lon (lon) float64 100.0 120.0 >>> a, b = xr.align(x, y, join="left") >>> a @@ -204,15 +220,15 @@ def align( array([[25, 35], [10, 24]]) Coordinates: - * lat (lat) float64 35.0 40.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 35.0 40.0 + * lon (lon) float64 100.0 120.0 >>> b array([[20., 5.], [nan, nan]]) Coordinates: - * lat (lat) float64 35.0 40.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 35.0 40.0 + * lon (lon) float64 100.0 120.0 >>> a, b = xr.align(x, y, join="right") >>> a @@ -220,15 +236,15 @@ def align( array([[25., 35.], [nan, nan]]) Coordinates: - * lat (lat) float64 35.0 42.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 35.0 42.0 + * lon (lon) float64 100.0 120.0 >>> b array([[20, 5], [ 7, 13]]) Coordinates: - * lat (lat) float64 35.0 42.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 35.0 42.0 + * lon (lon) float64 100.0 120.0 >>> a, b = xr.align(x, y, join="exact") Traceback (most recent call last): @@ -242,15 +258,15 @@ def align( array([[25, 35], [10, 24]]) Coordinates: - * lat (lat) float64 35.0 40.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 35.0 40.0 + * lon (lon) float64 100.0 120.0 >>> b array([[20, 5], [ 7, 13]]) Coordinates: - * lat (lat) float64 35.0 40.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 35.0 40.0 + * lon (lon) float64 100.0 120.0 """ if indexes is None: @@ -262,10 +278,12 @@ def align( return (obj.copy(deep=copy),) all_indexes = defaultdict(list) + all_coords = defaultdict(list) unlabeled_dim_sizes = defaultdict(set) for obj in objects: for dim in obj.dims: if dim not in exclude: + all_coords[dim].append(obj.coords[dim]) try: index = obj.indexes[dim] except KeyError: @@ -290,7 +308,7 @@ def align( any(not index.equals(other) for other in matching_indexes) or dim in unlabeled_dim_sizes ): - joined_indexes[dim] = index + joined_indexes[dim] = indexes[dim] else: if ( any( @@ -302,9 +320,11 @@ def align( if join == "exact": raise ValueError(f"indexes along dimension {dim!r} are not equal") index = joiner(matching_indexes) + # make sure str coords are not cast to object + index = maybe_coerce_to_str(index, all_coords[dim]) joined_indexes[dim] = index else: - index = matching_indexes[0] + index = all_coords[dim][0] if dim in unlabeled_dim_sizes: unlabeled_sizes = unlabeled_dim_sizes[dim] @@ -334,7 +354,9 @@ def align( # fast path for no reindexing necessary new_obj = obj.copy(deep=copy) else: - new_obj = obj.reindex(copy=copy, fill_value=fill_value, **valid_indexers) + new_obj = obj.reindex( + copy=copy, fill_value=fill_value, indexers=valid_indexers + ) new_obj.encoding = obj.encoding result.append(new_obj) @@ -565,7 +587,7 @@ def reindex_variables( args: tuple = (var.attrs, var.encoding) else: args = () - reindexed[dim] = IndexVariable((dim,), target, *args) + reindexed[dim] = IndexVariable((dim,), indexers[dim], *args) for dim in sizes: if dim not in indexes and dim in indexers: @@ -580,8 +602,13 @@ def reindex_variables( for name, var in variables.items(): if name not in indexers: + if isinstance(fill_value, dict): + fill_value_ = fill_value.get(name, dtypes.NA) + else: + fill_value_ = fill_value + if sparse: - var = var._as_sparse(fill_value=fill_value) + var = var._as_sparse(fill_value=fill_value_) key = tuple( slice(None) if d in unchanged_dims else int_indexers.get(d, slice(None)) for d in var.dims @@ -589,7 +616,7 @@ def reindex_variables( needs_masking = any(d in masked_dims for d in var.dims) if needs_masking: - new_var = var._getitem_with_mask(key, fill_value=fill_value) + new_var = var._getitem_with_mask(key, fill_value=fill_value_) elif all(is_full_slice(k) for k in key): # no reindexing necessary # here we need to manually deal with copying data, since @@ -664,14 +691,14 @@ def broadcast(*args, exclude=None): Parameters ---------- - *args : DataArray or Dataset objects + *args : DataArray or Dataset Arrays to broadcast against each other. exclude : sequence of str, optional Dimensions that must not be broadcasted Returns ------- - broadcast : tuple of xarray objects + broadcast : tuple of DataArray or tuple of Dataset The same data as the input arrays, but with additional dimensions inserted so that all data arrays have the same dimensions and shape. @@ -685,30 +712,24 @@ def broadcast(*args, exclude=None): >>> a array([1, 2, 3]) - Coordinates: - * x (x) int64 0 1 2 + Dimensions without coordinates: x >>> b array([5, 6]) - Coordinates: - * y (y) int64 0 1 + Dimensions without coordinates: y >>> a2, b2 = xr.broadcast(a, b) >>> a2 array([[1, 1], [2, 2], [3, 3]]) - Coordinates: - * x (x) int64 0 1 2 - * y (y) int64 0 1 + Dimensions without coordinates: x, y >>> b2 array([[5, 6], [5, 6], [5, 6]]) - Coordinates: - * y (y) int64 0 1 - * x (x) int64 0 1 2 + Dimensions without coordinates: x, y Fill out the dimensions of all data variables in a dataset: @@ -717,9 +738,7 @@ def broadcast(*args, exclude=None): >>> ds2 Dimensions: (x: 3, y: 2) - Coordinates: - * x (x) int64 0 1 2 - * y (y) int64 0 1 + Dimensions without coordinates: x, y Data variables: a (x, y) int64 1 1 2 2 3 3 b (x, y) int64 5 6 5 6 5 6 diff --git a/xarray/core/arithmetic.py b/xarray/core/arithmetic.py index 571dfbe70ed..8eba0fe7919 100644 --- a/xarray/core/arithmetic.py +++ b/xarray/core/arithmetic.py @@ -3,7 +3,7 @@ import numpy as np -from .options import OPTIONS +from .options import OPTIONS, _get_keep_attrs from .pycompat import dask_array_type from .utils import not_implemented @@ -77,6 +77,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): dataset_fill_value=np.nan, kwargs=kwargs, dask="allowed", + keep_attrs=_get_keep_attrs(default=True), ) # this has no runtime function - these are listed so IDEs know these diff --git a/xarray/core/combine.py b/xarray/core/combine.py index 1f990457798..86ed1870302 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -1,7 +1,5 @@ import itertools -import warnings from collections import Counter -from textwrap import dedent import pandas as pd @@ -95,7 +93,9 @@ def _infer_concat_order_from_coords(datasets): # position indices - they should be concatenated along another # dimension, not along this one series = first_items.to_series() - rank = series.rank(method="dense", ascending=ascending) + rank = series.rank( + method="dense", ascending=ascending, numeric_only=False + ) order = rank.astype(int).values - 1 # Append positions along extra dimension to structure which @@ -364,7 +364,7 @@ def combine_nested( Parameters ---------- - datasets : list or nested list of xarray.Dataset objects. + datasets : list or nested list of Dataset Dataset objects to combine. If concatenation or merging along more than one dimension is desired, then datasets must be supplied in a nested list-of-lists. @@ -377,48 +377,50 @@ def combine_nested( nested-list input along which to merge. Must be the same length as the depth of the list passed to ``datasets``. - compat : {'identical', 'equals', 'broadcast_equals', - 'no_conflicts', 'override'}, optional + compat : {"identical", "equals", "broadcast_equals", \ + "no_conflicts", "override"}, optional String indicating how to compare variables of the same name for potential merge conflicts: - - 'broadcast_equals': all values must be equal when variables are + - "broadcast_equals": all values must be equal when variables are broadcast against each other to ensure common dimensions. - - 'equals': all values and dimensions must be the same. - - 'identical': all values, dimensions and attributes must be the + - "equals": all values and dimensions must be the same. + - "identical": all values, dimensions and attributes must be the same. - - 'no_conflicts': only values which are not null in both datasets + - "no_conflicts": only values which are not null in both datasets must be equal. The returned dataset then contains the combination of all non-null values. - - 'override': skip comparing and pick variable from first dataset - data_vars : {'minimal', 'different', 'all' or list of str}, optional + - "override": skip comparing and pick variable from first dataset + data_vars : {"minimal", "different", "all" or list of str}, optional Details are in the documentation of concat - coords : {'minimal', 'different', 'all' or list of str}, optional + coords : {"minimal", "different", "all" or list of str}, optional Details are in the documentation of concat - fill_value : scalar, optional - Value to use for newly missing values - join : {'outer', 'inner', 'left', 'right', 'exact'}, optional + fill_value : scalar or dict-like, optional + Value to use for newly missing values. If a dict-like, maps + variable names to fill values. Use a data array's name to + refer to its values. + join : {"outer", "inner", "left", "right", "exact"}, optional String indicating how to combine differing indexes (excluding concat_dim) in objects - - 'outer': use the union of object indexes - - 'inner': use the intersection of object indexes - - 'left': use indexes from the first object with each dimension - - 'right': use indexes from the last object with each dimension - - 'exact': instead of aligning, raise `ValueError` when indexes to be + - "outer": use the union of object indexes + - "inner": use the intersection of object indexes + - "left": use indexes from the first object with each dimension + - "right": use indexes from the last object with each dimension + - "exact": instead of aligning, raise `ValueError` when indexes to be aligned are not equal - - 'override': if indexes are of same size, rewrite indexes to be + - "override": if indexes are of same size, rewrite indexes to be those of the first object with that dimension. Indexes for the same dimension must have the same size in all objects. - combine_attrs : {'drop', 'identical', 'no_conflicts', 'override'}, - default 'drop' + combine_attrs : {"drop", "identical", "no_conflicts", "override"}, \ + default: "drop" String indicating how to combine attrs of the objects being merged: - - 'drop': empty attrs on returned Dataset. - - 'identical': all attrs must be the same on every object. - - 'no_conflicts': attrs from all objects are combined, any that have + - "drop": empty attrs on returned Dataset. + - "identical": all attrs must be the same on every object. + - "no_conflicts": attrs from all objects are combined, any that have the same name must also have the same value. - - 'override': skip comparing and copy attrs from the first dataset to + - "override": skip comparing and copy attrs from the first dataset to the result. Returns @@ -433,22 +435,48 @@ def combine_nested( into 4 parts, 2 each along both the x and y axes, requires organising the datasets into a doubly-nested list, e.g: + >>> x1y1 = xr.Dataset( + ... { + ... "temperature": (("x", "y"), np.random.randn(2, 2)), + ... "precipitation": (("x", "y"), np.random.randn(2, 2)), + ... } + ... ) >>> x1y1 - Dimensions: (x: 2, y: 2) + Dimensions: (x: 2, y: 2) Dimensions without coordinates: x, y Data variables: - temperature (x, y) float64 11.04 23.57 20.77 ... - precipitation (x, y) float64 5.904 2.453 3.404 ... + temperature (x, y) float64 1.764 0.4002 0.9787 2.241 + precipitation (x, y) float64 1.868 -0.9773 0.9501 -0.1514 + >>> x1y2 = xr.Dataset( + ... { + ... "temperature": (("x", "y"), np.random.randn(2, 2)), + ... "precipitation": (("x", "y"), np.random.randn(2, 2)), + ... } + ... ) + >>> x2y1 = xr.Dataset( + ... { + ... "temperature": (("x", "y"), np.random.randn(2, 2)), + ... "precipitation": (("x", "y"), np.random.randn(2, 2)), + ... } + ... ) + >>> x2y2 = xr.Dataset( + ... { + ... "temperature": (("x", "y"), np.random.randn(2, 2)), + ... "precipitation": (("x", "y"), np.random.randn(2, 2)), + ... } + ... ) + >>> ds_grid = [[x1y1, x1y2], [x2y1, x2y2]] >>> combined = xr.combine_nested(ds_grid, concat_dim=["x", "y"]) + >>> combined - Dimensions: (x: 4, y: 4) + Dimensions: (x: 4, y: 4) Dimensions without coordinates: x, y Data variables: - temperature (x, y) float64 11.04 23.57 20.77 ... - precipitation (x, y) float64 5.904 2.453 3.404 ... + temperature (x, y) float64 1.764 0.4002 -0.1032 ... 0.04576 -0.1872 + precipitation (x, y) float64 1.868 -0.9773 0.761 ... -0.7422 0.1549 0.3782 ``manual_combine`` can also be used to explicitly merge datasets with different variables. For example if we have 4 datasets, which are divided @@ -456,34 +484,40 @@ def combine_nested( to ``concat_dim`` to specify the dimension of the nested list over which we wish to use ``merge`` instead of ``concat``: + >>> t1temp = xr.Dataset({"temperature": ("t", np.random.randn(5))}) >>> t1temp - Dimensions: (t: 5) + Dimensions: (t: 5) Dimensions without coordinates: t Data variables: - temperature (t) float64 11.04 23.57 20.77 ... + temperature (t) float64 -0.8878 -1.981 -0.3479 0.1563 1.23 + >>> t1precip = xr.Dataset({"precipitation": ("t", np.random.randn(5))}) >>> t1precip - Dimensions: (t: 5) + Dimensions: (t: 5) Dimensions without coordinates: t Data variables: - precipitation (t) float64 5.904 2.453 3.404 ... + precipitation (t) float64 1.202 -0.3873 -0.3023 -1.049 -1.42 + + >>> t2temp = xr.Dataset({"temperature": ("t", np.random.randn(5))}) + >>> t2precip = xr.Dataset({"precipitation": ("t", np.random.randn(5))}) + >>> ds_grid = [[t1temp, t1precip], [t2temp, t2precip]] >>> combined = xr.combine_nested(ds_grid, concat_dim=["t", None]) + >>> combined - Dimensions: (t: 10) + Dimensions: (t: 10) Dimensions without coordinates: t Data variables: - temperature (t) float64 11.04 23.57 20.77 ... - precipitation (t) float64 5.904 2.453 3.404 ... + temperature (t) float64 -0.8878 -1.981 -0.3479 ... -0.5097 -0.4381 -1.253 + precipitation (t) float64 1.202 -0.3873 -0.3023 ... -0.2127 -0.8955 0.3869 See also -------- concat merge - auto_combine """ if isinstance(concat_dim, (str, DataArray)) or concat_dim is None: concat_dim = [concat_dim] @@ -543,61 +577,63 @@ def combine_by_coords( ---------- datasets : sequence of xarray.Dataset Dataset objects to combine. - compat : {'identical', 'equals', 'broadcast_equals', 'no_conflicts', 'override'}, optional + compat : {"identical", "equals", "broadcast_equals", "no_conflicts", "override"}, optional String indicating how to compare variables of the same name for potential conflicts: - - 'broadcast_equals': all values must be equal when variables are + - "broadcast_equals": all values must be equal when variables are broadcast against each other to ensure common dimensions. - - 'equals': all values and dimensions must be the same. - - 'identical': all values, dimensions and attributes must be the + - "equals": all values and dimensions must be the same. + - "identical": all values, dimensions and attributes must be the same. - - 'no_conflicts': only values which are not null in both datasets + - "no_conflicts": only values which are not null in both datasets must be equal. The returned dataset then contains the combination of all non-null values. - - 'override': skip comparing and pick variable from first dataset - data_vars : {'minimal', 'different', 'all' or list of str}, optional + - "override": skip comparing and pick variable from first dataset + data_vars : {"minimal", "different", "all" or list of str}, optional These data variables will be concatenated together: - * 'minimal': Only data variables in which the dimension already + * "minimal": Only data variables in which the dimension already appears are included. - * 'different': Data variables which are not equal (ignoring + * "different": Data variables which are not equal (ignoring attributes) across all datasets are also concatenated (as well as all for which dimension already appears). Beware: this option may load the data payload of data variables into memory if they are not already loaded. - * 'all': All data variables will be concatenated. + * "all": All data variables will be concatenated. * list of str: The listed data variables will be concatenated, in - addition to the 'minimal' data variables. - - If objects are DataArrays, `data_vars` must be 'all'. - coords : {'minimal', 'different', 'all' or list of str}, optional - As per the 'data_vars' kwarg, but for coordinate variables. - fill_value : scalar, optional - Value to use for newly missing values. If None, raises a ValueError if + addition to the "minimal" data variables. + + If objects are DataArrays, `data_vars` must be "all". + coords : {"minimal", "different", "all"} or list of str, optional + As per the "data_vars" kwarg, but for coordinate variables. + fill_value : scalar or dict-like, optional + Value to use for newly missing values. If a dict-like, maps + variable names to fill values. Use a data array's name to + refer to its values. If None, raises a ValueError if the passed Datasets do not create a complete hypercube. - join : {'outer', 'inner', 'left', 'right', 'exact'}, optional + join : {"outer", "inner", "left", "right", "exact"}, optional String indicating how to combine differing indexes (excluding concat_dim) in objects - - 'outer': use the union of object indexes - - 'inner': use the intersection of object indexes - - 'left': use indexes from the first object with each dimension - - 'right': use indexes from the last object with each dimension - - 'exact': instead of aligning, raise `ValueError` when indexes to be + - "outer": use the union of object indexes + - "inner": use the intersection of object indexes + - "left": use indexes from the first object with each dimension + - "right": use indexes from the last object with each dimension + - "exact": instead of aligning, raise `ValueError` when indexes to be aligned are not equal - - 'override': if indexes are of same size, rewrite indexes to be + - "override": if indexes are of same size, rewrite indexes to be those of the first object with that dimension. Indexes for the same dimension must have the same size in all objects. - combine_attrs : {'drop', 'identical', 'no_conflicts', 'override'}, - default 'drop' + combine_attrs : {"drop", "identical", "no_conflicts", "override"}, \ + default: "drop" String indicating how to combine attrs of the objects being merged: - - 'drop': empty attrs on returned Dataset. - - 'identical': all attrs must be the same on every object. - - 'no_conflicts': attrs from all objects are combined, any that have + - "drop": empty attrs on returned Dataset. + - "identical": all attrs must be the same on every object. + - "no_conflicts": attrs from all objects are combined, any that have the same name must also have the same value. - - 'override': skip comparing and copy attrs from the first dataset to + - "override": skip comparing and copy attrs from the first dataset to the result. Returns @@ -646,71 +682,71 @@ def combine_by_coords( Dimensions: (x: 3, y: 2) Coordinates: - * y (y) int64 0 1 - * x (x) int64 10 20 30 + * y (y) int64 0 1 + * x (x) int64 10 20 30 Data variables: - temperature (y, x) float64 1.654 10.63 7.015 2.543 13.93 9.436 - precipitation (y, x) float64 0.2136 0.9974 0.7603 0.4679 0.3115 0.945 + temperature (y, x) float64 10.98 14.3 12.06 10.9 8.473 12.92 + precipitation (y, x) float64 0.4376 0.8918 0.9637 0.3834 0.7917 0.5289 >>> x2 Dimensions: (x: 3, y: 2) Coordinates: - * y (y) int64 2 3 - * x (x) int64 10 20 30 + * y (y) int64 2 3 + * x (x) int64 10 20 30 Data variables: - temperature (y, x) float64 9.341 0.1251 6.269 7.709 8.82 2.316 - precipitation (y, x) float64 0.1728 0.1178 0.03018 0.6509 0.06938 0.3792 + temperature (y, x) float64 11.36 18.51 1.421 1.743 0.4044 16.65 + precipitation (y, x) float64 0.7782 0.87 0.9786 0.7992 0.4615 0.7805 >>> x3 Dimensions: (x: 3, y: 2) Coordinates: - * y (y) int64 2 3 - * x (x) int64 40 50 60 + * y (y) int64 2 3 + * x (x) int64 40 50 60 Data variables: - temperature (y, x) float64 2.789 2.446 6.551 12.46 2.22 15.96 - precipitation (y, x) float64 0.4804 0.1902 0.2457 0.6125 0.4654 0.5953 + temperature (y, x) float64 2.365 12.8 2.867 18.89 10.44 8.293 + precipitation (y, x) float64 0.2646 0.7742 0.4562 0.5684 0.01879 0.6176 >>> xr.combine_by_coords([x2, x1]) Dimensions: (x: 3, y: 4) Coordinates: - * x (x) int64 10 20 30 - * y (y) int64 0 1 2 3 + * y (y) int64 0 1 2 3 + * x (x) int64 10 20 30 Data variables: - temperature (y, x) float64 1.654 10.63 7.015 2.543 ... 7.709 8.82 2.316 - precipitation (y, x) float64 0.2136 0.9974 0.7603 ... 0.6509 0.06938 0.3792 + temperature (y, x) float64 10.98 14.3 12.06 10.9 ... 1.743 0.4044 16.65 + precipitation (y, x) float64 0.4376 0.8918 0.9637 ... 0.7992 0.4615 0.7805 >>> xr.combine_by_coords([x3, x1]) Dimensions: (x: 6, y: 4) Coordinates: - * x (x) int64 10 20 30 40 50 60 - * y (y) int64 0 1 2 3 + * x (x) int64 10 20 30 40 50 60 + * y (y) int64 0 1 2 3 Data variables: - temperature (y, x) float64 1.654 10.63 7.015 nan ... nan 12.46 2.22 15.96 - precipitation (y, x) float64 0.2136 0.9974 0.7603 ... 0.6125 0.4654 0.5953 + temperature (y, x) float64 10.98 14.3 12.06 nan ... nan 18.89 10.44 8.293 + precipitation (y, x) float64 0.4376 0.8918 0.9637 ... 0.5684 0.01879 0.6176 >>> xr.combine_by_coords([x3, x1], join="override") Dimensions: (x: 3, y: 4) Coordinates: - * x (x) int64 10 20 30 - * y (y) int64 0 1 2 3 + * x (x) int64 10 20 30 + * y (y) int64 0 1 2 3 Data variables: - temperature (y, x) float64 1.654 10.63 7.015 2.543 ... 12.46 2.22 15.96 - precipitation (y, x) float64 0.2136 0.9974 0.7603 ... 0.6125 0.4654 0.5953 + temperature (y, x) float64 10.98 14.3 12.06 10.9 ... 18.89 10.44 8.293 + precipitation (y, x) float64 0.4376 0.8918 0.9637 ... 0.5684 0.01879 0.6176 >>> xr.combine_by_coords([x1, x2, x3]) Dimensions: (x: 6, y: 4) Coordinates: - * x (x) int64 10 20 30 40 50 60 - * y (y) int64 0 1 2 3 + * x (x) int64 10 20 30 40 50 60 + * y (y) int64 0 1 2 3 Data variables: - temperature (y, x) float64 1.654 10.63 7.015 nan ... 12.46 2.22 15.96 - precipitation (y, x) float64 0.2136 0.9974 0.7603 ... 0.6125 0.4654 0.5953 + temperature (y, x) float64 10.98 14.3 12.06 nan ... 18.89 10.44 8.293 + precipitation (y, x) float64 0.4376 0.8918 0.9637 ... 0.5684 0.01879 0.6176 """ # Group by data vars @@ -762,272 +798,3 @@ def combine_by_coords( join=join, combine_attrs=combine_attrs, ) - - -# Everything beyond here is only needed until the deprecation cycle in #2616 -# is completed - - -_CONCAT_DIM_DEFAULT = "__infer_concat_dim__" - - -def auto_combine( - datasets, - concat_dim="_not_supplied", - compat="no_conflicts", - data_vars="all", - coords="different", - fill_value=dtypes.NA, - join="outer", - from_openmfds=False, -): - """ - Attempt to auto-magically combine the given datasets into one. - - This entire function is deprecated in favour of ``combine_nested`` and - ``combine_by_coords``. - - This method attempts to combine a list of datasets into a single entity by - inspecting metadata and using a combination of concat and merge. - It does not concatenate along more than one dimension or sort data under - any circumstances. It does align coordinates, but different variables on - datasets can cause it to fail under some scenarios. In complex cases, you - may need to clean up your data and use ``concat``/``merge`` explicitly. - ``auto_combine`` works well if you have N years of data and M data - variables, and each combination of a distinct time period and set of data - variables is saved its own dataset. - - Parameters - ---------- - datasets : sequence of xarray.Dataset - Dataset objects to merge. - concat_dim : str or DataArray or Index, optional - Dimension along which to concatenate variables, as used by - :py:func:`xarray.concat`. You only need to provide this argument if - the dimension along which you want to concatenate is not a dimension - in the original datasets, e.g., if you want to stack a collection of - 2D arrays along a third dimension. - By default, xarray attempts to infer this argument by examining - component files. Set ``concat_dim=None`` explicitly to disable - concatenation. - compat : {'identical', 'equals', 'broadcast_equals', - 'no_conflicts', 'override'}, optional - String indicating how to compare variables of the same name for - potential conflicts: - - - 'broadcast_equals': all values must be equal when variables are - broadcast against each other to ensure common dimensions. - - 'equals': all values and dimensions must be the same. - - 'identical': all values, dimensions and attributes must be the - same. - - 'no_conflicts': only values which are not null in both datasets - must be equal. The returned dataset then contains the combination - of all non-null values. - - 'override': skip comparing and pick variable from first dataset - data_vars : {'minimal', 'different', 'all' or list of str}, optional - Details are in the documentation of concat - coords : {'minimal', 'different', 'all' o list of str}, optional - Details are in the documentation of concat - fill_value : scalar, optional - Value to use for newly missing values - join : {'outer', 'inner', 'left', 'right', 'exact'}, optional - String indicating how to combine differing indexes - (excluding concat_dim) in objects - - - 'outer': use the union of object indexes - - 'inner': use the intersection of object indexes - - 'left': use indexes from the first object with each dimension - - 'right': use indexes from the last object with each dimension - - 'exact': instead of aligning, raise `ValueError` when indexes to be - aligned are not equal - - 'override': if indexes are of same size, rewrite indexes to be - those of the first object with that dimension. Indexes for the same - dimension must have the same size in all objects. - - Returns - ------- - combined : xarray.Dataset - - See also - -------- - concat - Dataset.merge - """ - - if not from_openmfds: - basic_msg = dedent( - """\ - In xarray version 0.15 `auto_combine` will be deprecated. See - http://xarray.pydata.org/en/stable/combining.html#combining-multi""" - ) - warnings.warn(basic_msg, FutureWarning, stacklevel=2) - - if concat_dim == "_not_supplied": - concat_dim = _CONCAT_DIM_DEFAULT - message = "" - else: - message = dedent( - """\ - Also `open_mfdataset` will no longer accept a `concat_dim` argument. - To get equivalent behaviour from now on please use the new - `combine_nested` function instead (or the `combine='nested'` option to - `open_mfdataset`).""" - ) - - if _dimension_coords_exist(datasets): - message += dedent( - """\ - The datasets supplied have global dimension coordinates. You may want - to use the new `combine_by_coords` function (or the - `combine='by_coords'` option to `open_mfdataset`) to order the datasets - before concatenation. Alternatively, to continue concatenating based - on the order the datasets are supplied in future, please use the new - `combine_nested` function (or the `combine='nested'` option to - open_mfdataset).""" - ) - else: - message += dedent( - """\ - The datasets supplied do not have global dimension coordinates. In - future, to continue concatenating without supplying dimension - coordinates, please use the new `combine_nested` function (or the - `combine='nested'` option to open_mfdataset.""" - ) - - if _requires_concat_and_merge(datasets): - manual_dims = [concat_dim].append(None) - message += dedent( - """\ - The datasets supplied require both concatenation and merging. From - xarray version 0.15 this will operation will require either using the - new `combine_nested` function (or the `combine='nested'` option to - open_mfdataset), with a nested list structure such that you can combine - along the dimensions {}. Alternatively if your datasets have global - dimension coordinates then you can use the new `combine_by_coords` - function.""".format( - manual_dims - ) - ) - - warnings.warn(message, FutureWarning, stacklevel=2) - - return _old_auto_combine( - datasets, - concat_dim=concat_dim, - compat=compat, - data_vars=data_vars, - coords=coords, - fill_value=fill_value, - join=join, - ) - - -def _dimension_coords_exist(datasets): - """ - Check if the datasets have consistent global dimension coordinates - which would in future be used by `auto_combine` for concatenation ordering. - """ - - # Group by data vars - sorted_datasets = sorted(datasets, key=vars_as_keys) - grouped_by_vars = itertools.groupby(sorted_datasets, key=vars_as_keys) - - # Simulates performing the multidimensional combine on each group of data - # variables before merging back together - try: - for vars, datasets_with_same_vars in grouped_by_vars: - _infer_concat_order_from_coords(list(datasets_with_same_vars)) - return True - except ValueError: - # ValueError means datasets don't have global dimension coordinates - # Or something else went wrong in trying to determine them - return False - - -def _requires_concat_and_merge(datasets): - """ - Check if the datasets require the use of both xarray.concat and - xarray.merge, which in future might require the user to use - `manual_combine` instead. - """ - # Group by data vars - sorted_datasets = sorted(datasets, key=vars_as_keys) - grouped_by_vars = itertools.groupby(sorted_datasets, key=vars_as_keys) - - return len(list(grouped_by_vars)) > 1 - - -def _old_auto_combine( - datasets, - concat_dim=_CONCAT_DIM_DEFAULT, - compat="no_conflicts", - data_vars="all", - coords="different", - fill_value=dtypes.NA, - join="outer", -): - if concat_dim is not None: - dim = None if concat_dim is _CONCAT_DIM_DEFAULT else concat_dim - - sorted_datasets = sorted(datasets, key=vars_as_keys) - grouped = itertools.groupby(sorted_datasets, key=vars_as_keys) - - concatenated = [ - _auto_concat( - list(datasets), - dim=dim, - data_vars=data_vars, - coords=coords, - compat=compat, - fill_value=fill_value, - join=join, - ) - for vars, datasets in grouped - ] - else: - concatenated = datasets - merged = merge(concatenated, compat=compat, fill_value=fill_value, join=join) - return merged - - -def _auto_concat( - datasets, - dim=None, - data_vars="all", - coords="different", - fill_value=dtypes.NA, - join="outer", - compat="no_conflicts", -): - if len(datasets) == 1 and dim is None: - # There is nothing more to combine, so kick out early. - return datasets[0] - else: - if dim is None: - ds0 = datasets[0] - ds1 = datasets[1] - concat_dims = set(ds0.dims) - if ds0.dims != ds1.dims: - dim_tuples = set(ds0.dims.items()) - set(ds1.dims.items()) - concat_dims = {i for i, _ in dim_tuples} - if len(concat_dims) > 1: - concat_dims = {d for d in concat_dims if not ds0[d].equals(ds1[d])} - if len(concat_dims) > 1: - raise ValueError( - "too many different dimensions to " "concatenate: %s" % concat_dims - ) - elif len(concat_dims) == 0: - raise ValueError( - "cannot infer dimension to concatenate: " - "supply the ``concat_dim`` argument " - "explicitly" - ) - (dim,) = concat_dims - return concat( - datasets, - dim=dim, - data_vars=data_vars, - coords=coords, - fill_value=fill_value, - compat=compat, - ) diff --git a/xarray/core/common.py b/xarray/core/common.py index 8f6d57e9f12..283114770cf 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -23,9 +23,9 @@ from .arithmetic import SupportsArithmetic from .npcompat import DTypeLike from .options import OPTIONS, _get_keep_attrs -from .pycompat import dask_array_type +from .pycompat import is_duck_dask_array from .rolling_exp import RollingExp -from .utils import Frozen, either_dict_or_kwargs +from .utils import Frozen, either_dict_or_kwargs, is_scalar # Used as a sentinel value to indicate a all dimensions ALL_DIMS = ... @@ -111,8 +111,7 @@ def wrapped_func(self, dim=None, **kwargs): # type: ignore class AbstractArray(ImplementsArrayReduce): - """Shared base class for DataArray and Variable. - """ + """Shared base class for DataArray and Variable.""" __slots__ = () @@ -188,8 +187,7 @@ def sizes(self: Any) -> Mapping[Hashable, int]: class AttrAccessMixin: - """Mixin class that allows getting keys with attribute access - """ + """Mixin class that allows getting keys with attribute access""" __slots__ = () @@ -211,16 +209,14 @@ def __init_subclass__(cls): ) @property - def _attr_sources(self) -> List[Mapping[Hashable, Any]]: - """List of places to look-up items for attribute-style access - """ - return [] + def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]: + """Places to look-up items for attribute-style access""" + yield from () @property - def _item_sources(self) -> List[Mapping[Hashable, Any]]: - """List of places to look-up items for key-autocompletion - """ - return [] + def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]: + """Places to look-up items for key-autocompletion""" + yield from () def __getattr__(self, name: str) -> Any: if name not in {"__dict__", "__setstate__"}: @@ -239,8 +235,7 @@ def __getattr__(self, name: str) -> Any: # runtime before every single assignment. All of this is just temporary until the # FutureWarning can be changed into a hard crash. def _setattr_dict(self, name: str, value: Any) -> None: - """Deprecated third party subclass (see ``__init_subclass__`` above) - """ + """Deprecated third party subclass (see ``__init_subclass__`` above)""" object.__setattr__(self, name, value) if name in self.__dict__: # Custom, non-slotted attr, or improperly assigned variable? @@ -277,26 +272,26 @@ def __dir__(self) -> List[str]: """Provide method name lookup and completion. Only provide 'public' methods. """ - extra_attrs = [ + extra_attrs = set( item - for sublist in self._attr_sources - for item in sublist + for source in self._attr_sources + for item in source if isinstance(item, str) - ] - return sorted(set(dir(type(self)) + extra_attrs)) + ) + return sorted(set(dir(type(self))) | extra_attrs) def _ipython_key_completions_(self) -> List[str]: """Provide method for the key-autocompletions in IPython. See http://ipython.readthedocs.io/en/stable/config/integrating.html#tab-completion For the details. """ - item_lists = [ + items = set( item - for sublist in self._item_sources - for item in sublist + for source in self._item_sources + for item in source if isinstance(item, str) - ] - return list(set(item_lists)) + ) + return list(items) def get_squeeze_dims( @@ -304,8 +299,7 @@ def get_squeeze_dims( dim: Union[Hashable, Iterable[Hashable], None] = None, axis: Union[int, Iterable[int], None] = None, ) -> List[Hashable]: - """Get a list of dimensions to squeeze out. - """ + """Get a list of dimensions to squeeze out.""" if dim is not None and axis is not None: raise ValueError("cannot use both parameters `axis` and `dim`") if dim is None and axis is None: @@ -374,16 +368,14 @@ def squeeze( return self.isel(drop=drop, **{d: 0 for d in dims}) def get_index(self, key: Hashable) -> pd.Index: - """Get an index for a dimension, with fall-back to a default RangeIndex - """ + """Get an index for a dimension, with fall-back to a default RangeIndex""" if key not in self.dims: raise KeyError(key) try: return self.indexes[key] except KeyError: - # need to ensure dtype=int64 in case range is empty on Python 2 - return pd.Index(range(self.sizes[key]), name=key, dtype=np.int64) + return pd.Index(range(self.sizes[key]), name=key) def _calc_assign_results( self: C, kwargs: Mapping[Hashable, Union[T, Callable[[C], T]]] @@ -408,7 +400,7 @@ def assign_coords(self, coords=None, **coords_kwargs): the first element the dimension name and the second element the values for this new coordinate. - **coords_kwargs : keyword, value pairs, optional + **coords_kwargs : optional The keyword arguments form of ``coords``. One of ``coords`` or ``coords_kwargs`` must be provided. @@ -423,16 +415,18 @@ def assign_coords(self, coords=None, **coords_kwargs): Convert longitude coordinates from 0-359 to -180-179: >>> da = xr.DataArray( - ... np.random.rand(4), coords=[np.array([358, 359, 0, 1])], dims="lon", + ... np.random.rand(4), + ... coords=[np.array([358, 359, 0, 1])], + ... dims="lon", ... ) >>> da - array([0.28298 , 0.667347, 0.657938, 0.177683]) + array([0.5488135 , 0.71518937, 0.60276338, 0.54488318]) Coordinates: * lon (lon) int64 358 359 0 1 >>> da.assign_coords(lon=(((da.lon + 180) % 360) - 180)) - array([0.28298 , 0.667347, 0.657938, 0.177683]) + array([0.5488135 , 0.71518937, 0.60276338, 0.54488318]) Coordinates: * lon (lon) int64 -2 -1 0 1 @@ -440,23 +434,23 @@ def assign_coords(self, coords=None, **coords_kwargs): >>> da.assign_coords({"lon": (((da.lon + 180) % 360) - 180)}) - array([0.28298 , 0.667347, 0.657938, 0.177683]) + array([0.5488135 , 0.71518937, 0.60276338, 0.54488318]) Coordinates: * lon (lon) int64 -2 -1 0 1 New coordinate can also be attached to an existing dimension: >>> lon_2 = np.array([300, 289, 0, 1]) - >>> da.assign_coords(lon_2=('lon', lon_2)) + >>> da.assign_coords(lon_2=("lon", lon_2)) - array([0.28298 , 0.667347, 0.657938, 0.177683]) + array([0.5488135 , 0.71518937, 0.60276338, 0.54488318]) Coordinates: * lon (lon) int64 358 359 0 1 lon_2 (lon) int64 300 289 0 1 Note that the same result can also be obtained with a dict e.g. - >>> _ = da.assign_coords({"lon_2": ('lon', lon_2)}) + >>> _ = da.assign_coords({"lon_2": ("lon", lon_2)}) Notes ----- @@ -484,8 +478,10 @@ def assign_attrs(self, *args, **kwargs): Parameters ---------- - args : positional arguments passed into ``attrs.update``. - kwargs : keyword arguments passed into ``attrs.update``. + args + positional arguments passed into ``attrs.update``. + kwargs + keyword arguments passed into ``attrs.update``. Returns ------- @@ -513,18 +509,21 @@ def pipe( Parameters ---------- - func : function + func : callable function to apply to this xarray object (Dataset/DataArray). ``args``, and ``kwargs`` are passed into ``func``. Alternatively a ``(callable, data_keyword)`` tuple where ``data_keyword`` is a string indicating the keyword of ``callable`` that expects the xarray object. - args : positional arguments passed into ``func``. - kwargs : a dictionary of keyword arguments passed into ``func``. + args + positional arguments passed into ``func``. + kwargs + a dictionary of keyword arguments passed into ``func``. Returns ------- - object : the return type of ``func``. + object : Any + the return type of ``func``. Notes ----- @@ -532,17 +531,23 @@ def pipe( Use ``.pipe`` when chaining together functions that expect xarray or pandas objects, e.g., instead of writing - >>> f(g(h(ds), arg1=a), arg2=b, arg3=c) + .. code:: python + + f(g(h(ds), arg1=a), arg2=b, arg3=c) You can write - >>> (ds.pipe(h).pipe(g, arg1=a).pipe(f, arg2=b, arg3=c)) + .. code:: python + + (ds.pipe(h).pipe(g, arg1=a).pipe(f, arg2=b, arg3=c)) If you have a function that takes the data as (say) the second argument, pass a tuple indicating which keyword expects the data. For example, suppose ``f`` takes its data as ``arg2``: - >>> (ds.pipe(h).pipe(g, arg1=a).pipe((f, "arg2"), arg1=a, arg3=c)) + .. code:: python + + (ds.pipe(h).pipe(g, arg1=a).pipe((f, "arg2"), arg1=a, arg3=c)) Examples -------- @@ -563,11 +568,11 @@ def pipe( Dimensions: (lat: 2, lon: 2) Coordinates: - * lat (lat) int64 10 20 - * lon (lon) int64 150 160 + * lat (lat) int64 10 20 + * lon (lon) int64 150 160 Data variables: - temperature_c (lat, lon) float64 14.53 11.85 19.27 16.37 - precipitation (lat, lon) float64 0.7315 0.7189 0.8481 0.4671 + temperature_c (lat, lon) float64 10.98 14.3 12.06 10.9 + precipitation (lat, lon) float64 0.4237 0.6459 0.4376 0.8918 >>> def adder(data, arg): ... return data + arg @@ -582,21 +587,21 @@ def pipe( Dimensions: (lat: 2, lon: 2) Coordinates: - * lon (lon) int64 150 160 - * lat (lat) int64 10 20 + * lat (lat) int64 10 20 + * lon (lon) int64 150 160 Data variables: - temperature_c (lat, lon) float64 16.53 13.85 21.27 18.37 - precipitation (lat, lon) float64 2.731 2.719 2.848 2.467 + temperature_c (lat, lon) float64 12.98 16.3 14.06 12.9 + precipitation (lat, lon) float64 2.424 2.646 2.438 2.892 >>> x.pipe(adder, arg=2) Dimensions: (lat: 2, lon: 2) Coordinates: - * lon (lon) int64 150 160 - * lat (lat) int64 10 20 + * lat (lat) int64 10 20 + * lon (lon) int64 150 160 Data variables: - temperature_c (lat, lon) float64 16.53 13.85 21.27 18.37 - precipitation (lat, lon) float64 2.731 2.719 2.848 2.467 + temperature_c (lat, lon) float64 12.98 16.3 14.06 12.9 + precipitation (lat, lon) float64 2.424 2.646 2.438 2.892 >>> ( ... x.pipe(adder, arg=2) @@ -606,11 +611,11 @@ def pipe( Dimensions: (lat: 2, lon: 2) Coordinates: - * lon (lon) int64 150 160 - * lat (lat) int64 10 20 + * lat (lat) int64 10 20 + * lon (lon) int64 150 160 Data variables: - temperature_c (lat, lon) float64 14.53 11.85 19.27 16.37 - precipitation (lat, lon) float64 0.7315 0.7189 0.8481 0.4671 + temperature_c (lat, lon) float64 10.98 14.3 12.06 10.9 + precipitation (lat, lon) float64 0.4237 0.6459 0.4376 0.8918 See Also -------- @@ -620,7 +625,7 @@ def pipe( func, target = func if target in kwargs: raise ValueError( - "%s is both the pipe target and a keyword " "argument" % target + "%s is both the pipe target and a keyword argument" % target ) kwargs[target] = self return func(*args, **kwargs) @@ -635,7 +640,7 @@ def groupby(self, group, squeeze: bool = True, restore_coord_dims: bool = None): group : str, DataArray or IndexVariable Array whose unique values should be used to group this array. If a string, must be the name of a variable contained in this dataset. - squeeze : boolean, optional + squeeze : bool, optional If "group" is a dimension of any arrays in this dataset, `squeeze` controls whether the subarrays have a dimension of length 1 along that dimension or if the dimension is squeezed out. @@ -645,7 +650,7 @@ def groupby(self, group, squeeze: bool = True, restore_coord_dims: bool = None): Returns ------- - grouped : GroupBy + grouped A `GroupBy` object patterned after `pandas.GroupBy` that can be iterated over in the form of `(unique_value, grouped_array)` pairs. @@ -660,15 +665,16 @@ def groupby(self, group, squeeze: bool = True, restore_coord_dims: bool = None): ... ) >>> da - array([0.000e+00, 1.000e+00, 2.000e+00, ..., 1.824e+03, 1.825e+03, 1.826e+03]) + array([0.000e+00, 1.000e+00, 2.000e+00, ..., 1.824e+03, 1.825e+03, + 1.826e+03]) Coordinates: - * time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 ... + * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2004-12-31 >>> da.groupby("time.dayofyear") - da.groupby("time.dayofyear").mean("time") array([-730.8, -730.8, -730.8, ..., 730.2, 730.2, 730.5]) Coordinates: - * time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 ... - dayofyear (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 ... + * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2004-12-31 + dayofyear (time) int64 1 2 3 4 5 6 7 8 ... 359 360 361 362 363 364 365 366 See Also -------- @@ -711,17 +717,17 @@ def groupby_bins( group : str, DataArray or IndexVariable Array whose binned values should be used to group this array. If a string, must be the name of a variable contained in this dataset. - bins : int or array of scalars + bins : int or array-like If bins is an int, it defines the number of equal-width bins in the range of x. However, in this case, the range of x is extended by .1% on each side to include the min or max values of x. If bins is a sequence it defines the bin edges allowing for non-uniform bin width. No extension of the range of x is done in this case. - right : boolean, optional + right : bool, default: True Indicates whether the bins include the rightmost edge or not. If right == True (the default), then the bins [1,2,3,4] indicate (1,2], (2,3], (3,4]. - labels : array or boolean, default None + labels : array-like or bool, default: None Used as labels for the resulting bins. Must be of the same length as the resulting bins. If False, string bin labels are assigned by `pandas.cut`. @@ -729,7 +735,7 @@ def groupby_bins( The precision at which to store and display the bins labels. include_lowest : bool Whether the first interval should be left-inclusive or not. - squeeze : boolean, optional + squeeze : bool, default: True If "group" is a dimension of any arrays in this dataset, `squeeze` controls whether the subarrays have a dimension of length 1 along that dimension or if the dimension is squeezed out. @@ -739,7 +745,7 @@ def groupby_bins( Returns ------- - grouped : GroupBy + grouped A `GroupBy` object patterned after `pandas.GroupBy` that can be iterated over in the form of `(unique_value, grouped_array)` pairs. The name of the group has the added suffix `_bins` in order to @@ -786,7 +792,7 @@ def rolling( self, dim: Mapping[Hashable, int] = None, min_periods: int = None, - center: bool = False, + center: Union[bool, Mapping[Hashable, bool]] = False, keep_attrs: bool = None, **window_kwargs: int, ): @@ -798,24 +804,21 @@ def rolling( dim: dict, optional Mapping from the dimension name to create the rolling iterator along (e.g. `time`) to its moving window size. - min_periods : int, default None + min_periods : int, default: None Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. - center : boolean, default False + center : bool or mapping, default: False Set the labels at the center of the window. - keep_attrs : bool, optional - If True, the object's attributes (`attrs`) will be copied from - the original object to the new one. If False (default), the new - object will be returned without attributes. **window_kwargs : optional The keyword arguments form of ``dim``. One of dim or window_kwargs must be provided. Returns ------- - Rolling object (core.rolling.DataArrayRolling for DataArray, - core.rolling.DatasetRolling for Dataset.) + core.rolling.DataArrayRolling or core.rolling.DatasetRolling + A rolling object (``DataArrayRolling`` for ``DataArray``, + ``DatasetRolling`` for ``Dataset``) Examples -------- @@ -825,21 +828,23 @@ def rolling( ... np.linspace(0, 11, num=12), ... coords=[ ... pd.date_range( - ... "15/12/1999", periods=12, freq=pd.DateOffset(months=1), + ... "15/12/1999", + ... periods=12, + ... freq=pd.DateOffset(months=1), ... ) ... ], ... dims="time", ... ) >>> da - array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.]) + array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.]) Coordinates: - * time (time) datetime64[ns] 1999-12-15 2000-01-15 2000-02-15 ... + * time (time) datetime64[ns] 1999-12-15 2000-01-15 ... 2000-11-15 >>> da.rolling(time=3, center=True).mean() array([nan, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., nan]) Coordinates: - * time (time) datetime64[ns] 1999-12-15 2000-01-15 2000-02-15 ... + * time (time) datetime64[ns] 1999-12-15 2000-01-15 ... 2000-11-15 Remove the NaNs using ``dropna()``: @@ -847,15 +852,13 @@ def rolling( array([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.]) Coordinates: - * time (time) datetime64[ns] 2000-01-15 2000-02-15 2000-03-15 ... + * time (time) datetime64[ns] 2000-01-15 2000-02-15 ... 2000-10-15 See Also -------- core.rolling.DataArrayRolling core.rolling.DatasetRolling """ - if keep_attrs is None: - keep_attrs = _get_keep_attrs(default=False) dim = either_dict_or_kwargs(dim, window_kwargs, "rolling") return self._rolling_cls( @@ -876,20 +879,13 @@ def rolling_exp( Parameters ---------- - window : A single mapping from a dimension name to window value, - optional - - dim : str - Name of the dimension to create the rolling exponential window - along (e.g., `time`). - window : int - Size of the moving window. The type of this is specified in - `window_type` - window_type : str, one of ['span', 'com', 'halflife', 'alpha'], - default 'span' + window : mapping of hashable to int, optional + A mapping from the name of the dimension to create the rolling + exponential window along (e.g. `time`) to the size of the moving window. + window_type : {"span", "com", "halflife", "alpha"}, default: "span" The format of the previously supplied window. Each is a simple numerical transformation of the others. Described in detail: - https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.ewm.html + https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.ewm.html **window_kwargs : optional The keyword arguments form of ``window``. One of window or window_kwargs must be provided. @@ -916,20 +912,15 @@ def coarsen( Parameters ---------- - dim: dict, optional + dim : mapping of hashable to int, optional Mapping from the dimension name to the window size. - - dim : str - Name of the dimension to create the rolling iterator - along (e.g., `time`). - window : int - Size of the moving window. - boundary : 'exact' | 'trim' | 'pad' + boundary : {"exact", "trim", "pad"}, default: "exact" If 'exact', a ValueError will be raised if dimension size is not a multiple of the window size. If 'trim', the excess entries are dropped. If 'pad', NA will be padded. - side : 'left' or 'right' or mapping from dimension to 'left' or 'right' - coord_func : function (name) that is applied to the coordinates, + side : {"left", "right"} or mapping of str to {"left", "right"} + coord_func : str or mapping of hashable to str, default: "mean" + function (name) that is applied to the coordinates, or a mapping from coordinate name to function (name). keep_attrs : bool, optional If True, the object's attributes (`attrs`) will be copied from @@ -938,8 +929,9 @@ def coarsen( Returns ------- - Coarsen object (core.rolling.DataArrayCoarsen for DataArray, - core.rolling.DatasetCoarsen for Dataset.) + core.rolling.DataArrayCoarsen or core.rolling.DatasetCoarsen + A coarsen object (``DataArrayCoarsen`` for ``DataArray``, + ``DatasetCoarsen`` for ``Dataset``) Examples -------- @@ -950,17 +942,24 @@ def coarsen( ... dims="time", ... coords={"time": pd.date_range("15/12/1999", periods=364)}, ... ) - >>> da + >>> da # +doctest: ELLIPSIS - array([ 0. , 1.002755, 2.00551 , ..., 361.99449 , 362.997245, - 364. ]) + array([ 0. , 1.00275482, 2.00550964, 3.00826446, + 4.01101928, 5.0137741 , 6.01652893, 7.01928375, + 8.02203857, 9.02479339, 10.02754821, 11.03030303, + ... + 356.98071625, 357.98347107, 358.9862259 , 359.98898072, + 360.99173554, 361.99449036, 362.99724518, 364. ]) Coordinates: * time (time) datetime64[ns] 1999-12-15 1999-12-16 ... 2000-12-12 - >>> - >>> da.coarsen(time=3, boundary="trim").mean() + >>> da.coarsen(time=3, boundary="trim").mean() # +doctest: ELLIPSIS - array([ 1.002755, 4.011019, 7.019284, ..., 358.986226, - 361.99449 ]) + array([ 1.00275482, 4.01101928, 7.01928375, 10.02754821, + 13.03581267, 16.04407713, 19.0523416 , 22.06060606, + 25.06887052, 28.07713499, 31.08539945, 34.09366391, + ... + 349.96143251, 352.96969697, 355.97796143, 358.9862259 , + 361.99449036]) Coordinates: * time (time) datetime64[ns] 1999-12-16 1999-12-19 ... 2000-12-10 >>> @@ -1009,13 +1008,13 @@ def resample( dimension must be datetime-like. skipna : bool, optional Whether to skip missing values when aggregating in downsampling. - closed : 'left' or 'right', optional + closed : {"left", "right"}, optional Side of each interval to treat as closed. - label : 'left or 'right', optional + label : {"left", "right"}, optional Side of each interval to use for labeling. base : int, optional For frequencies that evenly subdivide 1 day, the "origin" of the - aggregated intervals. For example, for '24H' frequency, base could + aggregated intervals. For example, for "24H" frequency, base could range from 0 through 23. loffset : timedelta or str, optional Offset used to adjust the resampled time labels. Some pandas date @@ -1044,16 +1043,18 @@ def resample( ... np.linspace(0, 11, num=12), ... coords=[ ... pd.date_range( - ... "15/12/1999", periods=12, freq=pd.DateOffset(months=1), + ... "15/12/1999", + ... periods=12, + ... freq=pd.DateOffset(months=1), ... ) ... ], ... dims="time", ... ) >>> da - array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.]) + array([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.]) Coordinates: - * time (time) datetime64[ns] 1999-12-15 2000-01-15 2000-02-15 ... + * time (time) datetime64[ns] 1999-12-15 2000-01-15 ... 2000-11-15 >>> da.resample(time="QS-DEC").mean() array([ 1., 4., 7., 10.]) @@ -1062,11 +1063,16 @@ def resample( Upsample monthly time-series data to daily data: - >>> da.resample(time="1D").interpolate("linear") + >>> da.resample(time="1D").interpolate("linear") # +doctest: ELLIPSIS - array([ 0. , 0.032258, 0.064516, ..., 10.935484, 10.967742, 11. ]) + array([ 0. , 0.03225806, 0.06451613, 0.09677419, 0.12903226, + 0.16129032, 0.19354839, 0.22580645, 0.25806452, 0.29032258, + 0.32258065, 0.35483871, 0.38709677, 0.41935484, 0.4516129 , + ... + 10.80645161, 10.83870968, 10.87096774, 10.90322581, 10.93548387, + 10.96774194, 11. ]) Coordinates: - * time (time) datetime64[ns] 1999-12-15 1999-12-16 1999-12-17 ... + * time (time) datetime64[ns] 1999-12-15 1999-12-16 ... 2000-11-15 Limit scope of upsampling method @@ -1088,9 +1094,9 @@ def resample( """ # TODO support non-string indexer after removing the old API. + from ..coding.cftimeindex import CFTimeIndex from .dataarray import DataArray from .resample import RESAMPLE_DIM - from ..coding.cftimeindex import CFTimeIndex if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) @@ -1115,14 +1121,22 @@ def resample( dim_name = dim dim_coord = self[dim] - if isinstance(self.indexes[dim_name], CFTimeIndex): - from .resample_cftime import CFTimeGrouper - - grouper = CFTimeGrouper(freq, closed, label, base, loffset) - else: - grouper = pd.Grouper( - freq=freq, closed=closed, label=label, base=base, loffset=loffset + # TODO: remove once pandas=1.1 is the minimum required version + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + r"'(base|loffset)' in .resample\(\) and in Grouper\(\) is deprecated.", + category=FutureWarning, ) + + if isinstance(self.indexes[dim_name], CFTimeIndex): + from .resample_cftime import CFTimeGrouper + + grouper = CFTimeGrouper(freq, closed, label, base, loffset) + else: + grouper = pd.Grouper( + freq=freq, closed=closed, label=label, base=base, loffset=loffset + ) group = DataArray( dim_coord, coords=dim_coord.coords, dims=dim_coord.dims, name=RESAMPLE_DIM ) @@ -1145,19 +1159,21 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): Parameters ---------- - cond : DataArray or Dataset with boolean dtype - Locations at which to preserve this object's values. + cond : DataArray, Dataset, or callable + Locations at which to preserve this object's values. dtype must be `bool`. + If a callable, it must expect this object as its only parameter. other : scalar, DataArray or Dataset, optional Value to use for locations in this object where ``cond`` is False. By default, these locations filled with NA. - drop : boolean, optional + drop : bool, optional If True, coordinate labels that only correspond to False values of the condition are dropped from the result. Mutually exclusive with ``other``. Returns ------- - Same xarray type as caller, with dtype float64. + DataArray or Dataset + Same xarray type as caller, with dtype float64. Examples -------- @@ -1167,19 +1183,19 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): >>> a array([[ 0, 1, 2, 3, 4], - [ 5, 6, 7, 8, 9], - [10, 11, 12, 13, 14], - [15, 16, 17, 18, 19], - [20, 21, 22, 23, 24]]) + [ 5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24]]) Dimensions without coordinates: x, y >>> a.where(a.x + a.y < 4) - array([[ 0., 1., 2., 3., nan], - [ 5., 6., 7., nan, nan], - [ 10., 11., nan, nan, nan], - [ 15., nan, nan, nan, nan], - [ nan, nan, nan, nan, nan]]) + array([[ 0., 1., 2., 3., nan], + [ 5., 6., 7., nan, nan], + [10., 11., nan, nan, nan], + [15., nan, nan, nan, nan], + [nan, nan, nan, nan, nan]]) Dimensions without coordinates: x, y >>> a.where(a.x + a.y < 5, -1) @@ -1193,18 +1209,18 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): >>> a.where(a.x + a.y < 4, drop=True) - array([[ 0., 1., 2., 3.], - [ 5., 6., 7., nan], - [ 10., 11., nan, nan], - [ 15., nan, nan, nan]]) + array([[ 0., 1., 2., 3.], + [ 5., 6., 7., nan], + [10., 11., nan, nan], + [15., nan, nan, nan]]) Dimensions without coordinates: x, y >>> a.where(lambda x: x.x + x.y < 4, drop=True) - array([[ 0., 1., 2., 3.], - [ 5., 6., 7., nan], - [ 10., 11., nan, nan], - [ 15., nan, nan, nan]]) + array([[ 0., 1., 2., 3.], + [ 5., 6., 7., nan], + [10., 11., nan, nan], + [15., nan, nan, nan]]) Dimensions without coordinates: x, y See also @@ -1248,12 +1264,83 @@ def where(self, cond, other=dtypes.NA, drop: bool = False): return ops.where_method(self, cond, other) def close(self: Any) -> None: - """Close any files linked to this object - """ + """Close any files linked to this object""" if self._file_obj is not None: self._file_obj.close() self._file_obj = None + def isnull(self, keep_attrs: bool = None): + """Test each value in the array for whether it is a missing value. + + Returns + ------- + isnull : DataArray or Dataset + Same type and shape as object, but the dtype of the data is bool. + + See Also + -------- + pandas.isnull + + Examples + -------- + >>> array = xr.DataArray([1, np.nan, 3], dims="x") + >>> array + + array([ 1., nan, 3.]) + Dimensions without coordinates: x + >>> array.isnull() + + array([False, True, False]) + Dimensions without coordinates: x + """ + from .computation import apply_ufunc + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + + return apply_ufunc( + duck_array_ops.isnull, + self, + dask="allowed", + keep_attrs=keep_attrs, + ) + + def notnull(self, keep_attrs: bool = None): + """Test each value in the array for whether it is not a missing value. + + Returns + ------- + notnull : DataArray or Dataset + Same type and shape as object, but the dtype of the data is bool. + + See Also + -------- + pandas.notnull + + Examples + -------- + >>> array = xr.DataArray([1, np.nan, 3], dims="x") + >>> array + + array([ 1., nan, 3.]) + Dimensions without coordinates: x + >>> array.notnull() + + array([ True, False, True]) + Dimensions without coordinates: x + """ + from .computation import apply_ufunc + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + + return apply_ufunc( + duck_array_ops.notnull, + self, + dask="allowed", + keep_attrs=keep_attrs, + ) + def isin(self, test_elements): """Tests each value in the array for whether it is in test elements. @@ -1266,8 +1353,8 @@ def isin(self, test_elements): Returns ------- - isin : same as object, bool - Has the same shape as this object. + isin : DataArray or Dataset + Has the same type and shape as this object, but with a bool dtype. Examples -------- @@ -1283,8 +1370,8 @@ def isin(self, test_elements): numpy.isin """ from .computation import apply_ufunc - from .dataset import Dataset from .dataarray import DataArray + from .dataset import Dataset from .variable import Variable if isinstance(test_elements, Dataset): @@ -1305,6 +1392,83 @@ def isin(self, test_elements): dask="allowed", ) + def astype( + self: T, + dtype, + *, + order=None, + casting=None, + subok=None, + copy=None, + keep_attrs=True, + ) -> T: + """ + Copy of the xarray object, with data cast to a specified type. + Leaves coordinate dtype unchanged. + + Parameters + ---------- + dtype : str or dtype + Typecode or data-type to which the array is cast. + order : {'C', 'F', 'A', 'K'}, optional + Controls the memory layout order of the result. ‘C’ means C order, + ‘F’ means Fortran order, ‘A’ means ‘F’ order if all the arrays are + Fortran contiguous, ‘C’ order otherwise, and ‘K’ means as close to + the order the array elements appear in memory as possible. + casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional + Controls what kind of data casting may occur. + + * 'no' means the data types should not be cast at all. + * 'equiv' means only byte-order changes are allowed. + * 'safe' means only casts which can preserve values are allowed. + * 'same_kind' means only safe casts or casts within a kind, + like float64 to float32, are allowed. + * 'unsafe' means any data conversions may be done. + + subok : bool, optional + If True, then sub-classes will be passed-through, otherwise the + returned array will be forced to be a base-class array. + copy : bool, optional + By default, astype always returns a newly allocated array. If this + is set to False and the `dtype` requirement is satisfied, the input + array is returned instead of a copy. + keep_attrs : bool, optional + By default, astype keeps attributes. Set to False to remove + attributes in the returned object. + + Returns + ------- + out : same as object + New object with data cast to the specified type. + + Notes + ----- + The ``order``, ``casting``, ``subok`` and ``copy`` arguments are only passed + through to the ``astype`` method of the underlying array when a value + different than ``None`` is supplied. + Make sure to only supply these arguments if the underlying array class + supports them. + + See also + -------- + numpy.ndarray.astype + dask.array.Array.astype + sparse.COO.astype + """ + from .computation import apply_ufunc + + kwargs = dict(order=order, casting=casting, subok=subok, copy=copy) + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + return apply_ufunc( + duck_array_ops.astype, + self, + dtype, + kwargs=kwargs, + keep_attrs=keep_attrs, + dask="allowed", + ) + def __enter__(self: T) -> T: return self @@ -1321,12 +1485,15 @@ def full_like(other, fill_value, dtype: DTypeLike = None): Parameters ---------- - other : DataArray, Dataset, or Variable + other : DataArray, Dataset or Variable The reference object in input - fill_value : scalar - Value to fill the new object with before returning it. - dtype : dtype, optional - dtype of the new array. If omitted, it defaults to other.dtype. + fill_value : scalar or dict-like + Value to fill the new object with before returning it. If + other is a Dataset, may also be a dict-like mapping data + variables to fill values. + dtype : dtype or dict-like of dtype, optional + dtype of the new array. If a dict-like, maps dtypes to + variables. If omitted, it defaults to other.dtype. Returns ------- @@ -1351,40 +1518,68 @@ def full_like(other, fill_value, dtype: DTypeLike = None): array([[0, 1, 2], [3, 4, 5]]) Coordinates: - * lat (lat) int64 1 2 - * lon (lon) int64 0 1 2 + * lat (lat) int64 1 2 + * lon (lon) int64 0 1 2 >>> xr.full_like(x, 1) array([[1, 1, 1], [1, 1, 1]]) Coordinates: - * lat (lat) int64 1 2 - * lon (lon) int64 0 1 2 + * lat (lat) int64 1 2 + * lon (lon) int64 0 1 2 >>> xr.full_like(x, 0.5) array([[0, 0, 0], [0, 0, 0]]) Coordinates: - * lat (lat) int64 1 2 - * lon (lon) int64 0 1 2 + * lat (lat) int64 1 2 + * lon (lon) int64 0 1 2 >>> xr.full_like(x, 0.5, dtype=np.double) array([[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]]) Coordinates: - * lat (lat) int64 1 2 - * lon (lon) int64 0 1 2 + * lat (lat) int64 1 2 + * lon (lon) int64 0 1 2 >>> xr.full_like(x, np.nan, dtype=np.double) array([[nan, nan, nan], [nan, nan, nan]]) Coordinates: - * lat (lat) int64 1 2 - * lon (lon) int64 0 1 2 + * lat (lat) int64 1 2 + * lon (lon) int64 0 1 2 + + >>> ds = xr.Dataset( + ... {"a": ("x", [3, 5, 2]), "b": ("x", [9, 1, 0])}, coords={"x": [2, 4, 6]} + ... ) + >>> ds + + Dimensions: (x: 3) + Coordinates: + * x (x) int64 2 4 6 + Data variables: + a (x) int64 3 5 2 + b (x) int64 9 1 0 + >>> xr.full_like(ds, fill_value={"a": 1, "b": 2}) + + Dimensions: (x: 3) + Coordinates: + * x (x) int64 2 4 6 + Data variables: + a (x) int64 1 1 1 + b (x) int64 2 2 2 + >>> xr.full_like(ds, fill_value={"a": 1, "b": 2}, dtype={"a": bool, "b": float}) + + Dimensions: (x: 3) + Coordinates: + * x (x) int64 2 4 6 + Data variables: + a (x) bool True True True + b (x) float64 2.0 2.0 2.0 See also -------- @@ -1397,9 +1592,22 @@ def full_like(other, fill_value, dtype: DTypeLike = None): from .dataset import Dataset from .variable import Variable + if not is_scalar(fill_value) and not ( + isinstance(other, Dataset) and isinstance(fill_value, dict) + ): + raise ValueError( + f"fill_value must be scalar or, for datasets, a dict-like. Received {fill_value} instead." + ) + if isinstance(other, Dataset): + if not isinstance(fill_value, dict): + fill_value = {k: fill_value for k in other.data_vars.keys()} + + if not isinstance(dtype, dict): + dtype = {k: dtype for k in other.data_vars.keys()} + data_vars = { - k: _full_like_variable(v, fill_value, dtype) + k: _full_like_variable(v, fill_value.get(k, dtypes.NA), dtype.get(k, None)) for k, v in other.data_vars.items() } return Dataset(data_vars, coords=other.coords, attrs=other.attrs) @@ -1418,11 +1626,13 @@ def full_like(other, fill_value, dtype: DTypeLike = None): def _full_like_variable(other, fill_value, dtype: DTypeLike = None): - """Inner function of full_like, where other must be a variable - """ + """Inner function of full_like, where other must be a variable""" from .variable import Variable - if isinstance(other.data, dask_array_type): + if fill_value is dtypes.NA: + fill_value = dtypes.get_fill_value(dtype if dtype is not None else other.dtype) + + if is_duck_dask_array(other.data): import dask.array if dtype is None: @@ -1431,7 +1641,7 @@ def _full_like_variable(other, fill_value, dtype: DTypeLike = None): other.shape, fill_value, dtype=dtype, chunks=other.data.chunks ) else: - data = np.full_like(other, fill_value, dtype=dtype) + data = np.full_like(other.data, fill_value, dtype=dtype) return Variable(dims=other.dims, data=data, attrs=other.attrs) @@ -1442,14 +1652,14 @@ def zeros_like(other, dtype: DTypeLike = None): Parameters ---------- - other : DataArray, Dataset, or Variable + other : DataArray, Dataset or Variable The reference object. The output will have the same dimensions and coordinates as this object. dtype : dtype, optional dtype of the new array. If omitted, it defaults to other.dtype. Returns ------- - out : same as object + out : DataArray, Dataset or Variable New object of zeros with the same shape and type as other. Examples @@ -1467,24 +1677,24 @@ def zeros_like(other, dtype: DTypeLike = None): array([[0, 1, 2], [3, 4, 5]]) Coordinates: - * lat (lat) int64 1 2 - * lon (lon) int64 0 1 2 + * lat (lat) int64 1 2 + * lon (lon) int64 0 1 2 >>> xr.zeros_like(x) array([[0, 0, 0], [0, 0, 0]]) Coordinates: - * lat (lat) int64 1 2 - * lon (lon) int64 0 1 2 + * lat (lat) int64 1 2 + * lon (lon) int64 0 1 2 - >>> xr.zeros_like(x, dtype=np.float) + >>> xr.zeros_like(x, dtype=float) array([[0., 0., 0.], [0., 0., 0.]]) Coordinates: - * lat (lat) int64 1 2 - * lon (lon) int64 0 1 2 + * lat (lat) int64 1 2 + * lon (lon) int64 0 1 2 See also -------- @@ -1527,16 +1737,16 @@ def ones_like(other, dtype: DTypeLike = None): array([[0, 1, 2], [3, 4, 5]]) Coordinates: - * lat (lat) int64 1 2 - * lon (lon) int64 0 1 2 + * lat (lat) int64 1 2 + * lon (lon) int64 0 1 2 >>> xr.ones_like(x) array([[1, 1, 1], [1, 1, 1]]) Coordinates: - * lat (lat) int64 1 2 - * lon (lon) int64 0 1 2 + * lat (lat) int64 1 2 + * lon (lon) int64 0 1 2 See also -------- @@ -1549,20 +1759,17 @@ def ones_like(other, dtype: DTypeLike = None): def is_np_datetime_like(dtype: DTypeLike) -> bool: - """Check if a dtype is a subclass of the numpy datetime types - """ + """Check if a dtype is a subclass of the numpy datetime types""" return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) def is_np_timedelta_like(dtype: DTypeLike) -> bool: - """Check whether dtype is of the timedelta64 dtype. - """ + """Check whether dtype is of the timedelta64 dtype.""" return np.issubdtype(dtype, np.timedelta64) def _contains_cftime_datetimes(array) -> bool: - """Check if an array contains cftime.datetime objects - """ + """Check if an array contains cftime.datetime objects""" try: from cftime import datetime as cftime_datetime except ImportError: @@ -1570,7 +1777,7 @@ def _contains_cftime_datetimes(array) -> bool: else: if array.dtype == np.dtype("O") and array.size > 0: sample = array.ravel()[0] - if isinstance(sample, dask_array_type): + if is_duck_dask_array(sample): sample = sample.compute() if isinstance(sample, np.ndarray): sample = sample.item() @@ -1580,8 +1787,7 @@ def _contains_cftime_datetimes(array) -> bool: def contains_cftime_datetimes(var) -> bool: - """Check if an xarray.Variable contains cftime.datetime objects - """ + """Check if an xarray.Variable contains cftime.datetime objects""" return _contains_cftime_datetimes(var.data) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 6cf4178b5bf..e0d9ff4b218 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -4,7 +4,9 @@ import functools import itertools import operator +import warnings from collections import Counter +from distutils.version import LooseVersion from typing import ( TYPE_CHECKING, AbstractSet, @@ -24,11 +26,10 @@ import numpy as np from . import dtypes, duck_array_ops, utils -from .alignment import deep_align +from .alignment import align, deep_align from .merge import merge_coordinates_without_align -from .nanops import dask_array from .options import OPTIONS -from .pycompat import dask_array_type +from .pycompat import is_duck_dask_array from .utils import is_dict_like from .variable import Variable @@ -41,6 +42,14 @@ _JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"}) +def _first_of_type(args, kind): + """ Return either first object of type 'kind' or raise if not found. """ + for arg in args: + if isinstance(arg, kind): + return arg + raise ValueError("This should be unreachable.") + + class _UFuncSignature: """Core dimensions signature for a given function. @@ -91,6 +100,12 @@ def all_core_dims(self): self._all_core_dims = self.all_input_core_dims | self.all_output_core_dims return self._all_core_dims + @property + def dims_map(self): + return { + core_dim: f"dim{n}" for n, core_dim in enumerate(sorted(self.all_core_dims)) + } + @property def num_inputs(self): return len(self.input_core_dims) @@ -113,7 +128,9 @@ def __ne__(self, other): def __repr__(self): return "{}({!r}, {!r})".format( - type(self).__name__, list(self.input_core_dims), list(self.output_core_dims) + type(self).__name__, + list(self.input_core_dims), + list(self.output_core_dims), ) def __str__(self): @@ -121,22 +138,41 @@ def __str__(self): rhs = ",".join("({})".format(",".join(dims)) for dims in self.output_core_dims) return f"{lhs}->{rhs}" - def to_gufunc_string(self): + def to_gufunc_string(self, exclude_dims=frozenset()): """Create an equivalent signature string for a NumPy gufunc. Unlike __str__, handles dimensions that don't map to Python identifiers. + + Also creates unique names for input_core_dims contained in exclude_dims. """ - all_dims = self.all_core_dims - dims_map = dict(zip(sorted(all_dims), range(len(all_dims)))) input_core_dims = [ - ["dim%d" % dims_map[dim] for dim in core_dims] + [self.dims_map[dim] for dim in core_dims] for core_dims in self.input_core_dims ] output_core_dims = [ - ["dim%d" % dims_map[dim] for dim in core_dims] + [self.dims_map[dim] for dim in core_dims] for core_dims in self.output_core_dims ] + + # enumerate input_core_dims contained in exclude_dims to make them unique + if exclude_dims: + + exclude_dims = [self.dims_map[dim] for dim in exclude_dims] + + counter = Counter() + + def _enumerate(dim): + if dim in exclude_dims: + n = counter[dim] + counter.update([dim]) + dim = f"{dim}_{n}" + return dim + + input_core_dims = [ + [_enumerate(dim) for dim in arg] for arg in input_core_dims + ] + alt_signature = type(self)(input_core_dims, output_core_dims) return str(alt_signature) @@ -177,7 +213,7 @@ def build_output_coords( are OK, e.g., scalars, Variable, DataArray, Dataset. signature : _UfuncSignature Core dimensions signature for the operation. - exclude_dims : optional set + exclude_dims : set, optional Dimensions excluded from the operation. Coordinates along these dimensions are dropped. @@ -224,8 +260,9 @@ def apply_dataarray_vfunc( args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False ) - if keep_attrs and hasattr(args[0], "name"): - name = args[0].name + if keep_attrs: + first_obj = _first_of_type(args, DataArray) + name = first_obj.name else: name = result_name(args) result_coords = build_output_coords(args, signature, exclude_dims) @@ -242,6 +279,14 @@ def apply_dataarray_vfunc( (coords,) = result_coords out = DataArray(result_var, coords, name=name, fastpath=True) + if keep_attrs: + if isinstance(out, tuple): + for da in out: + # This is adding attrs in place + da._copy_attrs_from(first_obj) + else: + out._copy_attrs_from(first_obj) + return out @@ -362,8 +407,6 @@ def apply_dataset_vfunc( """ from .dataset import Dataset - first_obj = args[0] # we'll copy attrs from this in case keep_attrs=True - if dataset_join not in _JOINS_WITHOUT_FILL_VALUES and fill_value is _NO_FILL_VALUE: raise TypeError( "to apply an operation to datasets with different " @@ -371,6 +414,9 @@ def apply_dataset_vfunc( "dataset_fill_value argument." ) + if keep_attrs: + first_obj = _first_of_type(args, Dataset) + if len(args) > 1: args = deep_align( args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False @@ -389,9 +435,11 @@ def apply_dataset_vfunc( (coord_vars,) = list_of_coords out = _fast_dataset(result_vars, coord_vars) - if keep_attrs and isinstance(first_obj, Dataset): + if keep_attrs: if isinstance(out, tuple): - out = tuple(ds._copy_attrs_from(first_obj) for ds in out) + for ds in out: + # This is adding attrs in place + ds._copy_attrs_from(first_obj) else: out._copy_attrs_from(first_obj) return out @@ -425,7 +473,7 @@ def apply_groupby_func(func, *args): if any(not first_groupby._group.equals(gb._group) for gb in groupbys[1:]): raise ValueError( "apply_ufunc can only perform operations over " - "multiple GroupBy objets at once if they are all " + "multiple GroupBy objects at once if they are all " "grouped the same way" ) @@ -540,6 +588,19 @@ def broadcast_compat_data( return data +def _vectorize(func, signature, output_dtypes, exclude_dims): + if signature.all_core_dims: + func = np.vectorize( + func, + otypes=output_dtypes, + signature=signature.to_gufunc_string(exclude_dims), + ) + else: + func = np.vectorize(func, otypes=output_dtypes) + + return func + + def apply_variable_ufunc( func, *args, @@ -547,14 +608,15 @@ def apply_variable_ufunc( exclude_dims=frozenset(), dask="forbidden", output_dtypes=None, - output_sizes=None, + vectorize=False, keep_attrs=False, - meta=None, + dask_gufunc_kwargs=None, ): - """Apply a ndarray level function over Variable and/or ndarray objects. - """ + """Apply a ndarray level function over Variable and/or ndarray objects.""" from .variable import Variable, as_compatible_data + first_obj = _first_of_type(args, Variable) + dim_sizes = unified_dim_sizes( (a for a in args if hasattr(a, "dims")), exclude_dims=exclude_dims ) @@ -570,7 +632,7 @@ def apply_variable_ufunc( for arg, core_dims in zip(args, signature.input_core_dims) ] - if any(isinstance(array, dask_array_type) for array in input_data): + if any(is_duck_dask_array(array) for array in input_data): if dask == "forbidden": raise ValueError( "apply_ufunc encountered a dask array on an " @@ -580,21 +642,72 @@ def apply_variable_ufunc( "``.load()`` or ``.compute()``" ) elif dask == "parallelized": - input_dims = [broadcast_dims + dims for dims in signature.input_core_dims] numpy_func = func + if dask_gufunc_kwargs is None: + dask_gufunc_kwargs = {} + else: + dask_gufunc_kwargs = dask_gufunc_kwargs.copy() + + allow_rechunk = dask_gufunc_kwargs.get("allow_rechunk", None) + if allow_rechunk is None: + for n, (data, core_dims) in enumerate( + zip(input_data, signature.input_core_dims) + ): + if is_duck_dask_array(data): + # core dimensions cannot span multiple chunks + for axis, dim in enumerate(core_dims, start=-len(core_dims)): + if len(data.chunks[axis]) != 1: + raise ValueError( + f"dimension {dim} on {n}th function argument to " + "apply_ufunc with dask='parallelized' consists of " + "multiple chunks, but is also a core dimension. To " + "fix, either rechunk into a single dask array chunk along " + f"this dimension, i.e., ``.chunk({dim}: -1)``, or " + "pass ``allow_rechunk=True`` in ``dask_gufunc_kwargs`` " + "but beware that this may significantly increase memory usage." + ) + dask_gufunc_kwargs["allow_rechunk"] = True + + output_sizes = dask_gufunc_kwargs.pop("output_sizes", {}) + if output_sizes: + output_sizes_renamed = {} + for key, value in output_sizes.items(): + if key not in signature.all_output_core_dims: + raise ValueError( + f"dimension '{key}' in 'output_sizes' must correspond to output_core_dims" + ) + output_sizes_renamed[signature.dims_map[key]] = value + dask_gufunc_kwargs["output_sizes"] = output_sizes_renamed + + for key in signature.all_output_core_dims: + if key not in signature.all_input_core_dims and key not in output_sizes: + raise ValueError( + f"dimension '{key}' in 'output_core_dims' needs corresponding (dim, size) in 'output_sizes'" + ) + def func(*arrays): - return _apply_blockwise( + import dask.array as da + + res = da.apply_gufunc( numpy_func, - arrays, - input_dims, - output_dims, - signature, - output_dtypes, - output_sizes, - meta, + signature.to_gufunc_string(exclude_dims), + *arrays, + vectorize=vectorize, + output_dtypes=output_dtypes, + **dask_gufunc_kwargs, ) + # todo: covers for https://github.com/dask/dask/pull/6207 + # remove when minimal dask version >= 2.17.0 + from dask import __version__ as dask_version + + if LooseVersion(dask_version) < LooseVersion("2.17.0"): + if signature.num_outputs > 1: + res = tuple(res) + + return res + elif dask == "allowed": pass else: @@ -602,6 +715,12 @@ def func(*arrays): "unknown setting for dask array handling in " "apply_ufunc: {}".format(dask) ) + else: + if vectorize: + func = _vectorize( + func, signature, output_dtypes=output_dtypes, exclude_dims=exclude_dims + ) + result_data = func(*input_data) if signature.num_outputs == 1: @@ -623,9 +742,8 @@ def func(*arrays): if data.ndim != len(dims): raise ValueError( "applied function returned data with unexpected " - "number of dimensions: {} vs {}, for dimensions {}".format( - data.ndim, len(dims), dims - ) + f"number of dimensions. Received {data.ndim} dimension(s) but " + f"expected {len(dims)} dimensions with names: {dims!r}" ) var = Variable(dims, data, fastpath=True) @@ -640,8 +758,8 @@ def func(*arrays): ) ) - if keep_attrs and isinstance(args[0], Variable): - var.attrs.update(args[0].attrs) + if keep_attrs: + var.attrs.update(first_obj.attrs) output.append(var) if signature.num_outputs == 1: @@ -650,93 +768,9 @@ def func(*arrays): return tuple(output) -def _apply_blockwise( - func, - args, - input_dims, - output_dims, - signature, - output_dtypes, - output_sizes=None, - meta=None, -): - import dask.array - - if signature.num_outputs > 1: - raise NotImplementedError( - "multiple outputs from apply_ufunc not yet " - "supported with dask='parallelized'" - ) - - if output_dtypes is None: - raise ValueError( - "output dtypes (output_dtypes) must be supplied to " - "apply_func when using dask='parallelized'" - ) - if not isinstance(output_dtypes, list): - raise TypeError( - "output_dtypes must be a list of objects coercible to " - "numpy dtypes, got {}".format(output_dtypes) - ) - if len(output_dtypes) != signature.num_outputs: - raise ValueError( - "apply_ufunc arguments output_dtypes and " - "output_core_dims must have the same length: {} vs {}".format( - len(output_dtypes), signature.num_outputs - ) - ) - (dtype,) = output_dtypes - - if output_sizes is None: - output_sizes = {} - - new_dims = signature.all_output_core_dims - signature.all_input_core_dims - if any(dim not in output_sizes for dim in new_dims): - raise ValueError( - "when using dask='parallelized' with apply_ufunc, " - "output core dimensions not found on inputs must " - "have explicitly set sizes with ``output_sizes``: {}".format(new_dims) - ) - - for n, (data, core_dims) in enumerate(zip(args, signature.input_core_dims)): - if isinstance(data, dask_array_type): - # core dimensions cannot span multiple chunks - for axis, dim in enumerate(core_dims, start=-len(core_dims)): - if len(data.chunks[axis]) != 1: - raise ValueError( - "dimension {!r} on {}th function argument to " - "apply_ufunc with dask='parallelized' consists of " - "multiple chunks, but is also a core dimension. To " - "fix, rechunk into a single dask array chunk along " - "this dimension, i.e., ``.chunk({})``, but beware " - "that this may significantly increase memory usage.".format( - dim, n, {dim: -1} - ) - ) - - (out_ind,) = output_dims - - blockwise_args = [] - for arg, dims in zip(args, input_dims): - # skip leading dimensions that are implicitly added by broadcasting - ndim = getattr(arg, "ndim", 0) - trimmed_dims = dims[-ndim:] if ndim else () - blockwise_args.extend([arg, trimmed_dims]) - - return dask.array.blockwise( - func, - out_ind, - *blockwise_args, - dtype=dtype, - concatenate=True, - new_axes=output_sizes, - meta=meta, - ) - - def apply_array_ufunc(func, *args, dask="forbidden"): """Apply a ndarray level function over ndarray objects.""" - if any(isinstance(arg, dask_array_type) for arg in args): + if any(is_duck_dask_array(arg) for arg in args): if dask == "forbidden": raise ValueError( "apply_ufunc encountered a dask array on an " @@ -773,6 +807,7 @@ def apply_ufunc( output_dtypes: Sequence = None, output_sizes: Mapping[Any, int] = None, meta: Any = None, + dask_gufunc_kwargs: Dict[str, Any] = None, ) -> Any: """Apply a vectorized function for unlabeled arrays on xarray objects. @@ -791,9 +826,9 @@ def apply_ufunc( the style of NumPy universal functions [1]_ (if this is not the case, set ``vectorize=True``). If this function returns multiple outputs, you must set ``output_core_dims`` as well. - *args : Dataset, DataArray, GroupBy, Variable, numpy/dask arrays or scalars + *args : Dataset, DataArray, GroupBy, Variable, numpy.ndarray, dask.array.Array or scalar Mix of labeled and/or unlabeled arrays to which to apply the function. - input_core_dims : Sequence[Sequence], optional + input_core_dims : sequence of sequence, optional List of the same length as ``args`` giving the list of core dimensions on each input argument that should not be broadcast. By default, we assume there are no core dimensions on any input arguments. @@ -805,7 +840,7 @@ def apply_ufunc( Core dimensions are automatically moved to the last axes of input variables before applying ``func``, which facilitates using NumPy style generalized ufuncs [2]_. - output_core_dims : List[tuple], optional + output_core_dims : list of tuple, optional List of the same length as the number of output arguments from ``func``, giving the list of core dimensions on each output that were not broadcast on the inputs. By default, we assume that ``func`` @@ -826,7 +861,7 @@ def apply_ufunc( :py:func:`numpy.vectorize`. This option exists for convenience, but is almost always slower than supplying a pre-vectorized function. Using this option requires NumPy version 1.12 or newer. - join : {'outer', 'inner', 'left', 'right', 'exact'}, optional + join : {"outer", "inner", "left", "right", "exact"}, default: "exact" Method for joining the indexes of the passed objects along each dimension, and the variables of Dataset objects with mismatched data variables: @@ -837,7 +872,7 @@ def apply_ufunc( - 'right': use indexes from the last object with each dimension - 'exact': raise `ValueError` instead of aligning when indexes to be aligned are not equal - dataset_join : {'outer', 'inner', 'left', 'right', 'exact'}, optional + dataset_join : {"outer", "inner", "left", "right", "exact"}, default: "exact" Method for joining variables of Dataset objects with mismatched data variables. @@ -850,28 +885,38 @@ def apply_ufunc( Value used in place of missing variables on Dataset inputs when the datasets do not share the exact same ``data_vars``. Required if ``dataset_join not in {'inner', 'exact'}``, otherwise ignored. - keep_attrs: boolean, Optional + keep_attrs: bool, optional Whether to copy attributes from the first argument to the output. kwargs: dict, optional Optional keyword arguments passed directly on to call ``func``. - dask: 'forbidden', 'allowed' or 'parallelized', optional + dask: {"forbidden", "allowed", "parallelized"}, default: "forbidden" How to handle applying to objects containing lazy data in the form of dask arrays: - 'forbidden' (default): raise an error if a dask array is encountered. - - 'allowed': pass dask arrays directly on to ``func``. + - 'allowed': pass dask arrays directly on to ``func``. Prefer this option if + ``func`` natively supports dask arrays. - 'parallelized': automatically parallelize ``func`` if any of the - inputs are a dask array. If used, the ``output_dtypes`` argument must - also be provided. Multiple output arguments are not yet supported. - output_dtypes : list of dtypes, optional - Optional list of output dtypes. Only used if dask='parallelized'. + inputs are a dask array by using `dask.array.apply_gufunc`. Multiple output + arguments are supported. Only use this option if ``func`` does not natively + support dask arrays (e.g. converts them to numpy arrays). + dask_gufunc_kwargs : dict, optional + Optional keyword arguments passed to ``dask.array.apply_gufunc`` if + dask='parallelized'. Possible keywords are ``output_sizes``, ``allow_rechunk`` + and ``meta``. + output_dtypes : list of dtype, optional + Optional list of output dtypes. Only used if ``dask='parallelized'`` or + vectorize=True. output_sizes : dict, optional Optional mapping from dimension names to sizes for outputs. Only used if dask='parallelized' and new dimensions (not found on inputs) appear - on outputs. + on outputs. ``output_sizes`` should be given in the ``dask_gufunc_kwargs`` + parameter. It will be removed as direct parameter in a future version. meta : optional Size-0 object representing the type of array wrapped by dask array. Passed on to - ``dask.array.blockwise``. + ``dask.array.apply_gufunc``. ``meta`` should be given in the + ``dask_gufunc_kwargs`` parameter . It will be removed as direct parameter + a future version. Returns ------- @@ -886,6 +931,7 @@ def apply_ufunc( >>> def magnitude(a, b): ... func = lambda x, y: np.sqrt(x ** 2 + y ** 2) ... return xr.apply_ufunc(func, a, b) + ... You can now apply ``magnitude()`` to ``xr.DataArray`` and ``xr.Dataset`` objects, with automatically preserved dimensions and coordinates, e.g., @@ -893,7 +939,7 @@ def apply_ufunc( >>> array = xr.DataArray([1, 2, 3], coords=[("x", [0.1, 0.2, 0.3])]) >>> magnitude(array, -array) - array([1.414214, 2.828427, 4.242641]) + array([1.41421356, 2.82842712, 4.24264069]) Coordinates: * x (x) float64 0.1 0.2 0.3 @@ -977,17 +1023,18 @@ def earth_mover_distance(first_samples, .. [2] http://docs.scipy.org/doc/numpy/reference/c-api.generalized-ufuncs.html .. [3] http://xarray.pydata.org/en/stable/computation.html#wrapping-custom-computation """ - from .groupby import GroupBy from .dataarray import DataArray + from .groupby import GroupBy from .variable import Variable if input_core_dims is None: input_core_dims = ((),) * (len(args)) elif len(input_core_dims) != len(args): raise ValueError( - "input_core_dims must be None or a tuple with the length same to " - "the number of arguments. Given input_core_dims: {}, " - "number of args: {}.".format(input_core_dims, len(args)) + f"input_core_dims must be None or a tuple with the length same to " + f"the number of arguments. " + f"Given {len(input_core_dims)} input_core_dims: {input_core_dims}, " + f" but number of args is {len(args)}." ) if kwargs is None: @@ -995,28 +1042,46 @@ def earth_mover_distance(first_samples, signature = _UFuncSignature(input_core_dims, output_core_dims) - if exclude_dims and not exclude_dims <= signature.all_core_dims: - raise ValueError( - "each dimension in `exclude_dims` must also be a " - "core dimension in the function signature" - ) + if exclude_dims: + if not isinstance(exclude_dims, set): + raise TypeError( + f"Expected exclude_dims to be a 'set'. Received '{type(exclude_dims).__name__}' instead." + ) + if not exclude_dims <= signature.all_core_dims: + raise ValueError( + f"each dimension in `exclude_dims` must also be a " + f"core dimension in the function signature. " + f"Please make {(exclude_dims - signature.all_core_dims)} a core dimension" + ) + + # handle dask_gufunc_kwargs + if dask == "parallelized": + if dask_gufunc_kwargs is None: + dask_gufunc_kwargs = {} + else: + dask_gufunc_kwargs = dask_gufunc_kwargs.copy() + # todo: remove warnings after deprecation cycle + if meta is not None: + warnings.warn( + "``meta`` should be given in the ``dask_gufunc_kwargs`` parameter." + " It will be removed as direct parameter in a future version.", + FutureWarning, + stacklevel=2, + ) + dask_gufunc_kwargs.setdefault("meta", meta) + if output_sizes is not None: + warnings.warn( + "``output_sizes`` should be given in the ``dask_gufunc_kwargs`` " + "parameter. It will be removed as direct parameter in a future " + "version.", + FutureWarning, + stacklevel=2, + ) + dask_gufunc_kwargs.setdefault("output_sizes", output_sizes) if kwargs: func = functools.partial(func, **kwargs) - if vectorize: - if meta is None: - # set meta=np.ndarray by default for numpy vectorized functions - # work around dask bug computing meta with vectorized functions: GH5642 - meta = np.ndarray - - if signature.all_core_dims: - func = np.vectorize( - func, otypes=output_dtypes, signature=signature.to_gufunc_string() - ) - else: - func = np.vectorize(func, otypes=output_dtypes) - variables_vfunc = functools.partial( apply_variable_ufunc, func, @@ -1024,11 +1089,12 @@ def earth_mover_distance(first_samples, exclude_dims=exclude_dims, keep_attrs=keep_attrs, dask=dask, + vectorize=vectorize, output_dtypes=output_dtypes, - output_sizes=output_sizes, - meta=meta, + dask_gufunc_kwargs=dask_gufunc_kwargs, ) + # feed groupby-apply_ufunc through apply_groupby_func if any(isinstance(a, GroupBy) for a in args): this_apply = functools.partial( apply_ufunc, @@ -1041,9 +1107,12 @@ def earth_mover_distance(first_samples, dataset_fill_value=dataset_fill_value, keep_attrs=keep_attrs, dask=dask, - meta=meta, + vectorize=vectorize, + output_dtypes=output_dtypes, + dask_gufunc_kwargs=dask_gufunc_kwargs, ) return apply_groupby_func(this_apply, *args) + # feed datasets apply_variable_ufunc through apply_dataset_vfunc elif any(is_dict_like(a) for a in args): return apply_dataset_vfunc( variables_vfunc, @@ -1055,6 +1124,7 @@ def earth_mover_distance(first_samples, fill_value=dataset_fill_value, keep_attrs=keep_attrs, ) + # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc elif any(isinstance(a, DataArray) for a in args): return apply_dataarray_vfunc( variables_vfunc, @@ -1064,30 +1134,230 @@ def earth_mover_distance(first_samples, exclude_dims=exclude_dims, keep_attrs=keep_attrs, ) + # feed Variables directly through apply_variable_ufunc elif any(isinstance(a, Variable) for a in args): return variables_vfunc(*args) else: + # feed anything else through apply_array_ufunc return apply_array_ufunc(func, *args, dask=dask) +def cov(da_a, da_b, dim=None, ddof=1): + """ + Compute covariance between two DataArray objects along a shared dimension. + + Parameters + ---------- + da_a : DataArray + Array to compute. + da_b : DataArray + Array to compute. + dim : str, optional + The dimension along which the covariance will be computed + ddof : int, optional + If ddof=1, covariance is normalized by N-1, giving an unbiased estimate, + else normalization is by N. + + Returns + ------- + covariance : DataArray + + See also + -------- + pandas.Series.cov : corresponding pandas function + xarray.corr: respective function to calculate correlation + + Examples + -------- + >>> from xarray import DataArray + >>> da_a = DataArray( + ... np.array([[1, 2, 3], [0.1, 0.2, 0.3], [3.2, 0.6, 1.8]]), + ... dims=("space", "time"), + ... coords=[ + ... ("space", ["IA", "IL", "IN"]), + ... ("time", pd.date_range("2000-01-01", freq="1D", periods=3)), + ... ], + ... ) + >>> da_a + + array([[1. , 2. , 3. ], + [0.1, 0.2, 0.3], + [3.2, 0.6, 1.8]]) + Coordinates: + * space (space) >> da_b = DataArray( + ... np.array([[0.2, 0.4, 0.6], [15, 10, 5], [3.2, 0.6, 1.8]]), + ... dims=("space", "time"), + ... coords=[ + ... ("space", ["IA", "IL", "IN"]), + ... ("time", pd.date_range("2000-01-01", freq="1D", periods=3)), + ... ], + ... ) + >>> da_b + + array([[ 0.2, 0.4, 0.6], + [15. , 10. , 5. ], + [ 3.2, 0.6, 1.8]]) + Coordinates: + * space (space) >> xr.cov(da_a, da_b) + + array(-3.53055556) + >>> xr.cov(da_a, da_b, dim="time") + + array([ 0.2 , -0.5 , 1.69333333]) + Coordinates: + * space (space) >> from xarray import DataArray + >>> da_a = DataArray( + ... np.array([[1, 2, 3], [0.1, 0.2, 0.3], [3.2, 0.6, 1.8]]), + ... dims=("space", "time"), + ... coords=[ + ... ("space", ["IA", "IL", "IN"]), + ... ("time", pd.date_range("2000-01-01", freq="1D", periods=3)), + ... ], + ... ) + >>> da_a + + array([[1. , 2. , 3. ], + [0.1, 0.2, 0.3], + [3.2, 0.6, 1.8]]) + Coordinates: + * space (space) >> da_b = DataArray( + ... np.array([[0.2, 0.4, 0.6], [15, 10, 5], [3.2, 0.6, 1.8]]), + ... dims=("space", "time"), + ... coords=[ + ... ("space", ["IA", "IL", "IN"]), + ... ("time", pd.date_range("2000-01-01", freq="1D", periods=3)), + ... ], + ... ) + >>> da_b + + array([[ 0.2, 0.4, 0.6], + [15. , 10. , 5. ], + [ 3.2, 0.6, 1.8]]) + Coordinates: + * space (space) >> xr.corr(da_a, da_b) + + array(-0.57087777) + >>> xr.corr(da_a, da_b, dim="time") + + array([ 1., -1., 1.]) + Coordinates: + * space (space) array([[[ 0, 1], [ 2, 3]], + [[ 4, 5], [ 6, 7]], + [[ 8, 9], [10, 11]]]) Dimensions without coordinates: a, b, c @@ -1192,10 +1464,10 @@ def dot(*arrays, dims=None, **kwargs): # construct einsum subscripts, such as '...abc,...ab->...c' # Note: input_core_dims are always moved to the last position subscripts_list = [ - "..." + "".join([dim_map[d] for d in ds]) for ds in input_core_dims + "..." + "".join(dim_map[d] for d in ds) for ds in input_core_dims ] subscripts = ",".join(subscripts_list) - subscripts += "->..." + "".join([dim_map[d] for d in output_core_dims[0]]) + subscripts += "->..." + "".join(dim_map[d] for d in output_core_dims[0]) join = OPTIONS["arithmetic_join"] # using "inner" emulates `(a * b).sum()` for all joins (except "exact") @@ -1221,22 +1493,24 @@ def where(cond, x, y): Performs xarray-like broadcasting across input arguments. + All dimension coordinates on `x` and `y` must be aligned with each + other and with `cond`. + + Parameters ---------- - cond : scalar, array, Variable, DataArray or Dataset with boolean dtype + cond : scalar, array, Variable, DataArray or Dataset When True, return values from `x`, otherwise returns values from `y`. x : scalar, array, Variable, DataArray or Dataset values to choose from where `cond` is True y : scalar, array, Variable, DataArray or Dataset values to choose from where `cond` is False - All dimension coordinates on these objects must be aligned with each - other and with `cond`. - Returns ------- - In priority order: Dataset, DataArray, Variable or array, whichever - type appears as an input argument. + Dataset, DataArray, Variable or array + In priority order: Dataset, DataArray, Variable or array, whichever + type appears as an input argument. Examples -------- @@ -1252,13 +1526,13 @@ def where(cond, x, y): array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]) Coordinates: - * lat (lat) int64 0 1 2 3 4 5 6 7 8 9 + * lat (lat) int64 0 1 2 3 4 5 6 7 8 9 >>> xr.where(x < 0.5, x, x * 100) array([ 0. , 0.1, 0.2, 0.3, 0.4, 50. , 60. , 70. , 80. , 90. ]) Coordinates: - * lat (lat) int64 0 1 2 3 4 5 6 7 8 9 + * lat (lat) int64 0 1 2 3 4 5 6 7 8 9 >>> y = xr.DataArray( ... 0.1 * np.arange(9).reshape(3, 3), @@ -1272,8 +1546,8 @@ def where(cond, x, y): [0.3, 0.4, 0.5], [0.6, 0.7, 0.8]]) Coordinates: - * lat (lat) int64 0 1 2 - * lon (lon) int64 10 11 12 + * lat (lat) int64 0 1 2 + * lon (lon) int64 10 11 12 >>> xr.where(y.lat < 1, y, -1) @@ -1281,8 +1555,8 @@ def where(cond, x, y): [-1. , -1. , -1. ], [-1. , -1. , -1. ]]) Coordinates: - * lat (lat) int64 0 1 2 - * lon (lon) int64 10 11 12 + * lat (lat) int64 0 1 2 + * lon (lon) int64 10 11 12 >>> cond = xr.DataArray([True, False], dims=["x"]) >>> x = xr.DataArray([1, 2], dims=["y"]) @@ -1318,7 +1592,7 @@ def polyval(coord, coeffs, degree_dim="degree"): The 1D coordinate along which to evaluate the polynomial. coeffs : DataArray Coefficients of the polynomials. - degree_dim : str, default "degree" + degree_dim : str, default: "degree" Name of the polynomial degree dimension in `coeffs`. See also @@ -1329,7 +1603,7 @@ def polyval(coord, coeffs, degree_dim="degree"): from .dataarray import DataArray from .missing import get_clean_interp_index - x = get_clean_interp_index(coord, coord.name) + x = get_clean_interp_index(coord, coord.name, strict=False) deg_coord = coeffs[degree_dim] @@ -1380,24 +1654,24 @@ def _calc_idxminmax( # This will run argmin or argmax. indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna) - # Get the coordinate we want. - coordarray = array[dim] - # Handle dask arrays. - if isinstance(array, dask_array_type): - res = dask_array.map_blocks(coordarray, indx, dtype=indx.dtype) + if is_duck_dask_array(array.data): + import dask.array + + chunks = dict(zip(array.dims, array.chunks)) + dask_coord = dask.array.from_array(array[dim].data, chunks=chunks[dim]) + res = indx.copy(data=dask_coord[indx.data.ravel()].reshape(indx.shape)) + # we need to attach back the dim name + res.name = dim else: - res = coordarray[ - indx, - ] + res = array[dim][(indx,)] + # The dim is gone but we need to remove the corresponding coordinate. + del res.coords[dim] if skipna or (skipna is None and array.dtype.kind in na_dtypes): # Put the NaN values back in after removing them res = res.where(~allna, fill_value) - # The dim is gone but we need to remove the corresponding coordinate. - del res.coords[dim] - # Copy attributes from argmin/argmax, if any res.attrs = indx.attrs diff --git a/xarray/core/concat.py b/xarray/core/concat.py index 7741cbb826b..5cda5aa903c 100644 --- a/xarray/core/concat.py +++ b/xarray/core/concat.py @@ -1,3 +1,16 @@ +from typing import ( + TYPE_CHECKING, + Dict, + Hashable, + Iterable, + List, + Optional, + Set, + Tuple, + Union, + overload, +) + import pandas as pd from . import dtypes, utils @@ -7,6 +20,40 @@ from .variable import IndexVariable, Variable, as_variable from .variable import concat as concat_vars +if TYPE_CHECKING: + from .dataarray import DataArray + from .dataset import Dataset + + +@overload +def concat( + objs: Iterable["Dataset"], + dim: Union[str, "DataArray", pd.Index], + data_vars: Union[str, List[str]] = "all", + coords: Union[str, List[str]] = "different", + compat: str = "equals", + positions: Optional[Iterable[int]] = None, + fill_value: object = dtypes.NA, + join: str = "outer", + combine_attrs: str = "override", +) -> "Dataset": + ... + + +@overload +def concat( + objs: Iterable["DataArray"], + dim: Union[str, "DataArray", pd.Index], + data_vars: Union[str, List[str]] = "all", + coords: Union[str, List[str]] = "different", + compat: str = "equals", + positions: Optional[Iterable[int]] = None, + fill_value: object = dtypes.NA, + join: str = "outer", + combine_attrs: str = "override", +) -> "DataArray": + ... + def concat( objs, @@ -23,7 +70,7 @@ def concat( Parameters ---------- - objs : sequence of Dataset and DataArray objects + objs : sequence of Dataset and DataArray xarray objects to concatenate together. Each object is expected to consist of variables and coordinates with matching shapes except for along the concatenated dimension. @@ -34,74 +81,76 @@ def concat( unchanged. If dimension is provided as a DataArray or Index, its name is used as the dimension to concatenate along and the values are added as a coordinate. - data_vars : {'minimal', 'different', 'all' or list of str}, optional + data_vars : {"minimal", "different", "all"} or list of str, optional These data variables will be concatenated together: - * 'minimal': Only data variables in which the dimension already + * "minimal": Only data variables in which the dimension already appears are included. - * 'different': Data variables which are not equal (ignoring + * "different": Data variables which are not equal (ignoring attributes) across all datasets are also concatenated (as well as all for which dimension already appears). Beware: this option may load the data payload of data variables into memory if they are not already loaded. - * 'all': All data variables will be concatenated. + * "all": All data variables will be concatenated. * list of str: The listed data variables will be concatenated, in - addition to the 'minimal' data variables. + addition to the "minimal" data variables. - If objects are DataArrays, data_vars must be 'all'. - coords : {'minimal', 'different', 'all' or list of str}, optional + If objects are DataArrays, data_vars must be "all". + coords : {"minimal", "different", "all"} or list of str, optional These coordinate variables will be concatenated together: - * 'minimal': Only coordinates in which the dimension already appears + * "minimal": Only coordinates in which the dimension already appears are included. - * 'different': Coordinates which are not equal (ignoring attributes) + * "different": Coordinates which are not equal (ignoring attributes) across all datasets are also concatenated (as well as all for which dimension already appears). Beware: this option may load the data payload of coordinate variables into memory if they are not already loaded. - * 'all': All coordinate variables will be concatenated, except + * "all": All coordinate variables will be concatenated, except those corresponding to other dimensions. * list of str: The listed coordinate variables will be concatenated, - in addition to the 'minimal' coordinates. - compat : {'identical', 'equals', 'broadcast_equals', 'no_conflicts', 'override'}, optional + in addition to the "minimal" coordinates. + compat : {"identical", "equals", "broadcast_equals", "no_conflicts", "override"}, optional String indicating how to compare non-concatenated variables of the same name for potential conflicts. This is passed down to merge. - - 'broadcast_equals': all values must be equal when variables are + - "broadcast_equals": all values must be equal when variables are broadcast against each other to ensure common dimensions. - - 'equals': all values and dimensions must be the same. - - 'identical': all values, dimensions and attributes must be the + - "equals": all values and dimensions must be the same. + - "identical": all values, dimensions and attributes must be the same. - - 'no_conflicts': only values which are not null in both datasets + - "no_conflicts": only values which are not null in both datasets must be equal. The returned dataset then contains the combination of all non-null values. - - 'override': skip comparing and pick variable from first dataset + - "override": skip comparing and pick variable from first dataset positions : None or list of integer arrays, optional List of integer arrays which specifies the integer positions to which to assign each dataset along the concatenated dimension. If not supplied, objects are concatenated in the provided order. - fill_value : scalar, optional - Value to use for newly missing values - join : {'outer', 'inner', 'left', 'right', 'exact'}, optional + fill_value : scalar or dict-like, optional + Value to use for newly missing values. If a dict-like, maps + variable names to fill values. Use a data array's name to + refer to its values. + join : {"outer", "inner", "left", "right", "exact"}, optional String indicating how to combine differing indexes (excluding dim) in objects - - 'outer': use the union of object indexes - - 'inner': use the intersection of object indexes - - 'left': use indexes from the first object with each dimension - - 'right': use indexes from the last object with each dimension - - 'exact': instead of aligning, raise `ValueError` when indexes to be + - "outer": use the union of object indexes + - "inner": use the intersection of object indexes + - "left": use indexes from the first object with each dimension + - "right": use indexes from the last object with each dimension + - "exact": instead of aligning, raise `ValueError` when indexes to be aligned are not equal - - 'override': if indexes are of same size, rewrite indexes to be + - "override": if indexes are of same size, rewrite indexes to be those of the first object with that dimension. Indexes for the same dimension must have the same size in all objects. - combine_attrs : {'drop', 'identical', 'no_conflicts', 'override'}, - default 'override + combine_attrs : {"drop", "identical", "no_conflicts", "override"}, \ + default: "override" String indicating how to combine attrs of the objects being merged: - - 'drop': empty attrs on returned Dataset. - - 'identical': all attrs must be the same on every object. - - 'no_conflicts': attrs from all objects are combined, any that have + - "drop": empty attrs on returned Dataset. + - "identical": all attrs must be the same on every object. + - "no_conflicts": attrs from all objects are combined, any that have the same name must also have the same value. - - 'override': skip comparing and copy attrs from the first dataset to + - "override": skip comparing and copy attrs from the first dataset to the result. Returns @@ -111,13 +160,59 @@ def concat( See also -------- merge - auto_combine + + Examples + -------- + >>> da = xr.DataArray( + ... np.arange(6).reshape(2, 3), [("x", ["a", "b"]), ("y", [10, 20, 30])] + ... ) + >>> da + + array([[0, 1, 2], + [3, 4, 5]]) + Coordinates: + * x (x) >> xr.concat([da.isel(y=slice(0, 1)), da.isel(y=slice(1, None))], dim="y") + + array([[0, 1, 2], + [3, 4, 5]]) + Coordinates: + * x (x) >> xr.concat([da.isel(x=0), da.isel(x=1)], "x") + + array([[0, 1, 2], + [3, 4, 5]]) + Coordinates: + * x (x) >> xr.concat([da.isel(x=0), da.isel(x=1)], "new_dim") + + array([[0, 1, 2], + [3, 4, 5]]) + Coordinates: + x (new_dim) >> xr.concat([da.isel(x=0), da.isel(x=1)], pd.Index([-90, -100], name="new_dim")) + + array([[0, 1, 2], + [3, 4, 5]]) + Coordinates: + x (new_dim) Tuple[Dict[Hashable, Variable], Dict[Hashable, int], Set[Hashable], Set[Hashable]]: - dims = set() - all_coord_names = set() - data_vars = set() # list of data_vars - dim_coords = {} # maps dim name to variable - dims_sizes = {} # shared dimension sizes to expand variables + dims: Set[Hashable] = set() + all_coord_names: Set[Hashable] = set() + data_vars: Set[Hashable] = set() # list of data_vars + dim_coords: Dict[Hashable, Variable] = {} # maps dim name to variable + dims_sizes: Dict[Hashable, int] = {} # shared dimension sizes to expand variables for ds in datasets: dims_sizes.update(ds.dims) all_coord_names.update(ds.coords) data_vars.update(ds.data_vars) - for dim in set(ds.dims) - dims: + # preserves ordering of dimensions + for dim in ds.dims: + if dim in dims: + continue + if dim not in dim_coords: dim_coords[dim] = ds.coords[dim].variable dims = dims | set(ds.dims) @@ -307,16 +408,16 @@ def _parse_datasets(datasets): def _dataset_concat( - datasets, - dim, - data_vars, - coords, - compat, - positions, - fill_value=dtypes.NA, - join="outer", - combine_attrs="override", -): + datasets: List["Dataset"], + dim: Union[str, "DataArray", pd.Index], + data_vars: Union[str, List[str]], + coords: Union[str, List[str]], + compat: str, + positions: Optional[Iterable[int]], + fill_value: object = dtypes.NA, + join: str = "outer", + combine_attrs: str = "override", +) -> "Dataset": """ Concatenate a sequence of datasets along a new or existing dimension """ @@ -325,8 +426,8 @@ def _dataset_concat( dim, coord = _calc_concat_dim_coord(dim) # Make sure we're working on a copy (we'll be loading variables) datasets = [ds.copy() for ds in datasets] - datasets = align( - *datasets, join=join, copy=False, exclude=[dim], fill_value=fill_value + datasets = list( + align(*datasets, join=join, copy=False, exclude=[dim], fill_value=fill_value) ) dim_coords, dims_sizes, coord_names, data_names = _parse_datasets(datasets) @@ -356,7 +457,9 @@ def _dataset_concat( result_vars = {} if variables_to_merge: - to_merge = {var: [] for var in variables_to_merge} + to_merge: Dict[Hashable, List[Variable]] = { + var: [] for var in variables_to_merge + } for ds in datasets: for var in variables_to_merge: @@ -400,12 +503,15 @@ def ensure_common_dims(vars): for k in datasets[0].variables: if k in concat_over: try: - vars = ensure_common_dims([ds.variables[k] for ds in datasets]) + vars = ensure_common_dims([ds[k].variable for ds in datasets]) except KeyError: raise ValueError("%r is not present in all datasets." % k) combined = concat_vars(vars, dim, positions) assert isinstance(combined, Variable) result_vars[k] = combined + elif k in result_vars: + # preserves original variable order + result_vars[k] = result_vars.pop(k) result = Dataset(result_vars, attrs=result_attrs) absent_coord_names = coord_names - set(result.variables) @@ -427,16 +533,16 @@ def ensure_common_dims(vars): def _dataarray_concat( - arrays, - dim, - data_vars, - coords, - compat, - positions, - fill_value=dtypes.NA, - join="outer", - combine_attrs="override", -): + arrays: Iterable["DataArray"], + dim: Union[str, "DataArray", pd.Index], + data_vars: Union[str, List[str]], + coords: Union[str, List[str]], + compat: str, + positions: Optional[Iterable[int]], + fill_value: object = dtypes.NA, + join: str = "outer", + combine_attrs: str = "override", +) -> "DataArray": arrays = list(arrays) if data_vars != "all": diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 83c4d2a8636..37c462f79f4 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -214,9 +214,10 @@ def __getitem__(self, key: Hashable) -> "DataArray": return cast("DataArray", self._data[key]) def to_dataset(self) -> "Dataset": - """Convert these coordinates into a new Dataset - """ - return self._data._copy_listed(self._names) + """Convert these coordinates into a new Dataset""" + + names = [name for name in self._data._variables if name in self._names] + return self._data._copy_listed(names) def _update_coords( self, coords: Dict[Hashable, Variable], indexes: Mapping[Hashable, pd.Index] @@ -324,29 +325,6 @@ def _ipython_key_completions_(self): return self._data._ipython_key_completions_() -class LevelCoordinatesSource(Mapping[Hashable, Any]): - """Iterator for MultiIndex level coordinates. - - Used for attribute style lookup with AttrAccessMixin. Not returned directly - by any public methods. - """ - - __slots__ = ("_data",) - - def __init__(self, data_object: "Union[DataArray, Dataset]"): - self._data = data_object - - def __getitem__(self, key): - # not necessary -- everything here can already be found in coords. - raise KeyError() - - def __iter__(self) -> Iterator[Hashable]: - return iter(self._data._level_coords) - - def __len__(self) -> int: - return len(self._data._level_coords) - - def assert_coordinate_consistent( obj: Union["DataArray", "Dataset"], coords: Mapping[Hashable, Variable] ) -> None: diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py index 94c50d90e84..ce15e01fb12 100644 --- a/xarray/core/dask_array_compat.py +++ b/xarray/core/dask_array_compat.py @@ -4,8 +4,6 @@ import numpy as np -from .pycompat import dask_array_type - try: import dask.array as da from dask import __version__ as dask_version @@ -13,95 +11,9 @@ dask_version = "0.0.0" da = None -if LooseVersion(dask_version) >= LooseVersion("2.0.0"): - meta_from_array = da.utils.meta_from_array -else: - # Copied from dask v2.4.0 - # Used under the terms of Dask's license, see licenses/DASK_LICENSE. - import numbers - - def meta_from_array(x, ndim=None, dtype=None): - """ Normalize an array to appropriate meta object - - Parameters - ---------- - x: array-like, callable - Either an object that looks sufficiently like a Numpy array, - or a callable that accepts shape and dtype keywords - ndim: int - Number of dimensions of the array - dtype: Numpy dtype - A valid input for ``np.dtype`` - - Returns - ------- - array-like with zero elements of the correct dtype - """ - # If using x._meta, x must be a Dask Array, some libraries (e.g. zarr) - # implement a _meta attribute that are incompatible with Dask Array._meta - if hasattr(x, "_meta") and isinstance(x, dask_array_type): - x = x._meta - - if dtype is None and x is None: - raise ValueError("You must specify the meta or dtype of the array") - - if np.isscalar(x): - x = np.array(x) - - if x is None: - x = np.ndarray - - if isinstance(x, type): - x = x(shape=(0,) * (ndim or 0), dtype=dtype) - - if ( - not hasattr(x, "shape") - or not hasattr(x, "dtype") - or not isinstance(x.shape, tuple) - ): - return x - - if isinstance(x, list) or isinstance(x, tuple): - ndims = [ - 0 - if isinstance(a, numbers.Number) - else a.ndim - if hasattr(a, "ndim") - else len(a) - for a in x - ] - a = [a if nd == 0 else meta_from_array(a, nd) for a, nd in zip(x, ndims)] - return a if isinstance(x, list) else tuple(x) - - if ndim is None: - ndim = x.ndim - - try: - meta = x[tuple(slice(0, 0, None) for _ in range(x.ndim))] - if meta.ndim != ndim: - if ndim > x.ndim: - meta = meta[ - (Ellipsis,) + tuple(None for _ in range(ndim - meta.ndim)) - ] - meta = meta[tuple(slice(0, 0, None) for _ in range(meta.ndim))] - elif ndim == 0: - meta = meta.sum() - else: - meta = meta.reshape((0,) * ndim) - except Exception: - meta = np.empty((0,) * ndim, dtype=dtype or x.dtype) - - if np.isscalar(meta): - meta = np.array(meta) - - if dtype and meta.dtype != dtype: - meta = meta.astype(dtype) - - return meta - def _validate_pad_output_shape(input_shape, pad_width, output_shape): - """ Validates the output shape of dask.array.pad, raising a RuntimeError if they do not match. + """Validates the output shape of dask.array.pad, raising a RuntimeError if they do not match. In the current versions of dask (2.2/2.4), dask.array.pad with mode='reflect' sometimes returns an invalid shape. """ @@ -114,7 +26,7 @@ def _validate_pad_output_shape(input_shape, pad_width, output_shape): elif ( len(pad_width) == len(input_shape) and all(map(lambda x: len(x) == 2, pad_width)) - and all((isint(i) for p in pad_width for i in p)) + and all(isint(i) for p in pad_width for i in p) ): pad_width = np.sum(pad_width, axis=1) else: @@ -146,43 +58,6 @@ def pad(array, pad_width, mode="constant", **kwargs): return padded -if LooseVersion(dask_version) >= LooseVersion("2.8.1"): - median = da.median -else: - # Copied from dask v2.8.1 - # Used under the terms of Dask's license, see licenses/DASK_LICENSE. - def median(a, axis=None, keepdims=False): - """ - This works by automatically chunking the reduced axes to a single chunk - and then calling ``numpy.median`` function across the remaining dimensions - """ - - if axis is None: - raise NotImplementedError( - "The da.median function only works along an axis. " - "The full algorithm is difficult to do in parallel" - ) - - if not isinstance(axis, Iterable): - axis = (axis,) - - axis = [ax + a.ndim if ax < 0 else ax for ax in axis] - - a = a.rechunk({ax: -1 if ax in axis else "auto" for ax in range(a.ndim)}) - - result = a.map_blocks( - np.median, - axis=axis, - keepdims=keepdims, - drop_axis=axis if not keepdims else None, - chunks=[1 if ax in axis else c for ax, c in enumerate(a.chunks)] - if keepdims - else None, - ) - - return result - - if LooseVersion(dask_version) > LooseVersion("2.9.0"): nanmedian = da.nanmedian else: @@ -206,8 +81,9 @@ def nanmedian(a, axis=None, keepdims=False): a = a.rechunk({ax: -1 if ax in axis else "auto" for ax in range(a.ndim)}) - result = a.map_blocks( + result = da.map_blocks( np.nanmedian, + a, axis=axis, keepdims=keepdims, drop_axis=axis if not keepdims else None, diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 87f646352eb..15641506e4e 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -4,8 +4,7 @@ def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1): - """Wrapper to apply bottleneck moving window funcs on dask arrays - """ + """Wrapper to apply bottleneck moving window funcs on dask arrays""" import dask.array as da dtype, fill_value = dtypes.maybe_promote(a.dtype) @@ -19,8 +18,8 @@ def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1): # Create overlap array. ag = da.overlap.overlap(a, depth=depth, boundary=boundary) # apply rolling func - out = ag.map_blocks( - moving_func, window, min_count=min_count, axis=axis, dtype=a.dtype + out = da.map_blocks( + moving_func, ag, window, min_count=min_count, axis=axis, dtype=a.dtype ) # trim array result = da.overlap.trim_internal(out, depth) @@ -28,73 +27,89 @@ def dask_rolling_wrapper(moving_func, a, window, min_count=None, axis=-1): def rolling_window(a, axis, window, center, fill_value): - """Dask's equivalence to np.utils.rolling_window - """ + """Dask's equivalence to np.utils.rolling_window""" import dask.array as da + if not hasattr(axis, "__len__"): + axis = [axis] + window = [window] + center = [center] + orig_shape = a.shape - if axis < 0: - axis = a.ndim + axis depth = {d: 0 for d in range(a.ndim)} - depth[axis] = int(window / 2) - # For evenly sized window, we need to crop the first point of each block. - offset = 1 if window % 2 == 0 else 0 - - if depth[axis] > min(a.chunks[axis]): - raise ValueError( - "For window size %d, every chunk should be larger than %d, " - "but the smallest chunk size is %d. Rechunk your array\n" - "with a larger chunk size or a chunk size that\n" - "more evenly divides the shape of your array." - % (window, depth[axis], min(a.chunks[axis])) - ) - - # Although da.overlap pads values to boundaries of the array, - # the size of the generated array is smaller than what we want - # if center == False. - if center: - start = int(window / 2) # 10 -> 5, 9 -> 4 - end = window - 1 - start - else: - start, end = window - 1, 0 - pad_size = max(start, end) + offset - depth[axis] - drop_size = 0 - # pad_size becomes more than 0 when the overlapped array is smaller than - # needed. In this case, we need to enlarge the original array by padding - # before overlapping. - if pad_size > 0: - if pad_size < depth[axis]: - # overlapping requires each chunk larger than depth. If pad_size is - # smaller than the depth, we enlarge this and truncate it later. - drop_size = depth[axis] - pad_size - pad_size = depth[axis] - shape = list(a.shape) - shape[axis] = pad_size - chunks = list(a.chunks) - chunks[axis] = (pad_size,) - fill_array = da.full(shape, fill_value, dtype=a.dtype, chunks=chunks) - a = da.concatenate([fill_array, a], axis=axis) - + offset = [0] * a.ndim + drop_size = [0] * a.ndim + pad_size = [0] * a.ndim + for ax, win, cent in zip(axis, window, center): + if ax < 0: + ax = a.ndim + ax + depth[ax] = int(win / 2) + # For evenly sized window, we need to crop the first point of each block. + offset[ax] = 1 if win % 2 == 0 else 0 + + if depth[ax] > min(a.chunks[ax]): + raise ValueError( + "For window size %d, every chunk should be larger than %d, " + "but the smallest chunk size is %d. Rechunk your array\n" + "with a larger chunk size or a chunk size that\n" + "more evenly divides the shape of your array." + % (win, depth[ax], min(a.chunks[ax])) + ) + + # Although da.overlap pads values to boundaries of the array, + # the size of the generated array is smaller than what we want + # if center == False. + if cent: + start = int(win / 2) # 10 -> 5, 9 -> 4 + end = win - 1 - start + else: + start, end = win - 1, 0 + pad_size[ax] = max(start, end) + offset[ax] - depth[ax] + drop_size[ax] = 0 + # pad_size becomes more than 0 when the overlapped array is smaller than + # needed. In this case, we need to enlarge the original array by padding + # before overlapping. + if pad_size[ax] > 0: + if pad_size[ax] < depth[ax]: + # overlapping requires each chunk larger than depth. If pad_size is + # smaller than the depth, we enlarge this and truncate it later. + drop_size[ax] = depth[ax] - pad_size[ax] + pad_size[ax] = depth[ax] + + # TODO maybe following two lines can be summarized. + a = da.pad( + a, [(p, 0) for p in pad_size], mode="constant", constant_values=fill_value + ) boundary = {d: fill_value for d in range(a.ndim)} # create overlap arrays ag = da.overlap.overlap(a, depth=depth, boundary=boundary) - # apply rolling func - def func(x, window, axis=-1): + def func(x, window, axis): x = np.asarray(x) - rolling = nputils._rolling_window(x, window, axis) - return rolling[(slice(None),) * axis + (slice(offset, None),)] - - chunks = list(a.chunks) - chunks.append(window) - out = ag.map_blocks( - func, dtype=a.dtype, new_axis=a.ndim, chunks=chunks, window=window, axis=axis + index = [slice(None)] * x.ndim + for ax, win in zip(axis, window): + x = nputils._rolling_window(x, win, ax) + index[ax] = slice(offset[ax], None) + return x[tuple(index)] + + chunks = list(a.chunks) + window + new_axis = [a.ndim + i for i in range(len(axis))] + out = da.map_blocks( + func, + ag, + dtype=a.dtype, + new_axis=new_axis, + chunks=chunks, + window=window, + axis=axis, ) # crop boundary. - index = (slice(None),) * axis + (slice(drop_size, drop_size + orig_shape[axis]),) - return out[index] + index = [slice(None)] * a.ndim + for ax in axis: + index[ax] = slice(drop_size[ax], drop_size[ax] + orig_shape[ax]) + return out[tuple(index)] def least_squares(lhs, rhs, rcond=None, skipna=False): @@ -120,5 +135,7 @@ def least_squares(lhs, rhs, rcond=None, skipna=False): coeffs = coeffs.reshape(coeffs.shape[0]) residuals = residuals.reshape(residuals.shape[0]) else: + # Residuals here are (1, 1) but should be (K,) as rhs is (N, K) + # See issue dask/dask#6516 coeffs, residuals, _, _ = da.linalg.lstsq(lhs_da, rhs) return coeffs, residuals diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index c9e93dec26e..db01121b9da 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -46,7 +46,6 @@ from .common import AbstractArray, DataWithCoords from .coordinates import ( DataArrayCoordinates, - LevelCoordinatesSource, assert_coordinate_consistent, remap_label_indexers, ) @@ -54,9 +53,15 @@ from .formatting import format_item from .indexes import Indexes, default_indexes, propagate_indexes from .indexing import is_fancy_indexer -from .merge import PANDAS_TYPES, _extract_indexes_from_coords -from .options import OPTIONS -from .utils import Default, ReprObject, _check_inplace, _default, either_dict_or_kwargs +from .merge import PANDAS_TYPES, MergeError, _extract_indexes_from_coords +from .options import OPTIONS, _get_keep_attrs +from .utils import ( + Default, + HybridMappingProxy, + ReprObject, + _default, + either_dict_or_kwargs, +) from .variable import ( IndexVariable, Variable, @@ -196,7 +201,7 @@ def __getitem__(self, key) -> "DataArray": # expand the indexer so we can handle Ellipsis labels = indexing.expanded_indexer(key, self.data_array.ndim) key = dict(zip(self.data_array.dims, labels)) - return self.data_array.sel(**key) + return self.data_array.sel(key) def __setitem__(self, key, value) -> None: if not utils.is_dict_like(key): @@ -216,27 +221,125 @@ def __setitem__(self, key, value) -> None: class DataArray(AbstractArray, DataWithCoords): """N-dimensional array with labeled coordinates and dimensions. - DataArray provides a wrapper around numpy ndarrays that uses labeled - dimensions and coordinates to support metadata aware operations. The API is - similar to that for the pandas Series or DataFrame, but DataArray objects - can have any number of dimensions, and their contents have fixed data - types. + DataArray provides a wrapper around numpy ndarrays that uses + labeled dimensions and coordinates to support metadata aware + operations. The API is similar to that for the pandas Series or + DataFrame, but DataArray objects can have any number of dimensions, + and their contents have fixed data types. Additional features over raw numpy arrays: - Apply operations over dimensions by name: ``x.sum('time')``. - - Select or assign values by integer location (like numpy): ``x[:10]`` - or by label (like pandas): ``x.loc['2014-01-01']`` or + - Select or assign values by integer location (like numpy): + ``x[:10]`` or by label (like pandas): ``x.loc['2014-01-01']`` or ``x.sel(time='2014-01-01')``. - - Mathematical operations (e.g., ``x - y``) vectorize across multiple - dimensions (known in numpy as "broadcasting") based on dimension names, - regardless of their original order. - - Keep track of arbitrary metadata in the form of a Python dictionary: - ``x.attrs`` + - Mathematical operations (e.g., ``x - y``) vectorize across + multiple dimensions (known in numpy as "broadcasting") based on + dimension names, regardless of their original order. + - Keep track of arbitrary metadata in the form of a Python + dictionary: ``x.attrs`` - Convert to a pandas Series: ``x.to_series()``. - Getting items from or doing mathematical operations with a DataArray - always returns another DataArray. + Getting items from or doing mathematical operations with a + DataArray always returns another DataArray. + + Parameters + ---------- + data : array_like + Values for this array. Must be an ``numpy.ndarray``, ndarray + like, or castable to an ``ndarray``. If a self-described xarray + or pandas object, attempts are made to use this array's + metadata to fill in other unspecified arguments. A view of the + array's data is used instead of a copy if possible. + coords : sequence or dict of array_like, optional + Coordinates (tick labels) to use for indexing along each + dimension. The following notations are accepted: + + - mapping {dimension name: array-like} + - sequence of tuples that are valid arguments for + ``xarray.Variable()`` + - (dims, data) + - (dims, data, attrs) + - (dims, data, attrs, encoding) + + Additionally, it is possible to define a coord whose name + does not match the dimension name, or a coord based on multiple + dimensions, with one of the following notations: + + - mapping {coord name: DataArray} + - mapping {coord name: Variable} + - mapping {coord name: (dimension name, array-like)} + - mapping {coord name: (tuple of dimension names, array-like)} + + dims : hashable or sequence of hashable, optional + Name(s) of the data dimension(s). Must be either a hashable + (only for 1D data) or a sequence of hashables with length equal + to the number of dimensions. If this argument is omitted, + dimension names default to ``['dim_0', ... 'dim_n']``. + name : str or None, optional + Name of this array. + attrs : dict_like or None, optional + Attributes to assign to the new instance. By default, an empty + attribute dictionary is initialized. + + Examples + -------- + Create data: + + >>> np.random.seed(0) + >>> temperature = 15 + 8 * np.random.randn(2, 2, 3) + >>> precipitation = 10 * np.random.rand(2, 2, 3) + >>> lon = [[-99.83, -99.32], [-99.79, -99.23]] + >>> lat = [[42.25, 42.21], [42.63, 42.59]] + >>> time = pd.date_range("2014-09-06", periods=3) + >>> reference_time = pd.Timestamp("2014-09-05") + + Initialize a dataarray with multiple dimensions: + + >>> da = xr.DataArray( + ... data=temperature, + ... dims=["x", "y", "time"], + ... coords=dict( + ... lon=(["x", "y"], lon), + ... lat=(["x", "y"], lat), + ... time=time, + ... reference_time=reference_time, + ... ), + ... attrs=dict( + ... description="Ambient temperature.", + ... units="degC", + ... ), + ... ) + >>> da + + array([[[29.11241877, 18.20125767, 22.82990387], + [32.92714559, 29.94046392, 7.18177696]], + + [[22.60070734, 13.78914233, 14.17424919], + [18.28478802, 16.15234857, 26.63418806]]]) + Coordinates: + lon (x, y) float64 -99.83 -99.32 -99.79 -99.23 + lat (x, y) float64 42.25 42.21 42.63 42.59 + * time (time) datetime64[ns] 2014-09-06 2014-09-07 2014-09-08 + reference_time datetime64[ns] 2014-09-05 + Dimensions without coordinates: x, y + Attributes: + description: Ambient temperature. + units: degC + + Find out where the coldest temperature was: + + >>> da.isel(da.argmin(...)) + + array(7.18177696) + Coordinates: + lon float64 -99.32 + lat float64 42.21 + time datetime64[ns] 2014-09-08 + reference_time datetime64[ns] 2014-09-05 + Attributes: + description: Ambient temperature. + units: degC """ _cache: Dict[str, Any] @@ -261,7 +364,7 @@ class DataArray(AbstractArray, DataWithCoords): _resample_cls = resample.DataArrayResample _weighted_cls = weighted.DataArrayWeighted - dt = property(CombinedDatetimelikeAccessor) + dt = utils.UncachedAccessor(CombinedDatetimelikeAccessor) def __init__( self, @@ -274,45 +377,6 @@ def __init__( indexes: Dict[Hashable, pd.Index] = None, fastpath: bool = False, ): - """ - Parameters - ---------- - data : array_like - Values for this array. Must be an ``numpy.ndarray``, ndarray like, - or castable to an ``ndarray``. If a self-described xarray or pandas - object, attempts are made to use this array's metadata to fill in - other unspecified arguments. A view of the array's data is used - instead of a copy if possible. - coords : sequence or dict of array_like objects, optional - Coordinates (tick labels) to use for indexing along each dimension. - The following notations are accepted: - - - mapping {dimension name: array-like} - - sequence of tuples that are valid arguments for xarray.Variable() - - (dims, data) - - (dims, data, attrs) - - (dims, data, attrs, encoding) - - Additionally, it is possible to define a coord whose name - does not match the dimension name, or a coord based on multiple - dimensions, with one of the following notations: - - - mapping {coord name: DataArray} - - mapping {coord name: Variable} - - mapping {coord name: (dimension name, array-like)} - - mapping {coord name: (tuple of dimension names, array-like)} - - dims : hashable or sequence of hashable, optional - Name(s) of the data dimension(s). Must be either a hashable (only - for 1D data) or a sequence of hashables with length equal to the - number of dimensions. If this argument is omitted, dimension names - default to ``['dim_0', ... 'dim_n']``. - name : str or None, optional - Name of this array. - attrs : dict_like or None, optional - Attributes to assign to the new instance. By default, an empty - attribute dictionary is initialized. - """ if fastpath: variable = data assert dims is None @@ -423,7 +487,7 @@ def _to_temp_dataset(self) -> Dataset: return self._to_dataset_whole(name=_THIS_ARRAY, shallow_copy=False) def _from_temp_dataset( - self, dataset: Dataset, name: Hashable = _default + self, dataset: Dataset, name: Union[Hashable, None, Default] = _default ) -> "DataArray": variable = dataset._variables.pop(_THIS_ARRAY) coords = dataset._variables @@ -441,7 +505,7 @@ def subset(dim, label): variables = {label: subset(dim, label) for label in self.get_index(dim)} variables.update({k: v for k, v in self._coords.items() if k != dim}) indexes = propagate_indexes(self._indexes, exclude=dim) - coord_names = set(self._coords) - set([dim]) + coord_names = set(self._coords) - {dim} dataset = Dataset._construct_direct( variables, coord_names, indexes=indexes, attrs=self.attrs ) @@ -493,7 +557,7 @@ def to_dataset( name : hashable, optional Name to substitute for this array's name. Only valid if ``dim`` is not provided. - promote_attrs : bool, default False + promote_attrs : bool, default: False Set to True to shallow copy attrs of DataArray to returned Dataset. Returns @@ -519,8 +583,7 @@ def to_dataset( @property def name(self) -> Optional[Hashable]: - """The name of this array. - """ + """The name of this array.""" return self._name @name.setter @@ -557,8 +620,7 @@ def __len__(self) -> int: @property def data(self) -> Any: - """The array's data as a dask or numpy array - """ + """The array's data as a dask or numpy array""" return self.variable.data @data.setter @@ -664,28 +726,27 @@ def __delitem__(self, key: Any) -> None: del self.coords[key] @property - def _attr_sources(self) -> List[Mapping[Hashable, Any]]: - """List of places to look-up items for attribute-style access - """ - return self._item_sources + [self.attrs] + def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]: + """Places to look-up items for attribute-style access""" + yield from self._item_sources + yield self.attrs @property - def _item_sources(self) -> List[Mapping[Hashable, Any]]: - """List of places to look-up items for key-completion - """ - return [ - self.coords, - {d: self.coords[d] for d in self.dims}, - LevelCoordinatesSource(self), - ] + def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]: + """Places to look-up items for key-completion""" + yield HybridMappingProxy(keys=self._coords, mapping=self.coords) + + # virtual coordinates + # uses empty dict -- everything here can already be found in self.coords. + yield HybridMappingProxy(keys=self.dims, mapping={}) + yield HybridMappingProxy(keys=self._level_coords, mapping={}) def __contains__(self, key: Any) -> bool: return key in self.data @property def loc(self) -> _LocIndexer: - """Attribute for location based indexing like pandas. - """ + """Attribute for location based indexing like pandas.""" return _LocIndexer(self) @property @@ -710,29 +771,26 @@ def encoding(self, value: Mapping[Hashable, Any]) -> None: @property def indexes(self) -> Indexes: - """Mapping of pandas.Index objects used for label based indexing - """ + """Mapping of pandas.Index objects used for label based indexing""" if self._indexes is None: self._indexes = default_indexes(self._coords, self.dims) return Indexes(self._indexes) @property def coords(self) -> DataArrayCoordinates: - """Dictionary-like container of coordinate arrays. - """ + """Dictionary-like container of coordinate arrays.""" return DataArrayCoordinates(self) def reset_coords( self, names: Union[Iterable[Hashable], Hashable, None] = None, drop: bool = False, - inplace: bool = None, ) -> Union[None, "DataArray", Dataset]: """Given names of coordinates, reset them to become variables. Parameters ---------- - names : hashable or iterable of hashables, optional + names : hashable or iterable of hashable, optional Name(s) of non-index coordinates in this dataset to reset into variables. By default, all non-index coordinates are reset. drop : bool, optional @@ -743,7 +801,6 @@ def reset_coords( ------- Dataset, or DataArray if ``drop == True`` """ - _check_inplace(inplace) if names is None: names = set(self.coords) - set(self.dims) dataset = self.coords.to_dataset().reset_coords(names, drop) @@ -806,11 +863,11 @@ def load(self, **kwargs) -> "DataArray": Parameters ---------- **kwargs : dict - Additional keyword arguments passed on to ``dask.array.compute``. + Additional keyword arguments passed on to ``dask.compute``. See Also -------- - dask.array.compute + dask.compute """ ds = self._to_temp_dataset().load(**kwargs) new = self._from_temp_dataset(ds) @@ -831,17 +888,17 @@ def compute(self, **kwargs) -> "DataArray": Parameters ---------- **kwargs : dict - Additional keyword arguments passed on to ``dask.array.compute``. + Additional keyword arguments passed on to ``dask.compute``. See Also -------- - dask.array.compute + dask.compute """ new = self.copy(deep=False) return new.load(**kwargs) def persist(self, **kwargs) -> "DataArray": - """ Trigger computation in constituent dask arrays + """Trigger computation in constituent dask arrays This keeps them as dask arrays but encourages them to keep data in memory. This is particularly useful when on a distributed machine. @@ -863,8 +920,8 @@ def copy(self, deep: bool = True, data: Any = None) -> "DataArray": """Returns a copy of this array. If `deep=True`, a deep copy is made of the data array. - Otherwise, a shallow copy is made, so each variable in the new - array's dataset is also a variable in this array's dataset. + Otherwise, a shallow copy is made, and the returned data array's + values are a new view of this data array's values. Use `data` to create a new object with the same structure as original but entirely new data. @@ -895,19 +952,19 @@ def copy(self, deep: bool = True, data: Any = None) -> "DataArray": array([1, 2, 3]) Coordinates: - * x (x) >> array_0 = array.copy(deep=False) >>> array_0[0] = 7 >>> array_0 array([7, 2, 3]) Coordinates: - * x (x) >> array array([7, 2, 3]) Coordinates: - * x (x) "DataArray": >>> array.copy(data=[0.1, 0.2, 0.3]) - array([ 0.1, 0.2, 0.3]) + array([0.1, 0.2, 0.3]) Coordinates: - * x (x) >> array - array([1, 2, 3]) + array([7, 2, 3]) Coordinates: - * x (x) Optional[Tuple[Tuple[int, ...], ...]]: def chunk( self, chunks: Union[ - None, Number, Tuple[Number, ...], Tuple[Tuple[Number, ...], ...], Mapping[Hashable, Union[None, Number, Tuple[Number, ...]]], - ] = None, + ] = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) name_prefix: str = "xarray-", token: str = None, lock: bool = False, @@ -980,7 +1036,7 @@ def chunk( Parameters ---------- - chunks : int, tuple or mapping, optional + chunks : int, tuple of int or mapping of hashable to int, optional Chunk sizes along each dimension, e.g., ``5``, ``(5, 5)`` or ``{'x': 5, 'y': 5}``. name_prefix : str, optional @@ -1025,10 +1081,10 @@ def isel( drop : bool, optional If ``drop=True``, drop coordinates variables indexed by integers instead of making them scalar. - missing_dims : {"raise", "warn", "ignore"}, default "raise" + missing_dims : {"raise", "warn", "ignore"}, default: "raise" What to do if dimensions that should be selected from are not present in the DataArray: - - "exception": raise an exception + - "raise": raise an exception - "warning": raise a warning, and ignore the missing dimensions - "ignore": ignore the missing dimensions **indexers_kwargs : {dim: indexer, ...}, optional @@ -1077,6 +1133,19 @@ def sel( """Return a new DataArray whose data is given by selecting index labels along the specified dimension(s). + In contrast to `DataArray.isel`, indexers for this method should use + labels instead of integers. + + Under the hood, this method is powered by using pandas's powerful Index + objects. This makes label based indexing essentially just as fast as + using integer indexing. + + It also means this method uses pandas's (well documented) logic for + indexing. This means you can use string shortcuts for datetime indexes + (e.g., '2000-01' to select all values in January 2000). It also means + that slices are treated as inclusive of both the start and stop values, + unlike normal Python indexing. + .. warning:: Do not try to assign values when using any of the indexing methods @@ -1089,6 +1158,45 @@ def sel( Assigning values with the chained indexing using ``.sel`` or ``.isel`` fails silently. + Parameters + ---------- + indexers : dict, optional + A dict with keys matching dimensions and values given + by scalars, slices or arrays of tick labels. For dimensions with + multi-index, the indexer may also be a dict-like object with keys + matching index level names. + If DataArrays are passed as indexers, xarray-style indexing will be + carried out. See :ref:`indexing` for the details. + One of indexers or indexers_kwargs must be provided. + method : {None, "nearest", "pad", "ffill", "backfill", "bfill"}, optional + Method to use for inexact matches: + + * None (default): only exact matches + * pad / ffill: propagate last valid index value forward + * backfill / bfill: propagate next valid index value backward + * nearest: use nearest valid index value + tolerance : optional + Maximum distance between original and new labels for inexact + matches. The values of the index at the matching locations must + satisfy the equation ``abs(index[indexer] - target) <= tolerance``. + drop : bool, optional + If ``drop=True``, drop coordinates variables in `indexers` instead + of making them scalar. + **indexers_kwargs : {dim: indexer, ...}, optional + The keyword arguments form of ``indexers``. + One of indexers or indexers_kwargs must be provided. + + Returns + ------- + obj : DataArray + A new DataArray with the same contents as this DataArray, except the + data and each dimension is indexed by the appropriate indexers. + If indexer DataArrays have coordinates that do not conflict with + this object, then these coordinates will be attached. + In general, each array's data will be a view of the array's data + in this DataArray, unless vectorized indexing was triggered by using + an array indexer, in which case the data will be a copy. + See Also -------- Dataset.sel @@ -1180,34 +1288,45 @@ def broadcast_like( Returns ------- - new_da: xr.DataArray + new_da : DataArray + The caller broadcasted against ``other``. Examples -------- + >>> arr1 = xr.DataArray( + ... np.random.randn(2, 3), + ... dims=("x", "y"), + ... coords={"x": ["a", "b"], "y": ["a", "b", "c"]}, + ... ) + >>> arr2 = xr.DataArray( + ... np.random.randn(3, 2), + ... dims=("x", "y"), + ... coords={"x": ["a", "b", "c"], "y": ["a", "b"]}, + ... ) >>> arr1 - array([[0.840235, 0.215216, 0.77917 ], - [0.726351, 0.543824, 0.875115]]) + array([[ 1.76405235, 0.40015721, 0.97873798], + [ 2.2408932 , 1.86755799, -0.97727788]]) Coordinates: * x (x) >> arr2 - array([[0.612611, 0.125753], - [0.853181, 0.948818], - [0.180885, 0.33363 ]]) + array([[ 0.95008842, -0.15135721], + [-0.10321885, 0.4105985 ], + [ 0.14404357, 1.45427351]]) Coordinates: * x (x) >> arr1.broadcast_like(arr2) - array([[0.840235, 0.215216, 0.77917 ], - [0.726351, 0.543824, 0.875115], - [ nan, nan, nan]]) + array([[ 1.76405235, 0.40015721, 0.97873798], + [ 2.2408932 , 1.86755799, -0.97727788], + [ nan, nan, nan]]) Coordinates: - * x (x) object 'a' 'b' 'c' - * y (y) object 'a' 'b' 'c' + * x (x) "DataArray": - """ Multidimensional interpolation of variables. + """Multidimensional interpolation of variables. + Parameters + ---------- coords : dict, optional Mapping from dimension names to the new coordinates. - new coordinate can be an scalar, array-like or DataArray. - If DataArrays are passed as new coordates, their dimensions are - used for the broadcasting. - method: {'linear', 'nearest'} for multidimensional array, - {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} - for 1-dimensional array. - assume_sorted: boolean, optional + New coordinate can be an scalar, array-like or DataArray. + If DataArrays are passed as new coordinates, their dimensions are + used for the broadcasting. Missing values are skipped. + method : str, default: "linear" + The method used to interpolate. Choose from + + - {"linear", "nearest"} for multidimensional array, + - {"linear", "nearest", "zero", "slinear", "quadratic", "cubic"} for 1-dimensional array. + assume_sorted : bool, optional If False, values of x can be in any order and they are sorted first. If True, x has to be an array of monotonically increasing values. - kwargs: dictionary + kwargs : dict Additional keyword arguments passed to scipy's interpolator. Valid options and their behavior depend on if 1-dimensional or multi-dimensional interpolation is used. - ``**coords_kwargs`` : {dim: coordinate, ...}, optional + **coords_kwargs : {dim: coordinate, ...}, optional The keyword arguments form of ``coords``. One of coords or coords_kwargs must be provided. Returns ------- - interpolated: xr.DataArray + interpolated : DataArray New dataarray on the new coordinates. Notes @@ -1389,12 +1523,71 @@ def interp( Examples -------- - >>> da = xr.DataArray([1, 3], [("x", np.arange(2))]) - >>> da.interp(x=0.5) - - array(2.0) + >>> da = xr.DataArray( + ... data=[[1, 4, 2, 9], [2, 7, 6, np.nan], [6, np.nan, 5, 8]], + ... dims=("x", "y"), + ... coords={"x": [0, 1, 2], "y": [10, 12, 14, 16]}, + ... ) + >>> da + + array([[ 1., 4., 2., 9.], + [ 2., 7., 6., nan], + [ 6., nan, 5., 8.]]) + Coordinates: + * x (x) int64 0 1 2 + * y (y) int64 10 12 14 16 + + 1D linear interpolation (the default): + + >>> da.interp(x=[0, 0.75, 1.25, 1.75]) + + array([[1. , 4. , 2. , nan], + [1.75, 6.25, 5. , nan], + [3. , nan, 5.75, nan], + [5. , nan, 5.25, nan]]) + Coordinates: + * y (y) int64 10 12 14 16 + * x (x) float64 0.0 0.75 1.25 1.75 + + 1D nearest interpolation: + + >>> da.interp(x=[0, 0.75, 1.25, 1.75], method="nearest") + + array([[ 1., 4., 2., 9.], + [ 2., 7., 6., nan], + [ 2., 7., 6., nan], + [ 6., nan, 5., 8.]]) + Coordinates: + * y (y) int64 10 12 14 16 + * x (x) float64 0.0 0.75 1.25 1.75 + + 1D linear extrapolation: + + >>> da.interp( + ... x=[1, 1.5, 2.5, 3.5], + ... method="linear", + ... kwargs={"fill_value": "extrapolate"}, + ... ) + + array([[ 2. , 7. , 6. , nan], + [ 4. , nan, 5.5, nan], + [ 8. , nan, 4.5, nan], + [12. , nan, 3.5, nan]]) Coordinates: - x float64 0.5 + * y (y) int64 10 12 14 16 + * x (x) float64 1.0 1.5 2.5 3.5 + + 2D linear interpolation: + + >>> da.interp(x=[0, 0.75, 1.25, 1.75], y=[11, 13, 15], method="linear") + + array([[2.5 , 3. , nan], + [4. , 5.625, nan], + [ nan, nan, nan], + [ nan, nan, nan]]) + Coordinates: + * x (x) float64 0.0 0.75 1.25 1.75 + * y (y) int64 11 13 15 """ if self.dtype.kind not in "uifc": raise TypeError( @@ -1425,22 +1618,23 @@ def interp_like( other : Dataset or DataArray Object with an 'indexes' attribute giving a mapping from dimension names to an 1d array-like, which provides coordinates upon - which to index the variables in this dataset. - method: string, optional. - {'linear', 'nearest'} for multidimensional array, - {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} - for 1-dimensional array. 'linear' is used by default. - assume_sorted: boolean, optional + which to index the variables in this dataset. Missing values are skipped. + method : str, default: "linear" + The method used to interpolate. Choose from + + - {"linear", "nearest"} for multidimensional array, + - {"linear", "nearest", "zero", "slinear", "quadratic", "cubic"} for 1-dimensional array. + assume_sorted : bool, optional If False, values of coordinates that are interpolated over can be in any order and they are sorted first. If True, interpolated coordinates are assumed to be an array of monotonically increasing values. - kwargs: dictionary, optional + kwargs : dict, optional Additional keyword passed to scipy's interpolator. Returns ------- - interpolated: xr.DataArray + interpolated : DataArray Another dataarray by interpolating this dataarray's data along the coordinates of the other object. @@ -1478,7 +1672,7 @@ def rename( If the argument is dict-like, it used as a mapping from old names to new names for coordinates. Otherwise, use the argument as the new name for this array. - **names: hashable, optional + **names : hashable, optional The keyword arguments form of a mapping from old names to new names for coordinates. One of new_name_or_name_dict or names must be provided. @@ -1522,7 +1716,9 @@ def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "DataArray": -------- >>> arr = xr.DataArray( - ... data=[0, 1], dims="x", coords={"x": ["a", "b"], "y": ("x", [0, 1])}, + ... data=[0, 1], + ... dims="x", + ... coords={"x": ["a", "b"], "y": ("x", [0, 1])}, ... ) >>> arr @@ -1571,20 +1767,20 @@ def expand_dims( Parameters ---------- - dim : hashable, sequence of hashable, dict, or None + dim : hashable, sequence of hashable, dict, or None, optional Dimensions to include on the new variable. If provided as str or sequence of str, then dimensions are inserted with length 1. If provided as a dict, then the keys are the new dimensions and the values are either integers (giving the length of the new dimensions) or sequence/ndarray (giving the coordinates of the new dimensions). - axis : integer, list (or tuple) of integers, or None + axis : int, list of int or tuple of int, or None, default: None Axis position(s) where new axis is to be inserted (position(s) on the result array). If a list (or tuple) of integers is passed, multiple axes are inserted. In this case, dim arguments should be same length list. If axis=None is passed, all the axes will be inserted to the start of the result array. - **dim_kwargs : int or sequence/ndarray + **dim_kwargs : int or sequence or ndarray The keywords are arbitrary dimensions being inserted and the values are either the lengths of the new dims (if int is given), or their coordinates. Note, this is an alternative to passing a dict to the @@ -1612,7 +1808,6 @@ def set_index( self, indexes: Mapping[Hashable, Union[Hashable, Sequence[Hashable]]] = None, append: bool = False, - inplace: bool = None, **indexes_kwargs: Union[Hashable, Sequence[Hashable]], ) -> Optional["DataArray"]: """Set DataArray (multi-)indexes using one or more existing @@ -1627,7 +1822,7 @@ def set_index( append : bool, optional If True, append the supplied index(es) to the existing index(es). Otherwise replace the existing index(es) (default). - **indexes_kwargs: optional + **indexes_kwargs : optional The keyword arguments form of ``indexes``. One of indexes or indexes_kwargs must be provided. @@ -1663,22 +1858,19 @@ def set_index( -------- DataArray.reset_index """ - ds = self._to_temp_dataset().set_index( - indexes, append=append, inplace=inplace, **indexes_kwargs - ) + ds = self._to_temp_dataset().set_index(indexes, append=append, **indexes_kwargs) return self._from_temp_dataset(ds) def reset_index( self, dims_or_levels: Union[Hashable, Sequence[Hashable]], drop: bool = False, - inplace: bool = None, ) -> Optional["DataArray"]: """Reset the specified index(es) or multi-index level(s). Parameters ---------- - dims_or_levels : hashable or sequence of hashables + dims_or_levels : hashable or sequence of hashable Name(s) of the dimension(s) and/or multi-index level(s) that will be reset. drop : bool, optional @@ -1695,7 +1887,6 @@ def reset_index( -------- DataArray.set_index """ - _check_inplace(inplace) coords, _ = split_indexes( dims_or_levels, self._coords, set(), self._level_coords, drop=drop ) @@ -1704,7 +1895,6 @@ def reset_index( def reorder_levels( self, dim_order: Mapping[Hashable, Sequence[int]] = None, - inplace: bool = None, **dim_order_kwargs: Sequence[int], ) -> "DataArray": """Rearrange index levels using input order. @@ -1715,7 +1905,7 @@ def reorder_levels( Mapping from names matching dimensions and values given by lists representing new level orders. Every given dimension must have a multi-index. - **dim_order_kwargs: optional + **dim_order_kwargs : optional The keyword arguments form of ``dim_order``. One of dim_order or dim_order_kwargs must be provided. @@ -1725,7 +1915,6 @@ def reorder_levels( Another dataarray, with this dataarray's data but replaced coordinates. """ - _check_inplace(inplace) dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, "reorder_levels") replace_coords = {} for dim, order in dim_order.items(): @@ -1751,12 +1940,13 @@ def stack( Parameters ---------- - dimensions : Mapping of the form new_name=(dim1, dim2, ...) + dimensions : mapping of hashable to sequence of hashable + Mapping of the form `new_name=(dim1, dim2, ...)`. Names of new dimensions, and the existing dimensions that they replace. An ellipsis (`...`) will be replaced by all unlisted dimensions. Passing a list containing an ellipsis (`stacked_dim=[...]`) will stack over all dimensions. - **dimensions_kwargs: + **dimensions_kwargs The keyword arguments form of ``dimensions``. One of dimensions or dimensions_kwargs must be provided. @@ -1777,12 +1967,16 @@ def stack( array([[0, 1, 2], [3, 4, 5]]) Coordinates: - * x (x) |S1 'a' 'b' + * x (x) >> stacked = arr.stack(z=("x", "y")) >>> stacked.indexes["z"] - MultiIndex(levels=[['a', 'b'], [0, 1, 2]], - codes=[[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]], + MultiIndex([('a', 0), + ('a', 1), + ('a', 2), + ('b', 0), + ('b', 1), + ('b', 2)], names=['x', 'y']) See Also @@ -1809,8 +2003,13 @@ def unstack( dim : hashable or sequence of hashable, optional Dimension(s) over which to unstack. By default unstacks all MultiIndexes. - fill_value: value to be filled. By default, np.nan - sparse: use sparse-array if True + fill_value : scalar or dict-like, default: nan + value to be filled. If a dict-like, maps variable names to + fill values. Use the data array's name to refer to its + name. If not provided or if the dict-like does not contain + all variables, the dtype's NA value will be used. + sparse : bool, default: False + use sparse-array if True Returns ------- @@ -1829,12 +2028,16 @@ def unstack( array([[0, 1, 2], [3, 4, 5]]) Coordinates: - * x (x) |S1 'a' 'b' + * x (x) >> stacked = arr.stack(z=("x", "y")) >>> stacked.indexes["z"] - MultiIndex(levels=[['a', 'b'], [0, 1, 2]], - codes=[[0, 0, 0, 1, 1, 1], [0, 1, 2, 0, 1, 2]], + MultiIndex([('a', 0), + ('a', 1), + ('a', 2), + ('b', 0), + ('b', 1), + ('b', 2)], names=['x', 'y']) >>> roundtripped = stacked.unstack() >>> arr.identical(roundtripped) @@ -1860,7 +2063,7 @@ def to_unstacked_dataset(self, dim, level=0): level : int or str The MultiIndex level to expand to a dataset along. Can either be the integer index of the level or its name. - label : int, default 0 + label : int, default: 0 Label of the level to expand dataset along. Overrides the label argument if given. @@ -1885,11 +2088,13 @@ def to_unstacked_dataset(self, dim, level=0): Data variables: a (x, y) int64 0 1 2 3 4 5 b (x) int64 0 3 - >>> stacked = data.to_stacked_array("z", ["y"]) + >>> stacked = data.to_stacked_array("z", ["x"]) >>> stacked.indexes["z"] - MultiIndex(levels=[['a', 'b'], [0, 1, 2]], - labels=[[0, 0, 0, 1], [0, 1, 2, -1]], - names=['variable', 'y']) + MultiIndex([('a', 0.0), + ('a', 1.0), + ('a', 2.0), + ('b', nan)], + names=['variable', 'y']) >>> roundtripped = stacked.to_unstacked_dataset(dim="z") >>> data.identical(roundtripped) True @@ -1910,12 +2115,17 @@ def to_unstacked_dataset(self, dim, level=0): # pull variables out of datarray data_dict = {} for k in variables: - data_dict[k] = self.sel({variable_dim: k}).squeeze(drop=True) + data_dict[k] = self.sel({variable_dim: k}, drop=True).squeeze(drop=True) # unstacked dataset return Dataset(data_dict) - def transpose(self, *dims: Hashable, transpose_coords: bool = None) -> "DataArray": + def transpose( + self, + *dims: Hashable, + transpose_coords: bool = True, + missing_dims: str = "raise", + ) -> "DataArray": """Return a new DataArray object with transposed dimensions. Parameters @@ -1923,8 +2133,14 @@ def transpose(self, *dims: Hashable, transpose_coords: bool = None) -> "DataArra *dims : hashable, optional By default, reverse the dimensions. Otherwise, reorder the dimensions to this order. - transpose_coords : boolean, optional + transpose_coords : bool, default: True If True, also transpose the coordinates of this DataArray. + missing_dims : {"raise", "warn", "ignore"}, default: "raise" + What to do if dimensions that should be selected from are not present in the + DataArray: + - "raise": raise an exception + - "warning": raise a warning, and ignore the missing dimensions + - "ignore": ignore the missing dimensions Returns ------- @@ -1943,7 +2159,7 @@ def transpose(self, *dims: Hashable, transpose_coords: bool = None) -> "DataArra Dataset.transpose """ if dims: - dims = tuple(utils.infix_dims(dims, self.dims)) + dims = tuple(utils.infix_dims(dims, self.dims, missing_dims)) variable = self.variable.transpose(*dims) if transpose_coords: coords: Dict[Hashable, Variable] = {} @@ -1952,15 +2168,6 @@ def transpose(self, *dims: Hashable, transpose_coords: bool = None) -> "DataArra coords[name] = coord.variable.transpose(*coord_dims) return self._replace(variable, coords) else: - if transpose_coords is None and any(self[c].ndim > 1 for c in self.coords): - warnings.warn( - "This DataArray contains multi-dimensional " - "coordinates. In the future, these coordinates " - "will be transposed as well unless you specify " - "transpose_coords=False.", - FutureWarning, - stacklevel=2, - ) return self._replace(variable) @property @@ -1970,13 +2177,13 @@ def T(self) -> "DataArray": def drop_vars( self, names: Union[Hashable, Iterable[Hashable]], *, errors: str = "raise" ) -> "DataArray": - """Drop variables from this DataArray. + """Returns an array with dropped variables. Parameters ---------- - names : hashable or iterable of hashables + names : hashable or iterable of hashable Name(s) of variables to drop. - errors: {'raise', 'ignore'}, optional + errors: {"raise", "ignore"}, optional If 'raise' (default), raises a ValueError error if any of the variable passed are not in the dataset. If 'ignore', any given names that are in the DataArray are dropped and no error is raised. @@ -1984,7 +2191,7 @@ def drop_vars( Returns ------- dropped : Dataset - + New Dataset copied from `self` with variables removed. """ ds = self._to_temp_dataset().drop_vars(names, errors=errors) return self._from_temp_dataset(ds) @@ -2020,9 +2227,9 @@ def drop_sel( Parameters ---------- - labels : Mapping[Hashable, Any] + labels : mapping of hashable to Any Index labels to drop - errors: {'raise', 'ignore'}, optional + errors : {"raise", "ignore"}, optional If 'raise' (default), raises a ValueError error if any of the index labels passed are not in the dataset. If 'ignore', any given labels that are in the @@ -2051,10 +2258,10 @@ def dropna( dim : hashable Dimension along which to drop missing values. Dropping along multiple dimensions simultaneously is not yet supported. - how : {'any', 'all'}, optional + how : {"any", "all"}, optional * any : if any NA values are present, drop that label * all : if all values are NA, drop that label - thresh : int, default None + thresh : int, default: None If supplied, require this many non-NA values. Returns @@ -2121,18 +2328,18 @@ def interpolate_na( - 'barycentric', 'krog', 'pchip', 'spline', 'akima': use their respective :py:class:`scipy.interpolate` classes. - use_coordinate : bool, str, default True + use_coordinate : bool or str, default: True Specifies which index to use as the x values in the interpolation formulated as `y = f(x)`. If False, values are treated as if eqaully-spaced along ``dim``. If True, the IndexVariable `dim` is used. If ``use_coordinate`` is a string, it specifies the name of a coordinate variariable to use as the index. - limit : int, default None + limit : int, default: None Maximum number of consecutive NaNs to fill. Must be greater than 0 or None for no limit. This filling is done regardless of the size of the gap in the data. To only interpolate over gaps less than a given length, see ``max_gap``. - max_gap: int, float, str, pandas.Timedelta, numpy.timedelta64, datetime.timedelta, default None. + max_gap: int, float, str, pandas.Timedelta, numpy.timedelta64, datetime.timedelta, default: None Maximum size of gap, a continuous sequence of NaNs, that will be filled. Use None for no limit. When interpolating along a datetime64 dimension and ``use_coordinate=True``, ``max_gap`` can be one of the following: @@ -2155,7 +2362,7 @@ def interpolate_na( * x (x) int64 0 1 2 3 4 5 6 7 8 The gap lengths are 3-0 = 3; 6-3 = 3; and 8-6 = 2 respectively - keep_attrs : bool, default True + keep_attrs : bool, default: True If True, the dataarray's attributes (`attrs`) will be copied from the original object to the new one. If False, the new object will be returned without attributes. @@ -2171,6 +2378,29 @@ def interpolate_na( -------- numpy.interp scipy.interpolate + + Examples + -------- + >>> da = xr.DataArray( + ... [np.nan, 2, 3, np.nan, 0], dims="x", coords={"x": [0, 1, 2, 3, 4]} + ... ) + >>> da + + array([nan, 2., 3., nan, 0.]) + Coordinates: + * x (x) int64 0 1 2 3 4 + + >>> da.interpolate_na(dim="x", method="linear") + + array([nan, 2. , 3. , 1.5, 0. ]) + Coordinates: + * x (x) int64 0 1 2 3 4 + + >>> da.interpolate_na(dim="x", method="linear", fill_value="extrapolate") + + array([1. , 2. , 3. , 1.5, 0. ]) + Coordinates: + * x (x) int64 0 1 2 3 4 """ from .missing import interp_na @@ -2195,7 +2425,7 @@ def ffill(self, dim: Hashable, limit: int = None) -> "DataArray": dim : hashable Specifies the dimension along which to propagate values when filling. - limit : int, default None + limit : int, default: None The maximum number of consecutive NaN values to forward fill. In other words, if there is a gap with more than this number of consecutive NaNs, it will only be partially filled. Must be greater @@ -2219,7 +2449,7 @@ def bfill(self, dim: Hashable, limit: int = None) -> "DataArray": dim : str Specifies the dimension along which to propagate values when filling. - limit : int, default None + limit : int, default: None The maximum number of consecutive NaN values to backward fill. In other words, if there is a gap with more than this number of consecutive NaNs, it will only be partially filled. Must be greater @@ -2264,11 +2494,11 @@ def reduce( Parameters ---------- - func : function + func : callable Function which can be called in the form `f(x, axis=axis, **kwargs)` to return the result of reducing an np.ndarray over an integer valued axis. - dim : hashable or sequence of hashables, optional + dim : hashable or sequence of hashable, optional Dimension(s) over which to apply `func`. axis : int or sequence of int, optional Axis(es) over which to repeatedly apply `func`. Only one of the @@ -2279,7 +2509,7 @@ def reduce( If True, the variable's attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be returned without attributes. - keepdims : bool, default False + keepdims : bool, default: False If True, the dimensions which are reduced are left in the result as dimensions of size one. Coordinates that use these dimensions are removed. @@ -2323,13 +2553,36 @@ def to_pandas(self) -> Union["DataArray", pd.Series, pd.DataFrame]: indexes = [self.get_index(dim) for dim in self.dims] return constructor(self.values, *indexes) - def to_dataframe(self, name: Hashable = None) -> pd.DataFrame: + def to_dataframe( + self, name: Hashable = None, dim_order: List[Hashable] = None + ) -> pd.DataFrame: """Convert this array and its coordinates into a tidy pandas.DataFrame. The DataFrame is indexed by the Cartesian product of index coordinates (in the form of a :py:class:`pandas.MultiIndex`). Other coordinates are included as columns in the DataFrame. + + Parameters + ---------- + name + Name to give to this array (required if unnamed). + dim_order + Hierarchical dimension order for the resulting dataframe. + Array content is transposed to this order and then written out as flat + vectors in contiguous order, so the last dimension in this list + will be contiguous in the resulting DataFrame. This has a major + influence on which operations are efficient on the resulting + dataframe. + + If provided, must include all dimensions of this DataArray. By default, + dimensions are sorted according to the DataArray dimensions order. + + Returns + ------- + result + DataArray as a pandas DataFrame. + """ if name is None: name = self.name @@ -2338,8 +2591,9 @@ def to_dataframe(self, name: Hashable = None) -> pd.DataFrame: "cannot convert an unnamed DataArray to a " "DataFrame: use the ``name`` parameter" ) + if self.ndim == 0: + raise ValueError("cannot convert a scalar to a DataFrame") - dims = dict(zip(self.dims, self.shape)) # By using a unique name, we can convert a DataArray into a DataFrame # even if it shares a name with one of its coordinates. # I would normally use unique_name = object() but that results in a @@ -2347,7 +2601,13 @@ def to_dataframe(self, name: Hashable = None) -> pd.DataFrame: # been able to debug (possibly a pandas bug?). unique_name = "__unique_name_identifier_z98xfz98xugfg73ho__" ds = self._to_dataset_whole(name=unique_name) - df = ds._to_dataframe(dims) + + if dim_order is None: + ordered_dims = dict(zip(self.dims, self.shape)) + else: + ordered_dims = ds._normalize_dim_order(dim_order=dim_order) + + df = ds._to_dataframe(ordered_dims) df.columns = [name if c == unique_name else c for c in df.columns] return df @@ -2365,8 +2625,8 @@ def to_masked_array(self, copy: bool = True) -> np.ma.MaskedArray: Parameters ---------- - copy : bool - If True (default) make a copy of the array in the result. If False, + copy : bool, default: True + If True make a copy of the array in the result. If False, a MaskedArray view of DataArray.values is returned. Returns @@ -2381,15 +2641,19 @@ def to_masked_array(self, copy: bool = True) -> np.ma.MaskedArray: def to_netcdf(self, *args, **kwargs) -> Union[bytes, "Delayed", None]: """Write DataArray contents to a netCDF file. - All parameters are passed directly to `xarray.Dataset.to_netcdf`. + All parameters are passed directly to :py:meth:`xarray.Dataset.to_netcdf`. Notes ----- Only xarray.Dataset objects can be written to netCDF files, so the xarray.DataArray is converted to a xarray.Dataset object containing a single variable. If the DataArray has no name, or if the - name is the same as a co-ordinate name, then it is given the name - '__xarray_dataarray_variable__'. + name is the same as a coordinate name, then it is given the name + ``"__xarray_dataarray_variable__"``. + + See Also + -------- + Dataset.to_netcdf """ from ..backends.api import DATAARRAY_NAME, DATAARRAY_VARIABLE @@ -2414,7 +2678,7 @@ def to_dict(self, data: bool = True) -> dict: Converts all variables and attributes to native Python objects. Useful for converting to json. To avoid datetime incompatibility - use decode_times=False kwarg in xarrray.open_dataset. + use decode_times=False kwarg in xarray.open_dataset. Parameters ---------- @@ -2437,23 +2701,27 @@ def from_dict(cls, d: dict) -> "DataArray": """ Convert a dictionary into an xarray.DataArray - Input dict can take several forms:: + Input dict can take several forms: - d = {'dims': ('t'), 'data': x} + .. code:: python - d = {'coords': {'t': {'dims': 't', 'data': t, - 'attrs': {'units':'s'}}}, - 'attrs': {'title': 'air temperature'}, - 'dims': 't', - 'data': x, - 'name': 'a'} + d = {"dims": ("t"), "data": x} - where 't' is the name of the dimesion, 'a' is the name of the array, - and x and t are lists, numpy.arrays, or pandas objects. + d = { + "coords": {"t": {"dims": "t", "data": t, "attrs": {"units": "s"}}}, + "attrs": {"title": "air temperature"}, + "dims": "t", + "data": x, + "name": "a", + } + + where "t" is the name of the dimesion, "a" is the name of the array, + and x and t are lists, numpy.arrays, or pandas objects. Parameters ---------- - d : dict, with a minimum structure of {'dims': [..], 'data': [..]} + d : dict + Mapping with a minimum structure of {"dims": [...], "data": [...]} Returns ------- @@ -2508,38 +2776,33 @@ def from_series(cls, series: pd.Series, sparse: bool = False) -> "DataArray": return result def to_cdms2(self) -> "cdms2_Variable": - """Convert this array into a cdms2.Variable - """ + """Convert this array into a cdms2.Variable""" from ..convert import to_cdms2 return to_cdms2(self) @classmethod def from_cdms2(cls, variable: "cdms2_Variable") -> "DataArray": - """Convert a cdms2.Variable into an xarray.DataArray - """ + """Convert a cdms2.Variable into an xarray.DataArray""" from ..convert import from_cdms2 return from_cdms2(variable) def to_iris(self) -> "iris_Cube": - """Convert this array into a iris.cube.Cube - """ + """Convert this array into a iris.cube.Cube""" from ..convert import to_iris return to_iris(self) @classmethod def from_iris(cls, cube: "iris_Cube") -> "DataArray": - """Convert a iris.cube.Cube into an xarray.DataArray - """ + """Convert a iris.cube.Cube into an xarray.DataArray""" from ..convert import from_iris return from_iris(cube) def _all_compat(self, other: "DataArray", compat_str: str) -> bool: - """Helper function for equals, broadcast_equals, and identical - """ + """Helper function for equals, broadcast_equals, and identical""" def compat(x, y): return getattr(x.variable, compat_str)(y.variable) @@ -2590,7 +2853,7 @@ def identical(self, other: "DataArray") -> bool: See Also -------- DataArray.broadcast_equals - DataArray.equal + DataArray.equals """ try: return self.name == other.name and self._all_compat(other, "identical") @@ -2622,8 +2885,19 @@ def __rmatmul__(self, other): def _unary_op(f: Callable[..., Any]) -> Callable[..., "DataArray"]: @functools.wraps(f) def func(self, *args, **kwargs): - with np.errstate(all="ignore"): - return self.__array_wrap__(f(self.variable.data, *args, **kwargs)) + keep_attrs = kwargs.pop("keep_attrs", None) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") + warnings.filterwarnings( + "ignore", r"Mean of empty slice", category=RuntimeWarning + ) + with np.errstate(all="ignore"): + da = self.__array_wrap__(f(self.variable.data, *args, **kwargs)) + if keep_attrs: + da.attrs = self.attrs + return da return func @@ -2671,8 +2945,15 @@ def func(self, other): # don't support automatic alignment with in-place arithmetic. other_coords = getattr(other, "coords", None) other_variable = getattr(other, "variable", other) - with self.coords._merge_inplace(other_coords): - f(self.variable, other_variable) + try: + with self.coords._merge_inplace(other_coords): + f(self.variable, other_variable) + except MergeError as exc: + raise MergeError( + "Automatic alignment is not supported for in-place operations.\n" + "Consider aligning the indices manually or using a not-in-place operation.\n" + "See https://github.com/pydata/xarray/issues/3910 for more explanations." + ) from exc return self return func @@ -2680,24 +2961,7 @@ def func(self, other): def _copy_attrs_from(self, other: Union["DataArray", Dataset, Variable]) -> None: self.attrs = other.attrs - @property - def plot(self) -> _PlotMethods: - """ - Access plotting functions for DataArray's - - >>> d = xr.DataArray([[1, 2], [3, 4]]) - - For convenience just call this directly - - >>> d.plot() - - Or use it as a namespace to use xarray.plot functions as - DataArray methods - - >>> d.plot.imshow() # equivalent to xarray.plot.imshow(d) - - """ - return _PlotMethods(self) + plot = utils.UncachedAccessor(_PlotMethods) def _title_for_slice(self, truncate: int = 50) -> str: """ @@ -2706,7 +2970,7 @@ def _title_for_slice(self, truncate: int = 50) -> str: Parameters ---------- - truncate : integer + truncate : int, default: 50 maximum number of characters for title Returns @@ -2748,10 +3012,10 @@ def diff(self, dim: Hashable, n: int = 1, label: Hashable = "upper") -> "DataArr difference : same type as caller The n-th order finite difference of this object. - .. note:: - - `n` matches numpy's behavior and is different from pandas' first - argument named `periods`. + Notes + ----- + `n` matches numpy's behavior and is different from pandas' first argument named + `periods`. Examples @@ -2761,12 +3025,12 @@ def diff(self, dim: Hashable, n: int = 1, label: Hashable = "upper") -> "DataArr array([0, 1, 0]) Coordinates: - * x (x) int64 2 3 4 + * x (x) int64 2 3 4 >>> arr.diff("x", 2) array([ 1, -1]) Coordinates: - * x (x) int64 3 4 + * x (x) int64 3 4 See Also -------- @@ -2789,13 +3053,13 @@ def shift( Parameters ---------- - shifts : Mapping with the form of {dim: offset} + shifts : mapping of hashable to int, optional Integer offset to shift along each of the given dimensions. Positive offsets shift to the right; negative offsets shift to the left. fill_value: scalar, optional Value to use for newly missing values - **shifts_kwargs: + **shifts_kwargs The keyword arguments form of ``shifts``. One of shifts or shifts_kwargs must be provided. @@ -2815,9 +3079,8 @@ def shift( >>> arr = xr.DataArray([5, 6, 7], dims="x") >>> arr.shift(x=1) - array([ nan, 5., 6.]) - Coordinates: - * x (x) int64 0 1 2 + array([nan, 5., 6.]) + Dimensions without coordinates: x """ variable = self.variable.shift( shifts=shifts, fill_value=fill_value, **shifts_kwargs @@ -2838,16 +3101,17 @@ def roll( Parameters ---------- - shifts : Mapping with the form of {dim: offset} + shifts : mapping of hashable to int, optional Integer offset to rotate each of the given dimensions. Positive offsets roll to the right; negative offsets roll to the left. roll_coords : bool - Indicates whether to roll the coordinates by the offset + Indicates whether to roll the coordinates by the offset The current default of roll_coords (None, equivalent to True) is deprecated and will change to False in a future version. Explicitly pass roll_coords to silence the warning. - **shifts_kwargs : The keyword arguments form of ``shifts``. + **shifts_kwargs + The keyword arguments form of ``shifts``. One of shifts or shifts_kwargs must be provided. Returns @@ -2866,8 +3130,7 @@ def roll( >>> arr.roll(x=1) array([7, 5, 6]) - Coordinates: - * x (x) int64 2 0 1 + Dimensions without coordinates: x """ ds = self._to_temp_dataset().roll( shifts=shifts, roll_coords=roll_coords, **shifts_kwargs @@ -2893,8 +3156,8 @@ def dot( ---------- other : DataArray The other array with which the dot product is performed. - dims: '...', hashable or sequence of hashables, optional - Which dimensions to sum over. Ellipsis ('...') sums over all dimensions. + dims : ..., hashable or sequence of hashable, optional + Which dimensions to sum over. Ellipsis (`...`) sums over all dimensions. If not specified, then all the common dimensions are summed over. Returns @@ -2916,7 +3179,7 @@ def dot( >>> dm = xr.DataArray(dm_vals, dims=["z"]) >>> dm.dims - ('z') + ('z',) >>> da.dims ('x', 'y', 'z') @@ -2958,15 +3221,15 @@ def sortby( Parameters ---------- - variables: hashable, DataArray, or sequence of either + variables : hashable, DataArray, or sequence of hashable or DataArray 1D DataArray objects or name(s) of 1D variable(s) in coords whose values are used to sort this array. - ascending: boolean, optional + ascending : bool, optional Whether to sort by ascending or descending order. Returns ------- - sorted: DataArray + sorted : DataArray A new dataarray where all the specified dims are sorted by dim labels. @@ -2980,15 +3243,15 @@ def sortby( ... ) >>> da - array([ 0.965471, 0.615637, 0.26532 , 0.270962, 0.552878]) + array([0.5488135 , 0.71518937, 0.60276338, 0.54488318, 0.4236548 ]) Coordinates: - * time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 ... + * time (time) datetime64[ns] 2000-01-01 2000-01-02 ... 2000-01-05 >>> da.sortby(da) - array([ 0.26532 , 0.270962, 0.552878, 0.615637, 0.965471]) + array([0.4236548 , 0.54488318, 0.5488135 , 0.60276338, 0.71518937]) Coordinates: - * time (time) datetime64[ns] 2000-01-03 2000-01-04 2000-01-05 ... + * time (time) datetime64[ns] 2000-01-05 2000-01-04 ... 2000-01-02 """ ds = self._to_temp_dataset().sortby(variables, ascending=ascending) return self._from_temp_dataset(ds) @@ -3007,11 +3270,11 @@ def quantile( Parameters ---------- - q : float in range of [0,1] or array-like of floats + q : float or array-like of float Quantile to compute, which must be between 0 and 1 inclusive. dim : hashable or sequence of hashable, optional Dimension(s) over which to apply quantile. - interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'} + interpolation : {"linear", "lower", "higher", "midpoint", "nearest"}, default: "linear" This optional parameter specifies the interpolation method to use when the desired quantile lies between two data points ``i < j``: @@ -3121,7 +3384,7 @@ def rank( >>> arr = xr.DataArray([5, 6, 7], dims="x") >>> arr.rank("x") - array([ 1., 2., 3.]) + array([1., 2., 3.]) Dimensions without coordinates: x """ @@ -3140,12 +3403,12 @@ def differentiate( Parameters ---------- - coord: hashable + coord : hashable The coordinate to be used to compute the gradient. - edge_order: 1 or 2. Default 1 + edge_order : {1, 2}, default: 1 N-th order accurate differences at the boundaries. - datetime_unit: None or any of {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', - 'us', 'ns', 'ps', 'fs', 'as'} + datetime_unit : {"Y", "M", "W", "D", "h", "m", "s", "ms", \ + "us", "ns", "ps", "fs", "as"} or None, optional Unit to compute gradient. Only valid for datetime coordinate. Returns @@ -3176,10 +3439,10 @@ def differentiate( >>> >>> da.differentiate("x") - array([[30. , 30. , 30. ], - [27.545455, 27.545455, 27.545455], - [27.545455, 27.545455, 27.545455], - [30. , 30. , 30. ]]) + array([[30. , 30. , 30. ], + [27.54545455, 27.54545455, 27.54545455], + [27.54545455, 27.54545455, 27.54545455], + [30. , 30. , 30. ]]) Coordinates: * x (x) float64 0.0 0.1 1.1 1.2 Dimensions without coordinates: y @@ -3193,7 +3456,7 @@ def integrate( dim: Union[Hashable, Sequence[Hashable]] = None, datetime_unit: str = None, ) -> "DataArray": - """ Integrate along the given coordinate using the trapezoidal rule. + """Integrate along the given coordinate using the trapezoidal rule. .. note:: This feature is limited to simple cartesian geometry, i.e. coord @@ -3203,10 +3466,12 @@ def integrate( ---------- coord: hashable, or a sequence of hashable Coordinate(s) used for the integration. + dim : hashable, or sequence of hashable + Coordinate(s) used for the integration. datetime_unit: str, optional - Can be used to specify the unit if datetime coordinate is used. - One of {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', 'ps', - 'fs', 'as'} + Can be specify the unit if datetime coordinate is used. One of + {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', 'ps', 'fs', + 'as'} Returns ------- @@ -3254,7 +3519,7 @@ def integrate( return self._from_temp_dataset(ds) def unify_chunks(self) -> "DataArray": - """ Unify chunk size along all chunked dimensions of this DataArray. + """Unify chunk size along all chunked dimensions of this DataArray. Returns ------- @@ -3274,57 +3539,105 @@ def map_blocks( func: "Callable[..., T_DSorDA]", args: Sequence[Any] = (), kwargs: Mapping[str, Any] = None, + template: Union["DataArray", "Dataset"] = None, ) -> "T_DSorDA": """ - Apply a function to each chunk of this DataArray. This method is experimental - and its signature may change. + Apply a function to each block of this DataArray. + + .. warning:: + This method is experimental and its signature may change. Parameters ---------- - func: callable - User-provided function that accepts a DataArray as its first parameter. The - function will receive a subset of this DataArray, corresponding to one chunk - along each chunked dimension. ``func`` will be executed as - ``func(obj_subset, *args, **kwargs)``. - - The function will be first run on mocked-up data, that looks like this array - but has sizes 0, to determine properties of the returned object such as - dtype, variable names, new dimensions and new indexes (if any). + func : callable + User-provided function that accepts a DataArray as its first + parameter. The function will receive a subset or 'block' of this DataArray (see below), + corresponding to one chunk along each chunked dimension. ``func`` will be + executed as ``func(subset_dataarray, *subset_args, **kwargs)``. This function must return either a single DataArray or a single Dataset. - This function cannot change size of existing dimensions, or add new chunked - dimensions. - args: Sequence - Passed verbatim to func after unpacking, after the sliced DataArray. xarray - objects, if any, will not be split by chunks. Passing dask collections is - not allowed. - kwargs: Mapping + This function cannot add a new chunked dimension. + args : sequence + Passed to func after unpacking and subsetting any xarray objects by blocks. + xarray objects in args must be aligned with this object, otherwise an error is raised. + kwargs : mapping Passed verbatim to func after unpacking. xarray objects, if any, will not be - split by chunks. Passing dask collections is not allowed. + subset to blocks. Passing dask collections in kwargs is not allowed. + template : DataArray or Dataset, optional + xarray object representing the final result after compute is called. If not provided, + the function will be first run on mocked-up data, that looks like this object but + has sizes 0, to determine properties of the returned object such as dtype, + variable names, attributes, new dimensions and new indexes (if any). + ``template`` must be provided if the function changes the size of existing dimensions. + When provided, ``attrs`` on variables in `template` are copied over to the result. Any + ``attrs`` set by ``func`` will be ignored. Returns ------- - A single DataArray or Dataset with dask backend, reassembled from the outputs of - the function. + A single DataArray or Dataset with dask backend, reassembled from the outputs of the + function. Notes ----- - This method is designed for when one needs to manipulate a whole xarray object - within each chunk. In the more common case where one can work on numpy arrays, - it is recommended to use apply_ufunc. + This function is designed for when ``func`` needs to manipulate a whole xarray object + subset to each block. In the more common case where ``func`` can work on numpy arrays, it is + recommended to use ``apply_ufunc``. - If none of the variables in this DataArray is backed by dask, calling this - method is equivalent to calling ``func(self, *args, **kwargs)``. + If none of the variables in this object is backed by dask arrays, calling this function is + equivalent to calling ``func(obj, *args, **kwargs)``. See Also -------- - dask.array.map_blocks, xarray.apply_ufunc, xarray.map_blocks, - xarray.Dataset.map_blocks + dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks, + xarray.DataArray.map_blocks + + Examples + -------- + + Calculate an anomaly from climatology using ``.groupby()``. Using + ``xr.map_blocks()`` allows for parallel operations with knowledge of ``xarray``, + its indices, and its methods like ``.groupby()``. + + >>> def calculate_anomaly(da, groupby_type="time.month"): + ... gb = da.groupby(groupby_type) + ... clim = gb.mean(dim="time") + ... return gb - clim + ... + >>> time = xr.cftime_range("1990-01", "1992-01", freq="M") + >>> month = xr.DataArray(time.month, coords={"time": time}, dims=["time"]) + >>> np.random.seed(123) + >>> array = xr.DataArray( + ... np.random.rand(len(time)), + ... dims=["time"], + ... coords={"time": time, "month": month}, + ... ).chunk() + >>> array.map_blocks(calculate_anomaly, template=array).compute() + + array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862, + 0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714, + -0.19063865, 0.0590131 , -0.12894847, -0.11323072, 0.0855964 , + 0.09334032, -0.26848862, -0.12382735, -0.22460641, -0.07650108, + 0.07673453, 0.22865714, 0.19063865, -0.0590131 ]) + Coordinates: + * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + month (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 1 2 3 4 5 6 7 8 9 10 11 12 + + Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments + to the function being applied in ``xr.map_blocks()``: + + >>> array.map_blocks( + ... calculate_anomaly, kwargs={"groupby_type": "time.year"}, template=array + ... ) # doctest: +ELLIPSIS + + dask.array + Coordinates: + * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + month (time) int64 dask.array """ from .parallel import map_blocks - return map_blocks(func, self, args, kwargs) + return map_blocks(func, self, args, kwargs, template) def polyfit( self, @@ -3354,13 +3667,13 @@ def polyfit( invalid values, False otherwise. rcond : float, optional Relative condition number to the fit. - w : Union[Hashable, Any], optional + w : hashable or array-like, optional Weights to apply to the y-coordinate of the sample points. Can be an array-like object or the name of a coordinate in the dataset. full : bool, optional Whether to return the residuals, matrix rank and singular values in addition to the coefficients. - cov : Union[bool, str], optional + cov : bool or str, optional Whether to return to the covariance matrix in addition to the coefficients. The matrix is not scaled if `cov='unscaled'`. @@ -3372,7 +3685,8 @@ def polyfit( polyfit_coefficients The coefficients of the best fit. polyfit_residuals - The residuals of the least-square computation (only included if `full=True`) + The residuals of the least-square computation (only included if `full=True`). + When the matrix rank is deficient, np.nan is returned. [dim]_matrix_rank The effective rank of the scaled Vandermonde coefficient matrix (only included if `full=True`) [dim]_singular_value @@ -3416,10 +3730,11 @@ def pad( Parameters ---------- - pad_width : Mapping with the form of {dim: (pad_before, pad_after)} - Number of values padded along each dimension. + pad_width : mapping of hashable to tuple of int + Mapping with the form of {dim: (pad_before, pad_after)} + describing the number of values padded along each dimension. {dim: pad} is a shortcut for pad_before = pad_after = pad - mode : str + mode : str, default: "constant" One of the following string values (taken from numpy docs) 'constant' (default) @@ -3452,7 +3767,7 @@ def pad( Pads with the wrap of the vector along the axis. The first values are used to pad the end and the end values are used to pad the beginning. - stat_length : int, tuple or mapping of the form {dim: tuple} + stat_length : int, tuple or mapping of hashable to tuple, default: None Used in 'maximum', 'mean', 'median', and 'minimum'. Number of values at edge of each axis used to calculate the statistic value. {dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)} unique @@ -3462,7 +3777,7 @@ def pad( (stat_length,) or int is a shortcut for before = after = statistic length for all axes. Default is ``None``, to use the entire axis. - constant_values : scalar, tuple or mapping of the form {dim: tuple} + constant_values : scalar, tuple or mapping of hashable to tuple, default: 0 Used in 'constant'. The values to set the padded values for each axis. ``{dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}`` unique @@ -3472,7 +3787,7 @@ def pad( ``(constant,)`` or ``constant`` is a shortcut for ``before = after = constant`` for all dimensions. Default is 0. - end_values : scalar, tuple or mapping of the form {dim: tuple} + end_values : scalar, tuple or mapping of hashable to tuple, default: 0 Used in 'linear_ramp'. The values used for the ending value of the linear_ramp and that will form the edge of the padded array. ``{dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}`` unique @@ -3482,12 +3797,12 @@ def pad( ``(constant,)`` or ``constant`` is a shortcut for ``before = after = constant`` for all axes. Default is 0. - reflect_type : {'even', 'odd'}, optional - Used in 'reflect', and 'symmetric'. The 'even' style is the + reflect_type : {"even", "odd"}, optional + Used in "reflect", and "symmetric". The "even" style is the default with an unaltered reflection around the edge value. For - the 'odd' style, the extended part of the array is created by + the "odd" style, the extended part of the array is created by subtracting the reflected values from two times the edge value. - **pad_width_kwargs: + **pad_width_kwargs The keyword arguments form of ``pad_width``. One of ``pad_width`` or ``pad_width_kwargs`` must be provided. @@ -3509,17 +3824,18 @@ def pad( Examples -------- - >>> arr = xr.DataArray([5, 6, 7], coords=[("x", [0,1,2])]) - >>> arr.pad(x=(1,2), constant_values=0) + >>> arr = xr.DataArray([5, 6, 7], coords=[("x", [0, 1, 2])]) + >>> arr.pad(x=(1, 2), constant_values=0) array([0, 5, 6, 7, 0, 0]) Coordinates: * x (x) float64 nan 0.0 1.0 2.0 nan nan - >>> da = xr.DataArray([[0,1,2,3], [10,11,12,13]], - dims=["x", "y"], - coords={"x": [0,1], "y": [10, 20 ,30, 40], "z": ("x", [100, 200])} - ) + >>> da = xr.DataArray( + ... [[0, 1, 2, 3], [10, 11, 12, 13]], + ... dims=["x", "y"], + ... coords={"x": [0, 1], "y": [10, 20, 30, 40], "z": ("x", [100, 200])}, + ... ) >>> da.pad(x=1) array([[nan, nan, nan, nan], @@ -3577,18 +3893,18 @@ def idxmin( dim : str, optional Dimension over which to apply `idxmin`. This is optional for 1D arrays, but required for arrays with 2 or more dimensions. - skipna : bool or None, default None + skipna : bool or None, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for ``float``, ``complex``, and ``object`` dtypes; other dtypes either do not have a sentinel missing value (``int``) or ``skipna=True`` has not been implemented (``datetime64`` or ``timedelta64``). - fill_value : Any, default NaN + fill_value : Any, default: NaN Value to be filled in case all of the values along a dimension are null. By default this is NaN. The fill value and result are automatically converted to a compatible dtype if possible. Ignored if ``skipna`` is False. - keep_attrs : bool, default False + keep_attrs : bool, default: False If True, the attributes (``attrs``) will be copied from the original object to the new one. If False (default), the new object will be returned without attributes. @@ -3606,8 +3922,9 @@ def idxmin( Examples -------- - >>> array = xr.DataArray([0, 2, 1, 0, -2], dims="x", - ... coords={"x": ['a', 'b', 'c', 'd', 'e']}) + >>> array = xr.DataArray( + ... [0, 2, 1, 0, -2], dims="x", coords={"x": ["a", "b", "c", "d", "e"]} + ... ) >>> array.min() array(-2) @@ -3618,13 +3935,15 @@ def idxmin( array('e', dtype='>> array = xr.DataArray([[2.0, 1.0, 2.0, 0.0, -2.0], - ... [-4.0, np.NaN, 2.0, np.NaN, -2.0], - ... [np.NaN, np.NaN, 1., np.NaN, np.NaN]], - ... dims=["y", "x"], - ... coords={"y": [-1, 0, 1], - ... "x": np.arange(5.)**2} - ... ) + >>> array = xr.DataArray( + ... [ + ... [2.0, 1.0, 2.0, 0.0, -2.0], + ... [-4.0, np.NaN, 2.0, np.NaN, -2.0], + ... [np.NaN, np.NaN, 1.0, np.NaN, np.NaN], + ... ], + ... dims=["y", "x"], + ... coords={"y": [-1, 0, 1], "x": np.arange(5.0) ** 2}, + ... ) >>> array.min(dim="x") array([-2., -4., 1.]) @@ -3668,21 +3987,21 @@ def idxmax( Parameters ---------- - dim : str, optional + dim : hashable, optional Dimension over which to apply `idxmax`. This is optional for 1D arrays, but required for arrays with 2 or more dimensions. - skipna : bool or None, default None + skipna : bool or None, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for ``float``, ``complex``, and ``object`` dtypes; other dtypes either do not have a sentinel missing value (``int``) or ``skipna=True`` has not been implemented (``datetime64`` or ``timedelta64``). - fill_value : Any, default NaN + fill_value : Any, default: NaN Value to be filled in case all of the values along a dimension are null. By default this is NaN. The fill value and result are automatically converted to a compatible dtype if possible. Ignored if ``skipna`` is False. - keep_attrs : bool, default False + keep_attrs : bool, default: False If True, the attributes (``attrs``) will be copied from the original object to the new one. If False (default), the new object will be returned without attributes. @@ -3700,8 +4019,9 @@ def idxmax( Examples -------- - >>> array = xr.DataArray([0, 2, 1, 0, -2], dims="x", - ... coords={"x": ['a', 'b', 'c', 'd', 'e']}) + >>> array = xr.DataArray( + ... [0, 2, 1, 0, -2], dims="x", coords={"x": ["a", "b", "c", "d", "e"]} + ... ) >>> array.max() array(2) @@ -3712,13 +4032,15 @@ def idxmax( array('b', dtype='>> array = xr.DataArray([[2.0, 1.0, 2.0, 0.0, -2.0], - ... [-4.0, np.NaN, 2.0, np.NaN, -2.0], - ... [np.NaN, np.NaN, 1., np.NaN, np.NaN]], - ... dims=["y", "x"], - ... coords={"y": [-1, 0, 1], - ... "x": np.arange(5.)**2} - ... ) + >>> array = xr.DataArray( + ... [ + ... [2.0, 1.0, 2.0, 0.0, -2.0], + ... [-4.0, np.NaN, 2.0, np.NaN, -2.0], + ... [np.NaN, np.NaN, 1.0, np.NaN, np.NaN], + ... ], + ... dims=["y", "x"], + ... coords={"y": [-1, 0, 1], "x": np.arange(5.0) ** 2}, + ... ) >>> array.max(dim="x") array([2., 2., 1.]) @@ -3744,9 +4066,215 @@ def idxmax( keep_attrs=keep_attrs, ) + def argmin( + self, + dim: Union[Hashable, Sequence[Hashable]] = None, + axis: int = None, + keep_attrs: bool = None, + skipna: bool = None, + ) -> Union["DataArray", Dict[Hashable, "DataArray"]]: + """Index or indices of the minimum of the DataArray over one or more dimensions. + + If a sequence is passed to 'dim', then result returned as dict of DataArrays, + which can be passed directly to isel(). If a single str is passed to 'dim' then + returns a DataArray with dtype int. + + If there are multiple minima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : hashable, sequence of hashable or ..., optional + The dimensions over which to find the minimum. By default, finds minimum over + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will return a dict with indices for all + dimensions; to return a dict with all dimensions now, pass '...'. + axis : int, optional + Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments + can be supplied. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : DataArray or dict of DataArray + + See also + -------- + Variable.argmin, DataArray.idxmin + + Examples + -------- + >>> array = xr.DataArray([0, 2, -1, 3], dims="x") + >>> array.min() + + array(-1) + >>> array.argmin() + + array(2) + >>> array.argmin(...) + {'x': + array(2)} + >>> array.isel(array.argmin(...)) + + array(-1) + + >>> array = xr.DataArray( + ... [[[3, 2, 1], [3, 1, 2], [2, 1, 3]], [[1, 3, 2], [2, -5, 1], [2, 3, 1]]], + ... dims=("x", "y", "z"), + ... ) + >>> array.min(dim="x") + + array([[ 1, 2, 1], + [ 2, -5, 1], + [ 2, 1, 1]]) + Dimensions without coordinates: y, z + >>> array.argmin(dim="x") + + array([[1, 0, 0], + [1, 1, 1], + [0, 0, 1]]) + Dimensions without coordinates: y, z + >>> array.argmin(dim=["x"]) + {'x': + array([[1, 0, 0], + [1, 1, 1], + [0, 0, 1]]) + Dimensions without coordinates: y, z} + >>> array.min(dim=("x", "z")) + + array([ 1, -5, 1]) + Dimensions without coordinates: y + >>> array.argmin(dim=["x", "z"]) + {'x': + array([0, 1, 0]) + Dimensions without coordinates: y, 'z': + array([2, 1, 1]) + Dimensions without coordinates: y} + >>> array.isel(array.argmin(dim=["x", "z"])) + + array([ 1, -5, 1]) + Dimensions without coordinates: y + """ + result = self.variable.argmin(dim, axis, keep_attrs, skipna) + if isinstance(result, dict): + return {k: self._replace_maybe_drop_dims(v) for k, v in result.items()} + else: + return self._replace_maybe_drop_dims(result) + + def argmax( + self, + dim: Union[Hashable, Sequence[Hashable]] = None, + axis: int = None, + keep_attrs: bool = None, + skipna: bool = None, + ) -> Union["DataArray", Dict[Hashable, "DataArray"]]: + """Index or indices of the maximum of the DataArray over one or more dimensions. + + If a sequence is passed to 'dim', then result returned as dict of DataArrays, + which can be passed directly to isel(). If a single str is passed to 'dim' then + returns a DataArray with dtype int. + + If there are multiple maxima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : hashable, sequence of hashable or ..., optional + The dimensions over which to find the maximum. By default, finds maximum over + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will return a dict with indices for all + dimensions; to return a dict with all dimensions now, pass '...'. + axis : int, optional + Axis over which to apply `argmax`. Only one of the 'dim' and 'axis' arguments + can be supplied. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : DataArray or dict of DataArray + + See also + -------- + Variable.argmax, DataArray.idxmax + + Examples + -------- + >>> array = xr.DataArray([0, 2, -1, 3], dims="x") + >>> array.max() + + array(3) + >>> array.argmax() + + array(3) + >>> array.argmax(...) + {'x': + array(3)} + >>> array.isel(array.argmax(...)) + + array(3) + + >>> array = xr.DataArray( + ... [[[3, 2, 1], [3, 1, 2], [2, 1, 3]], [[1, 3, 2], [2, 5, 1], [2, 3, 1]]], + ... dims=("x", "y", "z"), + ... ) + >>> array.max(dim="x") + + array([[3, 3, 2], + [3, 5, 2], + [2, 3, 3]]) + Dimensions without coordinates: y, z + >>> array.argmax(dim="x") + + array([[0, 1, 1], + [0, 1, 0], + [0, 1, 0]]) + Dimensions without coordinates: y, z + >>> array.argmax(dim=["x"]) + {'x': + array([[0, 1, 1], + [0, 1, 0], + [0, 1, 0]]) + Dimensions without coordinates: y, z} + >>> array.max(dim=("x", "z")) + + array([3, 5, 3]) + Dimensions without coordinates: y + >>> array.argmax(dim=["x", "z"]) + {'x': + array([0, 1, 0]) + Dimensions without coordinates: y, 'z': + array([0, 1, 2]) + Dimensions without coordinates: y} + >>> array.isel(array.argmax(dim=["x", "z"])) + + array([3, 5, 3]) + Dimensions without coordinates: y + """ + result = self.variable.argmax(dim, axis, keep_attrs, skipna) + if isinstance(result, dict): + return {k: self._replace_maybe_drop_dims(v) for k, v in result.items()} + else: + return self._replace_maybe_drop_dims(result) + # this needs to be at the end, or mypy will confuse with `str` # https://mypy.readthedocs.io/en/latest/common_issues.html#dealing-with-conflicting-names - str = property(StringAccessor) + str = utils.UncachedAccessor(StringAccessor) # priority most be higher than Variable to properly work with binary ufuncs diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index c30cbaa6d63..175034d69de 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -27,6 +27,7 @@ TypeVar, Union, cast, + overload, ) import numpy as np @@ -57,7 +58,6 @@ ) from .coordinates import ( DatasetCoordinates, - LevelCoordinatesSource, assert_coordinate_consistent, remap_label_indexers, ) @@ -79,12 +79,12 @@ ) from .missing import get_clean_interp_index from .options import OPTIONS, _get_keep_attrs -from .pycompat import dask_array_type +from .pycompat import is_duck_dask_array from .utils import ( Default, Frozen, + HybridMappingProxy, SortedKeysDict, - _check_inplace, _default, decode_numpy_dict_values, drop_dims_from_indexers, @@ -196,7 +196,7 @@ def calculate_dimensions(variables: Mapping[Hashable, Variable]) -> Dict[Hashabl for dim, size in zip(var.dims, var.shape): if dim in scalar_vars: raise ValueError( - "dimension %r already exists as a scalar " "variable" % dim + "dimension %r already exists as a scalar variable" % dim ) if dim not in dims: dims[dim] = size @@ -285,7 +285,7 @@ def merge_indexes( new_variables = {k: v for k, v in variables.items() if k not in vars_to_remove} new_variables.update(vars_to_replace) - # update dimensions if necessary GH: 3512 + # update dimensions if necessary, GH: 3512 for k, v in new_variables.items(): if any(d in dims_to_replace for d in v.dims): new_dims = [dims_to_replace.get(d, d) for d in v.dims] @@ -329,7 +329,7 @@ def split_indexes( else: vars_to_remove.append(d) if not drop: - vars_to_create[str(d) + "_"] = Variable(d, index) + vars_to_create[str(d) + "_"] = Variable(d, index, variables[d].attrs) for d, levs in dim_levels.items(): index = variables[d].to_index() @@ -341,7 +341,7 @@ def split_indexes( if not drop: for lev in levs: idx = index.get_level_values(lev) - vars_to_create[idx.name] = Variable(d, idx) + vars_to_create[idx.name] = Variable(d, idx, variables[d].attrs) new_variables = dict(variables) for v in set(vars_to_remove): @@ -358,6 +358,90 @@ def _assert_empty(args: tuple, msg: str = "%s") -> None: raise ValueError(msg % args) +def _check_chunks_compatibility(var, chunks, preferred_chunks): + for dim in var.dims: + if dim not in chunks or (dim not in preferred_chunks): + continue + + preferred_chunks_dim = preferred_chunks.get(dim) + chunks_dim = chunks.get(dim) + + if isinstance(chunks_dim, int): + chunks_dim = (chunks_dim,) + else: + chunks_dim = chunks_dim[:-1] + + if any(s % preferred_chunks_dim for s in chunks_dim): + warnings.warn( + f"Specified Dask chunks {chunks[dim]} would separate " + f"on disks chunk shape {preferred_chunks[dim]} for dimension {dim}. " + "This could degrade performance. " + "Consider rechunking after loading instead.", + stacklevel=2, + ) + + +def _get_chunk(var, chunks): + # chunks need to be explicity computed to take correctly into accout + # backend preferred chunking + import dask.array as da + + if isinstance(var, IndexVariable): + return {} + + if isinstance(chunks, int) or (chunks == "auto"): + chunks = dict.fromkeys(var.dims, chunks) + + preferred_chunks = var.encoding.get("preferred_chunks", {}) + preferred_chunks_list = [ + preferred_chunks.get(dim, shape) for dim, shape in zip(var.dims, var.shape) + ] + + chunks_list = [ + chunks.get(dim, None) or preferred_chunks.get(dim, None) for dim in var.dims + ] + + output_chunks_list = da.core.normalize_chunks( + chunks_list, + shape=var.shape, + dtype=var.dtype, + previous_chunks=preferred_chunks_list, + ) + + output_chunks = dict(zip(var.dims, output_chunks_list)) + _check_chunks_compatibility(var, output_chunks, preferred_chunks) + + return output_chunks + + +def _maybe_chunk( + name, + var, + chunks, + token=None, + lock=None, + name_prefix="xarray-", + overwrite_encoded_chunks=False, +): + from dask.base import tokenize + + if chunks is not None: + chunks = {dim: chunks[dim] for dim in var.dims if dim in chunks} + if var.ndim: + # when rechunking by different amounts, make sure dask names change + # by provinding chunks as an input to tokenize. + # subtle bugs result otherwise. see GH3350 + token2 = tokenize(name, token if token else var._data, chunks) + name2 = f"{name_prefix}{name}-{token2}" + var = var.chunk(chunks, name=name2, lock=lock) + + if overwrite_encoded_chunks and var.chunks is not None: + var.encoding["chunks"] = tuple(x[0] for x in var.chunks) + return var + else: + return var + + def as_dataset(obj: Any) -> "Dataset": """Cast the given object to a Dataset. @@ -427,15 +511,124 @@ def __getitem__(self, key: Mapping[Hashable, Any]) -> "Dataset": class Dataset(Mapping, ImplementsDatasetReduce, DataWithCoords): """A multi-dimensional, in memory, array database. - A dataset resembles an in-memory representation of a NetCDF file, and - consists of variables, coordinates and attributes which together form a - self describing dataset. - - Dataset implements the mapping interface with keys given by variable names - and values given by DataArray objects for each variable name. - - One dimensional variables with name equal to their dimension are index - coordinates used for label based indexing. + A dataset resembles an in-memory representation of a NetCDF file, + and consists of variables, coordinates and attributes which + together form a self describing dataset. + + Dataset implements the mapping interface with keys given by variable + names and values given by DataArray objects for each variable name. + + One dimensional variables with name equal to their dimension are + index coordinates used for label based indexing. + + To load data from a file or file-like object, use the `open_dataset` + function. + + Parameters + ---------- + data_vars : dict-like, optional + A mapping from variable names to :py:class:`~xarray.DataArray` + objects, :py:class:`~xarray.Variable` objects or to tuples of + the form ``(dims, data[, attrs])`` which can be used as + arguments to create a new ``Variable``. Each dimension must + have the same length in all variables in which it appears. + + The following notations are accepted: + + - mapping {var name: DataArray} + - mapping {var name: Variable} + - mapping {var name: (dimension name, array-like)} + - mapping {var name: (tuple of dimension names, array-like)} + - mapping {dimension name: array-like} + (it will be automatically moved to coords, see below) + + Each dimension must have the same length in all variables in + which it appears. + coords : dict-like, optional + Another mapping in similar form as the `data_vars` argument, + except the each item is saved on the dataset as a "coordinate". + These variables have an associated meaning: they describe + constant/fixed/independent quantities, unlike the + varying/measured/dependent quantities that belong in + `variables`. Coordinates values may be given by 1-dimensional + arrays or scalars, in which case `dims` do not need to be + supplied: 1D arrays will be assumed to give index values along + the dimension with the same name. + + The following notations are accepted: + + - mapping {coord name: DataArray} + - mapping {coord name: Variable} + - mapping {coord name: (dimension name, array-like)} + - mapping {coord name: (tuple of dimension names, array-like)} + - mapping {dimension name: array-like} + (the dimension name is implicitly set to be the same as the + coord name) + + The last notation implies that the coord name is the same as + the dimension name. + + attrs : dict-like, optional + Global attributes to save on this dataset. + + Examples + -------- + Create data: + + >>> np.random.seed(0) + >>> temperature = 15 + 8 * np.random.randn(2, 2, 3) + >>> precipitation = 10 * np.random.rand(2, 2, 3) + >>> lon = [[-99.83, -99.32], [-99.79, -99.23]] + >>> lat = [[42.25, 42.21], [42.63, 42.59]] + >>> time = pd.date_range("2014-09-06", periods=3) + >>> reference_time = pd.Timestamp("2014-09-05") + + Initialize a dataset with multiple dimensions: + + >>> ds = xr.Dataset( + ... data_vars=dict( + ... temperature=(["x", "y", "time"], temperature), + ... precipitation=(["x", "y", "time"], precipitation), + ... ), + ... coords=dict( + ... lon=(["x", "y"], lon), + ... lat=(["x", "y"], lat), + ... time=time, + ... reference_time=reference_time, + ... ), + ... attrs=dict(description="Weather related data."), + ... ) + >>> ds + + Dimensions: (time: 3, x: 2, y: 2) + Coordinates: + lon (x, y) float64 -99.83 -99.32 -99.79 -99.23 + lat (x, y) float64 42.25 42.21 42.63 42.59 + * time (time) datetime64[ns] 2014-09-06 2014-09-07 2014-09-08 + reference_time datetime64[ns] 2014-09-05 + Dimensions without coordinates: x, y + Data variables: + temperature (x, y, time) float64 29.11 18.2 22.83 ... 18.28 16.15 26.63 + precipitation (x, y, time) float64 5.68 9.256 0.7104 ... 7.992 4.615 7.805 + Attributes: + description: Weather related data. + + Find out where the coldest temperature was and what values the + other variables had: + + >>> ds.isel(ds.temperature.argmin(...)) + + Dimensions: () + Coordinates: + lon float64 -99.32 + lat float64 42.21 + time datetime64[ns] 2014-09-08 + reference_time datetime64[ns] 2014-09-05 + Data variables: + temperature float64 7.182 + precipitation float64 8.326 + Attributes: + description: Weather related data. """ _attrs: Optional[Dict[Hashable, Any]] @@ -472,56 +665,6 @@ def __init__( coords: Mapping[Hashable, Any] = None, attrs: Mapping[Hashable, Any] = None, ): - """To load data from a file or file-like object, use the `open_dataset` - function. - - Parameters - ---------- - data_vars : dict-like, optional - A mapping from variable names to :py:class:`~xarray.DataArray` - objects, :py:class:`~xarray.Variable` objects or to tuples of the - form ``(dims, data[, attrs])`` which can be used as arguments to - create a new ``Variable``. Each dimension must have the same length - in all variables in which it appears. - - The following notations are accepted: - - - mapping {var name: DataArray} - - mapping {var name: Variable} - - mapping {var name: (dimension name, array-like)} - - mapping {var name: (tuple of dimension names, array-like)} - - mapping {dimension name: array-like} - (it will be automatically moved to coords, see below) - - Each dimension must have the same length in all variables in which - it appears. - coords : dict-like, optional - Another mapping in similar form as the `data_vars` argument, - except the each item is saved on the dataset as a "coordinate". - These variables have an associated meaning: they describe - constant/fixed/independent quantities, unlike the - varying/measured/dependent quantities that belong in `variables`. - Coordinates values may be given by 1-dimensional arrays or scalars, - in which case `dims` do not need to be supplied: 1D arrays will be - assumed to give index values along the dimension with the same - name. - - The following notations are accepted: - - - mapping {coord name: DataArray} - - mapping {coord name: Variable} - - mapping {coord name: (dimension name, array-like)} - - mapping {coord name: (tuple of dimension names, array-like)} - - mapping {dimension name: array-like} - (the dimension name is implicitly set to be the same as the coord name) - - The last notation implies that the coord name is the same as the - dimension name. - - attrs : dict-like, optional - Global attributes to save on this dataset. - """ - # TODO(shoyer): expose indexes as a public argument in __init__ if data_vars is None: @@ -576,8 +719,7 @@ def variables(self) -> Mapping[Hashable, Variable]: @property def attrs(self) -> Dict[Hashable, Any]: - """Dictionary of global attributes on this dataset - """ + """Dictionary of global attributes on this dataset""" if self._attrs is None: self._attrs = {} return self._attrs @@ -588,8 +730,7 @@ def attrs(self, value: Mapping[Hashable, Any]) -> None: @property def encoding(self) -> Dict: - """Dictionary of global encoding attributes on this dataset - """ + """Dictionary of global encoding attributes on this dataset""" if self._encoding is None: self._encoding = {} return self._encoding @@ -638,17 +779,15 @@ def load(self, **kwargs) -> "Dataset": Parameters ---------- **kwargs : dict - Additional keyword arguments passed on to ``dask.array.compute``. + Additional keyword arguments passed on to ``dask.compute``. See Also -------- - dask.array.compute + dask.compute """ # access .data to coerce everything to numpy or dask arrays lazy_data = { - k: v._data - for k, v in self.variables.items() - if isinstance(v._data, dask_array_type) + k: v._data for k, v in self.variables.items() if is_duck_dask_array(v._data) } if lazy_data: import dask.array as da @@ -780,10 +919,19 @@ def _dask_postcompute(results, info, *args): @staticmethod def _dask_postpersist(dsk, info, *args): variables = {} + # postpersist is called in both dask.optimize and dask.persist + # When persisting, we want to filter out unrelated keys for + # each Variable's task graph. + is_persist = len(dsk) == len(info) for is_dask, k, v in info: if is_dask: func, args2 = v - result = func(dsk, *args2) + if is_persist: + name = args2[1][0] + dsk2 = {k: v for k, v in dsk.items() if k[0] == name} + else: + dsk2 = dsk + result = func(dsk2, *args2) else: result = v variables[k] = result @@ -803,23 +951,20 @@ def compute(self, **kwargs) -> "Dataset": Parameters ---------- **kwargs : dict - Additional keyword arguments passed on to ``dask.array.compute``. + Additional keyword arguments passed on to ``dask.compute``. See Also -------- - dask.array.compute + dask.compute """ new = self.copy(deep=False) return new.load(**kwargs) def _persist_inplace(self, **kwargs) -> "Dataset": - """Persist all Dask arrays in memory - """ + """Persist all Dask arrays in memory""" # access .data to coerce everything to numpy or dask arrays lazy_data = { - k: v._data - for k, v in self.variables.items() - if isinstance(v._data, dask_array_type) + k: v._data for k, v in self.variables.items() if is_duck_dask_array(v._data) } if lazy_data: import dask @@ -833,7 +978,7 @@ def _persist_inplace(self, **kwargs) -> "Dataset": return self def persist(self, **kwargs) -> "Dataset": - """ Trigger computation, keeping data as dask arrays + """Trigger computation, keeping data as dask arrays This operation can be used to trigger computation on underlying dask arrays, similar to ``.compute()`` or ``.load()``. However this @@ -1017,16 +1162,17 @@ def copy(self, deep: bool = False, data: Mapping = None) -> "Dataset": >>> da = xr.DataArray(np.random.randn(2, 3)) >>> ds = xr.Dataset( - ... {"foo": da, "bar": ("x", [-1, 2])}, coords={"x": ["one", "two"]}, + ... {"foo": da, "bar": ("x", [-1, 2])}, + ... coords={"x": ["one", "two"]}, ... ) >>> ds.copy() Dimensions: (dim_0: 2, dim_1: 3, x: 2) Coordinates: - * x (x) >> ds_0 = ds.copy(deep=False) @@ -1035,33 +1181,31 @@ def copy(self, deep: bool = False, data: Mapping = None) -> "Dataset": Dimensions: (dim_0: 2, dim_1: 3, x: 2) Coordinates: - * x (x) >> ds Dimensions: (dim_0: 2, dim_1: 3, x: 2) Coordinates: - * x (x) >> ds.copy( - ... data={"foo": np.arange(6).reshape(2, 3), "bar": ["a", "b"]} - ... ) + >>> ds.copy(data={"foo": np.arange(6).reshape(2, 3), "bar": ["a", "b"]}) Dimensions: (dim_0: 2, dim_1: 3, x: 2) Coordinates: - * x (x) "Dataset": Dimensions: (dim_0: 2, dim_1: 3, x: 2) Coordinates: - * x (x) "Dataset": dims = {k: self.dims[k] for k in needed_dims} - for k in self._coord_names: + # preserves ordering of coordinates + for k in self._variables: + if k not in self._coord_names: + continue + if set(self.variables[k].dims) <= needed_dims: variables[k] = self._variables[k] coord_names.add(k) @@ -1159,8 +1307,7 @@ def _copy_listed(self, names: Iterable[Hashable]) -> "Dataset": return self._replace(variables, coord_names, dims, indexes=indexes) def _construct_dataarray(self, name: Hashable) -> "DataArray": - """Construct a DataArray by indexing this dataset - """ + """Construct a DataArray by indexing this dataset""" from .dataarray import DataArray try: @@ -1173,8 +1320,9 @@ def _construct_dataarray(self, name: Hashable) -> "DataArray": needed_dims = set(variable.dims) coords: Dict[Hashable, Variable] = {} - for k in self.coords: - if set(self.variables[k].dims) <= needed_dims: + # preserve ordering + for k in self._variables: + if k in self._coord_names and set(self.variables[k].dims) <= needed_dims: coords[k] = self.variables[k] if self._indexes is None: @@ -1193,21 +1341,22 @@ def __deepcopy__(self, memo=None) -> "Dataset": return self.copy(deep=True) @property - def _attr_sources(self) -> List[Mapping[Hashable, Any]]: - """List of places to look-up items for attribute-style access - """ - return self._item_sources + [self.attrs] + def _attr_sources(self) -> Iterable[Mapping[Hashable, Any]]: + """Places to look-up items for attribute-style access""" + yield from self._item_sources + yield self.attrs @property - def _item_sources(self) -> List[Mapping[Hashable, Any]]: - """List of places to look-up items for key-completion - """ - return [ - self.data_vars, - self.coords, - {d: self[d] for d in self.dims}, - LevelCoordinatesSource(self), - ] + def _item_sources(self) -> Iterable[Mapping[Hashable, Any]]: + """Places to look-up items for key-completion""" + yield self.data_vars + yield HybridMappingProxy(keys=self._coord_names, mapping=self.coords) + + # virtual coordinates + yield HybridMappingProxy(keys=self.dims, mapping=self) + + # uses empty dict -- everything here can already be found in self.coords. + yield HybridMappingProxy(keys=self._level_coords, mapping={}) def __contains__(self, key: object) -> bool: """The 'in' operator will return true or false depending on whether @@ -1243,13 +1392,25 @@ def loc(self) -> _LocIndexer: """ return _LocIndexer(self) - def __getitem__(self, key: Any) -> "Union[DataArray, Dataset]": + # FIXME https://github.com/python/mypy/issues/7328 + @overload + def __getitem__(self, key: Mapping) -> "Dataset": # type: ignore + ... + + @overload + def __getitem__(self, key: Hashable) -> "DataArray": # type: ignore + ... + + @overload + def __getitem__(self, key: Any) -> "Dataset": + ... + + def __getitem__(self, key): """Access variables or coordinates this dataset as a :py:class:`~xarray.DataArray`. Indexing with a list of names will return a new ``Dataset`` object. """ - # TODO(shoyer): type this properly: https://github.com/python/mypy/issues/7328 if utils.is_dict_like(key): return self.isel(**cast(Mapping, key)) @@ -1271,14 +1432,13 @@ def __setitem__(self, key: Hashable, value) -> None: """ if utils.is_dict_like(key): raise NotImplementedError( - "cannot yet use a dictionary as a key " "to set Dataset values" + "cannot yet use a dictionary as a key to set Dataset values" ) self.update({key: value}) def __delitem__(self, key: Hashable) -> None: - """Remove a variable from this dataset. - """ + """Remove a variable from this dataset.""" del self._variables[key] self._coord_names.discard(key) if key in self.indexes: @@ -1291,8 +1451,7 @@ def __delitem__(self, key: Hashable) -> None: __hash__ = None # type: ignore def _all_compat(self, other: "Dataset", compat_str: str) -> bool: - """Helper function for equals and identical - """ + """Helper function for equals and identical""" # some stores (e.g., scipy) do not seem to preserve order, so don't # require matching order for equality @@ -1359,8 +1518,7 @@ def identical(self, other: "Dataset") -> bool: @property def indexes(self) -> Indexes: - """Mapping of pandas.Index objects used for label based indexing - """ + """Mapping of pandas.Index objects used for label based indexing""" if self._indexes is None: self._indexes = default_indexes(self._variables, self._dims) return Indexes(self._indexes) @@ -1374,18 +1532,15 @@ def coords(self) -> DatasetCoordinates: @property def data_vars(self) -> DataVariables: - """Dictionary of DataArray objects corresponding to data variables - """ + """Dictionary of DataArray objects corresponding to data variables""" return DataVariables(self) - def set_coords( - self, names: "Union[Hashable, Iterable[Hashable]]", inplace: bool = None - ) -> "Dataset": + def set_coords(self, names: "Union[Hashable, Iterable[Hashable]]") -> "Dataset": """Given names of one or more variables, set them as coordinates Parameters ---------- - names : hashable or iterable of hashables + names : hashable or iterable of hashable Name(s) of variables in this dataset to convert into coordinates. Returns @@ -1400,7 +1555,6 @@ def set_coords( # DataFrame.set_index? # nb. check in self._variables, not self.data_vars to insure that the # operation is idempotent - _check_inplace(inplace) if isinstance(names, str) or not isinstance(names, Iterable): names = [names] else: @@ -1414,13 +1568,12 @@ def reset_coords( self, names: "Union[Hashable, Iterable[Hashable], None]" = None, drop: bool = False, - inplace: bool = None, ) -> "Dataset": """Given names of coordinates, reset them to become variables Parameters ---------- - names : hashable or iterable of hashables, optional + names : hashable or iterable of hashable, optional Name(s) of non-index coordinates in this dataset to reset into variables. By default, all non-index coordinates are reset. drop : bool, optional @@ -1431,7 +1584,6 @@ def reset_coords( ------- Dataset """ - _check_inplace(inplace) if names is None: names = self._coord_names - set(self.dims) else: @@ -1453,8 +1605,7 @@ def reset_coords( return obj def dump_to_store(self, store: "AbstractDataStore", **kwargs) -> None: - """Store dataset contents to a backends.*DataStore object. - """ + """Store dataset contents to a backends.*DataStore object.""" from ..backends.api import dump_to_store # TODO: rename and/or cleanup this method to make it more consistent @@ -1477,18 +1628,18 @@ def to_netcdf( Parameters ---------- - path : str, Path or file-like object, optional + path : str, Path or file-like, optional Path to which to save this dataset. File-like objects are only supported by the scipy engine. If no path is provided, this function returns the resulting netCDF file as bytes; in this case, we need to use scipy, which does not support netCDF version 4 (the default format becomes NETCDF3_64BIT). - mode : {'w', 'a'}, optional + mode : {"w", "a"}, default: "w" Write ('w') or append ('a') mode. If mode='w', any existing file at this location will be overwritten. If mode='a', existing variables will be overwritten. - format : {'NETCDF4', 'NETCDF4_CLASSIC', 'NETCDF3_64BIT', - 'NETCDF3_CLASSIC'}, optional + format : {"NETCDF4", "NETCDF4_CLASSIC", "NETCDF3_64BIT", \ + "NETCDF3_CLASSIC"}, optional File format for the resulting netCDF file: * NETCDF4: Data is stored in an HDF5 file, using netCDF4 API @@ -1511,19 +1662,19 @@ def to_netcdf( group : str, optional Path to the netCDF4 group in the given file to open (only works for format='NETCDF4'). The group(s) will be created if necessary. - engine : {'netcdf4', 'scipy', 'h5netcdf'}, optional + engine : {"netcdf4", "scipy", "h5netcdf"}, optional Engine to use when writing netCDF files. If not provided, the default engine is chosen based on available dependencies, with a preference for 'netcdf4' if writing to a file on disk. encoding : dict, optional Nested dictionary with variable names as keys and dictionaries of variable specific encodings as values, e.g., - ``{'my_variable': {'dtype': 'int16', 'scale_factor': 0.1, - 'zlib': True}, ...}`` + ``{"my_variable": {"dtype": "int16", "scale_factor": 0.1, + "zlib": True}, ...}`` The `h5netcdf` engine supports both the NetCDF4-style compression - encoding parameters ``{'zlib': True, 'complevel': 9}`` and the h5py - ones ``{'compression': 'gzip', 'compression_opts': 9}``. + encoding parameters ``{"zlib": True, "complevel": 9}`` and the h5py + ones ``{"compression": "gzip", "compression_opts": 9}``. This allows using any compression plugin installed in the HDF5 library, e.g. LZF. @@ -1531,14 +1682,14 @@ def to_netcdf( Dimension(s) that should be serialized as unlimited dimensions. By default, no dimensions are treated as unlimited dimensions. Note that unlimited_dims may also be set via - ``dataset.encoding['unlimited_dims']``. - compute: boolean + ``dataset.encoding["unlimited_dims"]``. + compute: bool, default: True If true compute immediately, otherwise return a ``dask.delayed.Delayed`` object that can be computed later. - invalid_netcdf: boolean - Only valid along with engine='h5netcdf'. If True, allow writing - hdf5 files which are valid netcdf as described in - https://github.com/shoyer/h5netcdf. Default: False. + invalid_netcdf: bool, default: False + Only valid along with ``engine="h5netcdf"``. If True, allow writing + hdf5 files which are invalid netcdf as described in + https://github.com/shoyer/h5netcdf. """ if encoding is None: encoding = {} @@ -1560,6 +1711,7 @@ def to_netcdf( def to_zarr( self, store: Union[MutableMapping, str, Path] = None, + chunk_store: Union[MutableMapping, str, Path] = None, mode: str = None, synchronizer=None, group: str = None, @@ -1567,6 +1719,7 @@ def to_zarr( compute: bool = True, consolidated: bool = False, append_dim: Hashable = None, + region: Mapping[str, slice] = None, ) -> "ZarrStore": """Write dataset contents to a zarr group. @@ -1578,56 +1731,75 @@ def to_zarr( ---------- store : MutableMapping, str or Path, optional Store or path to directory in file system. - mode : {'w', 'w-', 'a', None} - Persistence mode: 'w' means create (overwrite if exists); - 'w-' means create (fail if exists); - 'a' means append (create if does not exist). + chunk_store : MutableMapping, str or Path, optional + Store or path to directory in file system only for Zarr array chunks. + Requires zarr-python v2.4.0 or later. + mode : {"w", "w-", "a", None}, optional + Persistence mode: "w" means create (overwrite if exists); + "w-" means create (fail if exists); + "a" means override existing variables (create if does not exist). If ``append_dim`` is set, ``mode`` can be omitted as it is - internally set to ``'a'``. Otherwise, ``mode`` will default to + internally set to ``"a"``. Otherwise, ``mode`` will default to `w-` if not set. synchronizer : object, optional - Array synchronizer + Zarr array synchronizer. group : str, optional Group path. (a.k.a. `path` in zarr terminology.) encoding : dict, optional Nested dictionary with variable names as keys and dictionaries of variable specific encodings as values, e.g., - ``{'my_variable': {'dtype': 'int16', 'scale_factor': 0.1,}, ...}`` + ``{"my_variable": {"dtype": "int16", "scale_factor": 0.1,}, ...}`` compute: bool, optional - If True compute immediately, otherwise return a - ``dask.delayed.Delayed`` object that can be computed later. + If True write array data immediately, otherwise return a + ``dask.delayed.Delayed`` object that can be computed to write + array data later. Metadata is always updated eagerly. consolidated: bool, optional If True, apply zarr's `consolidate_metadata` function to the store - after writing. + after writing metadata. append_dim: hashable, optional - If set, the dimension on which the data will be appended. + If set, the dimension along which the data will be appended. All + other dimensions on overriden variables must remain the same size. + region: dict, optional + Optional mapping from dimension names to integer slices along + dataset dimensions to indicate the region of existing zarr array(s) + in which to write this dataset's data. For example, + ``{'x': slice(0, 1000), 'y': slice(10000, 11000)}`` would indicate + that values should be written to the region ``0:1000`` along ``x`` + and ``10000:11000`` along ``y``. + + Two restrictions apply to the use of ``region``: + + - If ``region`` is set, _all_ variables in a dataset must have at + least one dimension in common with the region. Other variables + should be written in a separate call to ``to_zarr()``. + - Dimensions cannot be included in both ``region`` and + ``append_dim`` at the same time. To create empty arrays to fill + in with ``region``, use a separate call to ``to_zarr()`` with + ``compute=False``. See "Appending to existing Zarr stores" in + the reference documentation for full details. References ---------- https://zarr.readthedocs.io/ + + Notes + ----- + Zarr chunking behavior: + If chunks are found in the encoding argument or attribute + corresponding to any DataArray, those chunks are used. + If a DataArray is a dask array, it is written with those chunks. + If not other chunks are found, Zarr uses its own heuristics to + choose automatic chunk sizes. """ + from ..backends.api import to_zarr + if encoding is None: encoding = {} - if (mode == "a") or (append_dim is not None): - if mode is None: - mode = "a" - elif mode != "a": - raise ValueError( - "append_dim was set along with mode='{}', either set " - "mode='a' or don't set it.".format(mode) - ) - elif mode is None: - mode = "w-" - if mode not in ["w", "w-", "a"]: - # TODO: figure out how to handle 'r+' - raise ValueError( - "The only supported options for mode are 'w'," "'w-' and 'a'." - ) - from ..backends.api import to_zarr return to_zarr( self, store=store, + chunk_store=chunk_store, mode=mode, synchronizer=synchronizer, group=group, @@ -1635,6 +1807,7 @@ def to_zarr( compute=compute, consolidated=consolidated, append_dim=append_dim, + region=region, ) def __repr__(self) -> str: @@ -1651,7 +1824,8 @@ def info(self, buf=None) -> None: Parameters ---------- - buf : writable buffer, defaults to sys.stdout + buf : file-like, default: sys.stdout + writable buffer See Also -------- @@ -1699,8 +1873,10 @@ def chunks(self) -> Mapping[Hashable, Tuple[int, ...]]: def chunk( self, chunks: Union[ - None, Number, Mapping[Hashable, Union[None, Number, Tuple[Number, ...]]] - ] = None, + Number, + str, + Mapping[Hashable, Union[None, Number, str, Tuple[Number, ...]]], + ] = {}, # {} even though it's technically unsafe, is being used intentionally here (#4667) name_prefix: str = "xarray-", token: str = None, lock: bool = False, @@ -1717,9 +1893,9 @@ def chunk( Parameters ---------- - chunks : int or mapping, optional + chunks : int, 'auto' or mapping, optional Chunk sizes along each dimension, e.g., ``5`` or - ``{'x': 5, 'y': 5}``. + ``{"x": 5, "y": 5}``. name_prefix : str, optional Prefix for the name of any new dask arrays. token : str, optional @@ -1732,45 +1908,33 @@ def chunk( ------- chunked : xarray.Dataset """ - from dask.base import tokenize + if chunks is None: + warnings.warn( + "None value for 'chunks' is deprecated. " + "It will raise an error in the future. Use instead '{}'", + category=FutureWarning, + ) + chunks = {} - if isinstance(chunks, Number): + if isinstance(chunks, (Number, str)): chunks = dict.fromkeys(self.dims, chunks) - if chunks is not None: - bad_dims = chunks.keys() - self.dims.keys() - if bad_dims: - raise ValueError( - "some chunks keys are not dimensions on this " - "object: %s" % bad_dims - ) - - def selkeys(dict_, keys): - if dict_ is None: - return None - return {d: dict_[d] for d in keys if d in dict_} - - def maybe_chunk(name, var, chunks): - chunks = selkeys(chunks, var.dims) - if not chunks: - chunks = None - if var.ndim > 0: - # when rechunking by different amounts, make sure dask names change - # by provinding chunks as an input to tokenize. - # subtle bugs result otherwise. see GH3350 - token2 = tokenize(name, token if token else var._data, chunks) - name2 = f"{name_prefix}{name}-{token2}" - return var.chunk(chunks, name=name2, lock=lock) - else: - return var + bad_dims = chunks.keys() - self.dims.keys() + if bad_dims: + raise ValueError( + "some chunks keys are not dimensions on this " "object: %s" % bad_dims + ) - variables = {k: maybe_chunk(k, v, chunks) for k, v in self.variables.items()} + variables = { + k: _maybe_chunk(k, v, chunks, token, lock, name_prefix) + for k, v in self.variables.items() + } return self._replace(variables) def _validate_indexers( - self, indexers: Mapping[Hashable, Any], missing_dims: str = "raise", + self, indexers: Mapping[Hashable, Any], missing_dims: str = "raise" ) -> Iterator[Tuple[Hashable, Union[int, slice, np.ndarray, Variable]]]: - """ Here we make sure + """Here we make sure + indexer has a valid keys + indexer is in a valid data type + string indexers are cast to the appropriate date type if the @@ -1812,8 +1976,7 @@ def _validate_indexers( def _validate_interp_indexers( self, indexers: Mapping[Hashable, Any] ) -> Iterator[Tuple[Hashable, Variable]]: - """Variant of _validate_indexers to be used for interpolation - """ + """Variant of _validate_indexers to be used for interpolation""" for k, v in self._validate_indexers(indexers): if isinstance(v, Variable): if v.ndim == 1: @@ -1896,10 +2059,10 @@ def isel( drop : bool, optional If ``drop=True``, drop coordinates variables indexed by integers instead of making them scalar. - missing_dims : {"raise", "warn", "ignore"}, default "raise" + missing_dims : {"raise", "warn", "ignore"}, default: "raise" What to do if dimensions that should be selected from are not present in the Dataset: - - "exception": raise an exception + - "raise": raise an exception - "warning": raise a warning, and ignore the missing dimensions - "ignore": ignore the missing dimensions **indexers_kwargs : {dim: indexer, ...}, optional @@ -2038,7 +2201,7 @@ def sel( If DataArrays are passed as indexers, xarray-style indexing will be carried out. See :ref:`indexing` for the details. One of indexers or indexers_kwargs must be provided. - method : {None, 'nearest', 'pad'/'ffill', 'backfill'/'bfill'}, optional + method : {None, "nearest", "pad", "ffill", "backfill", "bfill"}, optional Method to use for inexact matches: * None (default): only exact matches @@ -2191,7 +2354,7 @@ def thin( A dict with keys matching dimensions and integer values `n` or a single integer `n` applied over all dimensions. One of indexers or indexers_kwargs must be provided. - ``**indexers_kwargs`` : {dim: n, ...}, optional + **indexers_kwargs : {dim: n, ...}, optional The keyword arguments form of ``indexers``. One of indexers or indexers_kwargs must be provided. @@ -2271,7 +2434,7 @@ def reindex_like( other object need not be the same as the indexes on this dataset. Any mis-matched index values will be filled in with NaN, and any mis-matched dimension names will simply be ignored. - method : {None, 'nearest', 'pad'/'ffill', 'backfill'/'bfill'}, optional + method : {None, "nearest", "pad", "ffill", "backfill", "bfill"}, optional Method to use for filling index values from other not found in this dataset: @@ -2288,8 +2451,9 @@ def reindex_like( ``copy=False`` and reindexing is unnecessary, or can be performed with only slice operations, then the output may share memory with the input. In either case, a new xarray object is always returned. - fill_value : scalar, optional - Value to use for newly missing values + fill_value : scalar or dict-like, optional + Value to use for newly missing values. If a dict-like maps + variable names to fill values. Returns ------- @@ -2325,13 +2489,13 @@ def reindex( Parameters ---------- - indexers : dict. optional + indexers : dict, optional Dictionary with keys given by dimension names and values given by arrays of coordinates tick labels. Any mis-matched coordinate values will be filled in with NaN, and any mis-matched dimension names will simply be ignored. One of indexers or indexers_kwargs must be provided. - method : {None, 'nearest', 'pad'/'ffill', 'backfill'/'bfill'}, optional + method : {None, "nearest", "pad", "ffill", "backfill", "bfill"}, optional Method to use for filling index values in ``indexers`` not found in this dataset: @@ -2348,9 +2512,11 @@ def reindex( ``copy=False`` and reindexing is unnecessary, or can be performed with only slice operations, then the output may share memory with the input. In either case, a new xarray object is always returned. - fill_value : scalar, optional - Value to use for newly missing values - sparse: use sparse-array. By default, False + fill_value : scalar or dict-like, optional + Value to use for newly missing values. If a dict-like, + maps variable names (including coordinates) to fill values. + sparse : bool, default: False + use sparse-array. **indexers_kwargs : {dim: indexer, ...}, optional Keyword arguments in the same form as ``indexers``. One of indexers or indexers_kwargs must be provided. @@ -2384,10 +2550,10 @@ def reindex( Dimensions: (station: 4) Coordinates: - * station (station) >> x.indexes station: Index(['boston', 'nyc', 'seattle', 'denver'], dtype='object', name='station') @@ -2399,10 +2565,10 @@ def reindex( Dimensions: (station: 4) Coordinates: - * station (station) object 'boston' 'austin' 'seattle' 'lincoln' + * station (station) Dimensions: (station: 4) Coordinates: - * station (station) object 'boston' 'austin' 'seattle' 'lincoln' + * station (station) >> x.reindex( + ... {"station": new_index}, fill_value={"temperature": 0, "pressure": 100} + ... ) + + Dimensions: (station: 4) + Coordinates: + * station (station) Dimensions: (time: 6) Coordinates: - * time (time) datetime64[ns] 2019-01-01 2019-01-02 ... 2019-01-06 + * time (time) datetime64[ns] 2019-01-01 2019-01-02 ... 2019-01-06 Data variables: temperature (time) float64 15.57 12.77 nan 0.3081 16.59 15.12 - pressure (time) float64 103.4 122.7 452.0 444.0 399.2 486.0 + pressure (time) float64 481.8 191.7 395.9 264.4 284.0 462.8 Suppose we decide to expand the dataset to cover a wider date range. @@ -2453,10 +2632,10 @@ def reindex( Dimensions: (time: 10) Coordinates: - * time (time) datetime64[ns] 2018-12-29 2018-12-30 ... 2019-01-07 + * time (time) datetime64[ns] 2018-12-29 2018-12-30 ... 2019-01-07 Data variables: temperature (time) float64 nan nan nan 15.57 ... 0.3081 16.59 15.12 nan - pressure (time) float64 nan nan nan 103.4 ... 444.0 399.2 486.0 nan + pressure (time) float64 nan nan nan 481.8 ... 264.4 284.0 462.8 nan The index entries that did not have a value in the original data frame (for example, `2018-12-29`) are by default filled with NaN. If desired, we can fill in the missing values using one of several options. @@ -2469,10 +2648,10 @@ def reindex( Dimensions: (time: 10) Coordinates: - * time (time) datetime64[ns] 2018-12-29 2018-12-30 ... 2019-01-07 + * time (time) datetime64[ns] 2018-12-29 2018-12-30 ... 2019-01-07 Data variables: temperature (time) float64 15.57 15.57 15.57 15.57 ... 16.59 15.12 nan - pressure (time) float64 103.4 103.4 103.4 103.4 ... 399.2 486.0 nan + pressure (time) float64 481.8 481.8 481.8 481.8 ... 284.0 462.8 nan Please note that the `NaN` value present in the original dataset (at index value `2019-01-03`) will not be filled by any of the value propagation schemes. @@ -2481,18 +2660,18 @@ def reindex( Dimensions: (time: 1) Coordinates: - * time (time) datetime64[ns] 2019-01-03 + * time (time) datetime64[ns] 2019-01-03 Data variables: temperature (time) float64 nan - pressure (time) float64 452.0 + pressure (time) float64 395.9 >>> x3.where(x3.temperature.isnull(), drop=True) Dimensions: (time: 2) Coordinates: - * time (time) datetime64[ns] 2019-01-03 2019-01-07 + * time (time) datetime64[ns] 2019-01-03 2019-01-07 Data variables: temperature (time) float64 nan nan - pressure (time) float64 452.0 nan + pressure (time) float64 395.9 nan This is because filling while reindexing does not look at dataset values, but only compares the original and desired indexes. If you do want to fill in the `NaN` values present in the @@ -2551,25 +2730,25 @@ def interp( kwargs: Mapping[str, Any] = None, **coords_kwargs: Any, ) -> "Dataset": - """ Multidimensional interpolation of Dataset. + """Multidimensional interpolation of Dataset. Parameters ---------- coords : dict, optional Mapping from dimension names to the new coordinates. New coordinate can be a scalar, array-like or DataArray. - If DataArrays are passed as new coordates, their dimensions are - used for the broadcasting. - method: string, optional. - {'linear', 'nearest'} for multidimensional array, - {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} - for 1-dimensional array. 'linear' is used by default. - assume_sorted: boolean, optional + If DataArrays are passed as new coordinates, their dimensions are + used for the broadcasting. Missing values are skipped. + method : str, optional + {"linear", "nearest"} for multidimensional array, + {"linear", "nearest", "zero", "slinear", "quadratic", "cubic"} + for 1-dimensional array. "linear" is used by default. + assume_sorted : bool, optional If False, values of coordinates that are interpolated over can be in any order and they are sorted first. If True, interpolated coordinates are assumed to be an array of monotonically increasing values. - kwargs: dictionary, optional + kwargs: dict, optional Additional keyword arguments passed to scipy's interpolator. Valid options and their behavior depend on if 1-dimensional or multi-dimensional interpolation is used. @@ -2579,7 +2758,7 @@ def interp( Returns ------- - interpolated: xr.Dataset + interpolated : Dataset New dataset on the new coordinates. Notes @@ -2590,6 +2769,80 @@ def interp( -------- scipy.interpolate.interp1d scipy.interpolate.interpn + + Examples + -------- + >>> ds = xr.Dataset( + ... data_vars={ + ... "a": ("x", [5, 7, 4]), + ... "b": ( + ... ("x", "y"), + ... [[1, 4, 2, 9], [2, 7, 6, np.nan], [6, np.nan, 5, 8]], + ... ), + ... }, + ... coords={"x": [0, 1, 2], "y": [10, 12, 14, 16]}, + ... ) + >>> ds + + Dimensions: (x: 3, y: 4) + Coordinates: + * x (x) int64 0 1 2 + * y (y) int64 10 12 14 16 + Data variables: + a (x) int64 5 7 4 + b (x, y) float64 1.0 4.0 2.0 9.0 2.0 7.0 6.0 nan 6.0 nan 5.0 8.0 + + 1D interpolation with the default method (linear): + + >>> ds.interp(x=[0, 0.75, 1.25, 1.75]) + + Dimensions: (x: 4, y: 4) + Coordinates: + * y (y) int64 10 12 14 16 + * x (x) float64 0.0 0.75 1.25 1.75 + Data variables: + a (x) float64 5.0 6.5 6.25 4.75 + b (x, y) float64 1.0 4.0 2.0 nan 1.75 6.25 ... nan 5.0 nan 5.25 nan + + 1D interpolation with a different method: + + >>> ds.interp(x=[0, 0.75, 1.25, 1.75], method="nearest") + + Dimensions: (x: 4, y: 4) + Coordinates: + * y (y) int64 10 12 14 16 + * x (x) float64 0.0 0.75 1.25 1.75 + Data variables: + a (x) float64 5.0 7.0 7.0 4.0 + b (x, y) float64 1.0 4.0 2.0 9.0 2.0 7.0 ... 6.0 nan 6.0 nan 5.0 8.0 + + 1D extrapolation: + + >>> ds.interp( + ... x=[1, 1.5, 2.5, 3.5], + ... method="linear", + ... kwargs={"fill_value": "extrapolate"}, + ... ) + + Dimensions: (x: 4, y: 4) + Coordinates: + * y (y) int64 10 12 14 16 + * x (x) float64 1.0 1.5 2.5 3.5 + Data variables: + a (x) float64 7.0 5.5 2.5 -0.5 + b (x, y) float64 2.0 7.0 6.0 nan 4.0 nan ... 4.5 nan 12.0 nan 3.5 nan + + 2D interpolation: + + >>> ds.interp(x=[0, 0.75, 1.25, 1.75], y=[11, 13, 15], method="linear") + + Dimensions: (x: 4, y: 3) + Coordinates: + * x (x) float64 0.0 0.75 1.25 1.75 + * y (y) int64 11 13 15 + Data variables: + a (x) float64 5.0 6.5 6.25 4.75 + b (x, y) float64 2.5 3.0 nan 4.0 5.625 nan nan nan nan nan nan nan """ from . import missing @@ -2687,22 +2940,22 @@ def interp_like( other : Dataset or DataArray Object with an 'indexes' attribute giving a mapping from dimension names to an 1d array-like, which provides coordinates upon - which to index the variables in this dataset. - method: string, optional. - {'linear', 'nearest'} for multidimensional array, - {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} + which to index the variables in this dataset. Missing values are skipped. + method : str, optional + {"linear", "nearest"} for multidimensional array, + {"linear", "nearest", "zero", "slinear", "quadratic", "cubic"} for 1-dimensional array. 'linear' is used by default. - assume_sorted: boolean, optional + assume_sorted : bool, optional If False, values of coordinates that are interpolated over can be in any order and they are sorted first. If True, interpolated coordinates are assumed to be an array of monotonically increasing values. - kwargs: dictionary, optional + kwargs: dict, optional Additional keyword passed to scipy's interpolator. Returns ------- - interpolated: xr.Dataset + interpolated : Dataset Another dataset by interpolating this dataset's data along the coordinates of the other object. @@ -2779,7 +3032,6 @@ def _rename_all(self, name_dict, dims_dict): def rename( self, name_dict: Mapping[Hashable, Hashable] = None, - inplace: bool = None, **names: Hashable, ) -> "Dataset": """Returns a new object with renamed variables and dimensions. @@ -2789,7 +3041,7 @@ def rename( name_dict : dict-like, optional Dictionary whose keys are current variable or dimension names and whose values are the desired names. - **names, optional + **names : optional Keyword form of ``name_dict``. One of name_dict or names must be provided. @@ -2805,7 +3057,6 @@ def rename( Dataset.rename_dims DataArray.rename """ - _check_inplace(inplace) name_dict = either_dict_or_kwargs(name_dict, names, "rename") for k in name_dict.keys(): if k not in self and k not in self.dims: @@ -2831,7 +3082,7 @@ def rename_dims( Dictionary whose keys are current dimension names and whose values are the desired names. The desired names must not be the name of an existing dimension or Variable in the Dataset. - **dims, optional + **dims : optional Keyword form of ``dims_dict``. One of dims_dict or dims must be provided. @@ -2875,7 +3126,7 @@ def rename_vars( name_dict : dict-like, optional Dictionary whose keys are current variable or coordinate names and whose values are the desired names. - **names, optional + **names : optional Keyword form of ``name_dict``. One of name_dict or names must be provided. @@ -2903,9 +3154,7 @@ def rename_vars( ) return self._replace(variables, coord_names, dims=dims, indexes=indexes) - def swap_dims( - self, dims_dict: Mapping[Hashable, Hashable], inplace: bool = None - ) -> "Dataset": + def swap_dims(self, dims_dict: Mapping[Hashable, Hashable]) -> "Dataset": """Returns a new object with swapped dimensions. Parameters @@ -2964,7 +3213,6 @@ def swap_dims( """ # TODO: deprecate this method in favor of a (less confusing) # rename_dims() method that only renames dimensions. - _check_inplace(inplace) for k, v in dims_dict.items(): if k not in self.dims: raise ValueError( @@ -3025,13 +3273,13 @@ def expand_dims( and the values are either integers (giving the length of the new dimensions) or array-like (giving the coordinates of the new dimensions). - axis : integer, sequence of integers, or None + axis : int, sequence of int, or None Axis position(s) where new axis is to be inserted (position(s) on the result array). If a list (or tuple) of integers is passed, multiple axes are inserted. In this case, dim arguments should be same length list. If axis=None is passed, all the axes will be inserted to the start of the result array. - **dim_kwargs : int or sequence/ndarray + **dim_kwargs : int or sequence or ndarray The keywords are arbitrary dimensions being inserted and the values are either the lengths of the new dims (if int is given), or their coordinates. Note, this is an alternative to passing a dict to the @@ -3139,7 +3387,6 @@ def set_index( self, indexes: Mapping[Hashable, Union[Hashable, Sequence[Hashable]]] = None, append: bool = False, - inplace: bool = None, **indexes_kwargs: Union[Hashable, Sequence[Hashable]], ) -> "Dataset": """Set Dataset (multi-)indexes using one or more existing coordinates @@ -3154,7 +3401,7 @@ def set_index( append : bool, optional If True, append the supplied index(es) to the existing index(es). Otherwise replace the existing index(es) (default). - **indexes_kwargs: optional + **indexes_kwargs : optional The keyword arguments form of ``indexes``. One of indexes or indexes_kwargs must be provided. @@ -3194,7 +3441,6 @@ def set_index( Dataset.reset_index Dataset.swap_dims """ - _check_inplace(inplace) indexes = either_dict_or_kwargs(indexes, indexes_kwargs, "set_index") variables, coord_names = merge_indexes( indexes, self._variables, self._coord_names, append=append @@ -3205,7 +3451,6 @@ def reset_index( self, dims_or_levels: Union[Hashable, Sequence[Hashable]], drop: bool = False, - inplace: bool = None, ) -> "Dataset": """Reset the specified index(es) or multi-index level(s). @@ -3227,7 +3472,6 @@ def reset_index( -------- Dataset.set_index """ - _check_inplace(inplace) variables, coord_names = split_indexes( dims_or_levels, self._variables, @@ -3240,7 +3484,6 @@ def reset_index( def reorder_levels( self, dim_order: Mapping[Hashable, Sequence[int]] = None, - inplace: bool = None, **dim_order_kwargs: Sequence[int], ) -> "Dataset": """Rearrange index levels using input order. @@ -3251,7 +3494,7 @@ def reorder_levels( Mapping from names matching dimensions and values given by lists representing new level orders. Every given dimension must have a multi-index. - **dim_order_kwargs: optional + **dim_order_kwargs : optional The keyword arguments form of ``dim_order``. One of dim_order or dim_order_kwargs must be provided. @@ -3261,7 +3504,6 @@ def reorder_levels( Another dataset, with this dataset's data but replaced coordinates. """ - _check_inplace(inplace) dim_order = either_dict_or_kwargs(dim_order, dim_order_kwargs, "reorder_levels") variables = self._variables.copy() indexes = dict(self.indexes) @@ -3319,12 +3561,13 @@ def stack( Parameters ---------- - dimensions : Mapping of the form new_name=(dim1, dim2, ...) - Names of new dimensions, and the existing dimensions that they - replace. An ellipsis (`...`) will be replaced by all unlisted dimensions. + dimensions : mapping of hashable to sequence of hashable + Mapping of the form `new_name=(dim1, dim2, ...)`. Names of new + dimensions, and the existing dimensions that they replace. An + ellipsis (`...`) will be replaced by all unlisted dimensions. Passing a list containing an ellipsis (`stacked_dim=[...]`) will stack over all dimensions. - **dimensions_kwargs: + **dimensions_kwargs The keyword arguments form of ``dimensions``. One of dimensions or dimensions_kwargs must be provided. @@ -3358,9 +3601,9 @@ def to_stacked_array( Parameters ---------- - new_dim : Hashable + new_dim : hashable Name of the new stacked coordinate - sample_dims : Sequence[Hashable] + sample_dims : sequence of hashable Dimensions that **will not** be stacked. Each array in the dataset must share these dimensions. For machine learning applications, these define the dimensions over which samples are drawn. @@ -3399,20 +3642,20 @@ def to_stacked_array( Dimensions: (x: 2, y: 3) Coordinates: - * y (y) >> data.to_stacked_array("z", sample_dims=["x"]) - + array([[0, 1, 2, 6], - [3, 4, 5, 7]]) + [3, 4, 5, 7]]) Coordinates: - * z (z) MultiIndex - - variable (z) object 'a' 'a' 'a' 'b' - - y (z) object 'u' 'v' 'w' nan + * z (z) MultiIndex + - variable (z) object 'a' 'a' 'a' 'b' + - y (z) object 'u' 'v' 'w' nan Dimensions without coordinates: x """ @@ -3514,11 +3757,15 @@ def unstack( Parameters ---------- - dim : Hashable or iterable of Hashable, optional + dim : hashable or iterable of hashable, optional Dimension(s) over which to unstack. By default unstacks all MultiIndexes. - fill_value: value to be filled. By default, np.nan - sparse: use sparse-array if True + fill_value : scalar or dict-like, default: nan + value to be filled. If a dict-like, maps variable names to + fill values. If not provided or if the dict-like does not + contain all variables, the dtype's NA value will be used. + sparse : bool, default: False + use sparse-array if True Returns ------- @@ -3559,12 +3806,12 @@ def unstack( result = result._unstack_once(dim, fill_value, sparse) return result - def update(self, other: "CoercibleMapping", inplace: bool = None) -> "Dataset": + def update(self, other: "CoercibleMapping") -> "Dataset": """Update this dataset's variables with those from another dataset. Parameters ---------- - other : Dataset or castable to Dataset + other : Dataset or mapping Variables with which to update this dataset. One of: - Dataset @@ -3585,14 +3832,12 @@ def update(self, other: "CoercibleMapping", inplace: bool = None) -> "Dataset": If any dimensions would have inconsistent sizes in the updated dataset. """ - _check_inplace(inplace) merge_result = dataset_update_method(self, other) return self._replace(inplace=True, **merge_result._asdict()) def merge( self, other: Union["CoercibleMapping", "DataArray"], - inplace: bool = None, overwrite_vars: Union[Hashable, Iterable[Hashable]] = frozenset(), compat: str = "no_conflicts", join: str = "outer", @@ -3607,13 +3852,13 @@ def merge( Parameters ---------- - other : Dataset or castable to Dataset + other : Dataset or mapping Dataset or variables to merge with this dataset. - overwrite_vars : Hashable or iterable of Hashable, optional + overwrite_vars : hashable or iterable of hashable, optional If provided, update variables of these name(s) without checking for conflicts in this dataset. - compat : {'broadcast_equals', 'equals', 'identical', - 'no_conflicts'}, optional + compat : {"broadcast_equals", "equals", "identical", \ + "no_conflicts"}, optional String indicating how to compare variables of the same name for potential conflicts: @@ -3626,7 +3871,7 @@ def merge( must be equal. The returned dataset then contains the combination of all non-null values. - join : {'outer', 'inner', 'left', 'right', 'exact'}, optional + join : {"outer", "inner", "left", "right", "exact"}, optional Method for joining ``self`` and ``other`` along shared dimensions: - 'outer': use the union of the indexes @@ -3634,8 +3879,9 @@ def merge( - 'left': use indexes from ``self`` - 'right': use indexes from ``other`` - 'exact': error instead of aligning non-equal indexes - fill_value: scalar, optional - Value to use for newly missing values + fill_value : scalar or dict-like, optional + Value to use for newly missing values. If a dict-like, maps + variable names (including coordinates) to fill values. Returns ------- @@ -3647,7 +3893,6 @@ def merge( MergeError If any variables conflict (see ``compat``). """ - _check_inplace(inplace) other = other.to_dataset() if isinstance(other, xr.DataArray) else other merge_result = dataset_merge_method( self, @@ -3678,9 +3923,9 @@ def drop_vars( Parameters ---------- - names : hashable or iterable of hashables + names : hashable or iterable of hashable Name(s) of variables to drop. - errors: {'raise', 'ignore'}, optional + errors : {"raise", "ignore"}, optional If 'raise' (default), raises a ValueError error if any of the variable passed are not in the dataset. If 'ignore', any given names that are in the dataset are dropped and no error is raised. @@ -3759,9 +4004,9 @@ def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs): Parameters ---------- - labels : Mapping[Hashable, Any] + labels : mapping of hashable to Any Index labels to drop - errors: {'raise', 'ignore'}, optional + errors : {"raise", "ignore"}, optional If 'raise' (default), raises a ValueError error if any of the index labels passed are not in the dataset. If 'ignore', any given labels that are in the @@ -3785,7 +4030,7 @@ def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs): * y (y) >> ds.drop_sel(y="b") Dimensions: (x: 2, y: 2) @@ -3793,7 +4038,7 @@ def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs): * y (y) "Dataset": Parameters ---------- - *dims : Hashable, optional + *dims : hashable, optional By default, reverse the dimensions on each array. Otherwise, reorder the dimensions to this order. @@ -3911,13 +4156,13 @@ def dropna( Parameters ---------- - dim : Hashable + dim : hashable Dimension along which to drop missing values. Dropping along multiple dimensions simultaneously is not yet supported. - how : {'any', 'all'}, optional + how : {"any", "all"}, default: "any" * any : if any NA values are present, drop that label * all : if all values are NA, drop that label - thresh : int, default None + thresh : int, default: None If supplied, require this many non-NA values. subset : iterable of hashable, optional Which variables to check for missing values. By default, all @@ -3999,7 +4244,7 @@ def fillna(self, value: Any) -> "Dataset": Dimensions: (x: 4) Coordinates: - * x (x) int64 0 1 2 3 + * x (x) int64 0 1 2 3 Data variables: A (x) float64 nan 2.0 nan 0.0 B (x) float64 3.0 4.0 nan 1.0 @@ -4012,7 +4257,7 @@ def fillna(self, value: Any) -> "Dataset": Dimensions: (x: 4) Coordinates: - * x (x) int64 0 1 2 3 + * x (x) int64 0 1 2 3 Data variables: A (x) float64 0.0 2.0 0.0 0.0 B (x) float64 3.0 4.0 0.0 1.0 @@ -4026,7 +4271,7 @@ def fillna(self, value: Any) -> "Dataset": Dimensions: (x: 4) Coordinates: - * x (x) int64 0 1 2 3 + * x (x) int64 0 1 2 3 Data variables: A (x) float64 0.0 2.0 0.0 0.0 B (x) float64 3.0 4.0 1.0 1.0 @@ -4073,18 +4318,18 @@ def interpolate_na( - 'barycentric', 'krog', 'pchip', 'spline', 'akima': use their respective :py:class:`scipy.interpolate` classes. - use_coordinate : bool, str, default True + use_coordinate : bool, str, default: True Specifies which index to use as the x values in the interpolation formulated as `y = f(x)`. If False, values are treated as if eqaully-spaced along ``dim``. If True, the IndexVariable `dim` is used. If ``use_coordinate`` is a string, it specifies the name of a coordinate variariable to use as the index. - limit : int, default None + limit : int, default: None Maximum number of consecutive NaNs to fill. Must be greater than 0 or None for no limit. This filling is done regardless of the size of the gap in the data. To only interpolate over gaps less than a given length, see ``max_gap``. - max_gap: int, float, str, pandas.Timedelta, numpy.timedelta64, datetime.timedelta, default None. + max_gap : int, float, str, pandas.Timedelta, numpy.timedelta64, datetime.timedelta, default: None Maximum size of gap, a continuous sequence of NaNs, that will be filled. Use None for no limit. When interpolating along a datetime64 dimension and ``use_coordinate=True``, ``max_gap`` can be one of the following: @@ -4119,8 +4364,52 @@ def interpolate_na( -------- numpy.interp scipy.interpolate + + Examples + -------- + >>> ds = xr.Dataset( + ... { + ... "A": ("x", [np.nan, 2, 3, np.nan, 0]), + ... "B": ("x", [3, 4, np.nan, 1, 7]), + ... "C": ("x", [np.nan, np.nan, np.nan, 5, 0]), + ... "D": ("x", [np.nan, 3, np.nan, -1, 4]), + ... }, + ... coords={"x": [0, 1, 2, 3, 4]}, + ... ) + >>> ds + + Dimensions: (x: 5) + Coordinates: + * x (x) int64 0 1 2 3 4 + Data variables: + A (x) float64 nan 2.0 3.0 nan 0.0 + B (x) float64 3.0 4.0 nan 1.0 7.0 + C (x) float64 nan nan nan 5.0 0.0 + D (x) float64 nan 3.0 nan -1.0 4.0 + + >>> ds.interpolate_na(dim="x", method="linear") + + Dimensions: (x: 5) + Coordinates: + * x (x) int64 0 1 2 3 4 + Data variables: + A (x) float64 nan 2.0 3.0 1.5 0.0 + B (x) float64 3.0 4.0 2.5 1.0 7.0 + C (x) float64 nan nan nan 5.0 0.0 + D (x) float64 nan 3.0 1.0 -1.0 4.0 + + >>> ds.interpolate_na(dim="x", method="linear", fill_value="extrapolate") + + Dimensions: (x: 5) + Coordinates: + * x (x) int64 0 1 2 3 4 + Data variables: + A (x) float64 1.0 2.0 3.0 1.5 0.0 + B (x) float64 3.0 4.0 2.5 1.0 7.0 + C (x) float64 20.0 15.0 10.0 5.0 0.0 + D (x) float64 5.0 3.0 1.0 -1.0 4.0 """ - from .missing import interp_na, _apply_over_vars_with_dim + from .missing import _apply_over_vars_with_dim, interp_na new = _apply_over_vars_with_dim( interp_na, @@ -4144,7 +4433,7 @@ def ffill(self, dim: Hashable, limit: int = None) -> "Dataset": dim : Hashable Specifies the dimension along which to propagate values when filling. - limit : int, default None + limit : int, default: None The maximum number of consecutive NaN values to forward fill. In other words, if there is a gap with more than this number of consecutive NaNs, it will only be partially filled. Must be greater @@ -4154,7 +4443,7 @@ def ffill(self, dim: Hashable, limit: int = None) -> "Dataset": ------- Dataset """ - from .missing import ffill, _apply_over_vars_with_dim + from .missing import _apply_over_vars_with_dim, ffill new = _apply_over_vars_with_dim(ffill, self, dim=dim, limit=limit) return new @@ -4169,7 +4458,7 @@ def bfill(self, dim: Hashable, limit: int = None) -> "Dataset": dim : str Specifies the dimension along which to propagate values when filling. - limit : int, default None + limit : int, default: None The maximum number of consecutive NaN values to backward fill. In other words, if there is a gap with more than this number of consecutive NaNs, it will only be partially filled. Must be greater @@ -4179,7 +4468,7 @@ def bfill(self, dim: Hashable, limit: int = None) -> "Dataset": ------- Dataset """ - from .missing import bfill, _apply_over_vars_with_dim + from .missing import _apply_over_vars_with_dim, bfill new = _apply_over_vars_with_dim(bfill, self, dim=dim, limit=limit) return new @@ -4210,7 +4499,6 @@ def reduce( keep_attrs: bool = None, keepdims: bool = False, numeric_only: bool = False, - allow_lazy: bool = None, **kwargs: Any, ) -> "Dataset": """Reduce this dataset by applying `func` along some dimension(s). @@ -4228,7 +4516,7 @@ def reduce( If True, the dataset's attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be returned without attributes. - keepdims : bool, default False + keepdims : bool, default: False If True, the dimensions which are reduced are left in the result as dimensions of size one. Coordinates that use these dimensions are removed. @@ -4285,7 +4573,6 @@ def reduce( dim=reduce_dims, keep_attrs=keep_attrs, keepdims=keepdims, - allow_lazy=allow_lazy, **kwargs, ) @@ -4334,22 +4621,25 @@ def map( Dimensions: (dim_0: 2, dim_1: 3, x: 2) Dimensions without coordinates: dim_0, dim_1, x Data variables: - foo (dim_0, dim_1) float64 -0.3751 -1.951 -1.945 0.2948 0.711 -0.3948 + foo (dim_0, dim_1) float64 1.764 0.4002 0.9787 2.241 1.868 -0.9773 bar (x) int64 -1 2 >>> ds.map(np.fabs) Dimensions: (dim_0: 2, dim_1: 3, x: 2) Dimensions without coordinates: dim_0, dim_1, x Data variables: - foo (dim_0, dim_1) float64 0.3751 1.951 1.945 0.2948 0.711 0.3948 + foo (dim_0, dim_1) float64 1.764 0.4002 0.9787 2.241 1.868 0.9773 bar (x) float64 1.0 2.0 """ + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) variables = { k: maybe_wrap_array(v, func(v, *args, **kwargs)) for k, v in self.data_vars.items() } - if keep_attrs is None: - keep_attrs = _get_keep_attrs(default=False) + if keep_attrs: + for k, v in variables.items(): + v._copy_attrs_from(self.data_vars[k]) attrs = self.attrs if keep_attrs else None return type(self)(variables, attrs=attrs) @@ -4382,12 +4672,12 @@ def assign( Parameters ---------- - variables : mapping, value pairs + variables : mapping of hashable to Any Mapping from variables names to the new values. If the new values are callable, they are computed on the Dataset and assigned to new data variables. If the values are not callable, (e.g. a DataArray, scalar, or array), they are simply assigned. - **variables_kwargs: + **variables_kwargs The keyword arguments form of ``variables``. One of variables or variables_kwargs must be provided. @@ -4425,11 +4715,11 @@ def assign( Dimensions: (lat: 2, lon: 2) Coordinates: - * lat (lat) int64 10 20 - * lon (lon) int64 150 160 + * lat (lat) int64 10 20 + * lon (lon) int64 150 160 Data variables: - temperature_c (lat, lon) float64 18.04 12.51 17.64 9.313 - precipitation (lat, lon) float64 0.4751 0.6827 0.3697 0.03524 + temperature_c (lat, lon) float64 10.98 14.3 12.06 10.9 + precipitation (lat, lon) float64 0.4237 0.6459 0.4376 0.8918 Where the value is a callable, evaluated on dataset: @@ -4437,12 +4727,12 @@ def assign( Dimensions: (lat: 2, lon: 2) Coordinates: - * lat (lat) int64 10 20 - * lon (lon) int64 150 160 + * lat (lat) int64 10 20 + * lon (lon) int64 150 160 Data variables: - temperature_c (lat, lon) float64 18.04 12.51 17.64 9.313 - precipitation (lat, lon) float64 0.4751 0.6827 0.3697 0.03524 - temperature_f (lat, lon) float64 64.47 54.51 63.75 48.76 + temperature_c (lat, lon) float64 10.98 14.3 12.06 10.9 + precipitation (lat, lon) float64 0.4237 0.6459 0.4376 0.8918 + temperature_f (lat, lon) float64 51.76 57.75 53.7 51.62 Alternatively, the same behavior can be achieved by directly referencing an existing dataarray: @@ -4450,12 +4740,12 @@ def assign( Dimensions: (lat: 2, lon: 2) Coordinates: - * lat (lat) int64 10 20 - * lon (lon) int64 150 160 + * lat (lat) int64 10 20 + * lon (lon) int64 150 160 Data variables: - temperature_c (lat, lon) float64 18.04 12.51 17.64 9.313 - precipitation (lat, lon) float64 0.4751 0.6827 0.3697 0.03524 - temperature_f (lat, lon) float64 64.47 54.51 63.75 48.76 + temperature_c (lat, lon) float64 10.98 14.3 12.06 10.9 + precipitation (lat, lon) float64 0.4237 0.6459 0.4376 0.8918 + temperature_f (lat, lon) float64 51.76 57.75 53.7 51.62 """ variables = either_dict_or_kwargs(variables, variables_kwargs, "assign") @@ -4500,44 +4790,91 @@ def to_array(self, dim="variable", name=None): data, coords, dims, attrs=self.attrs, name=name, indexes=indexes ) - def _to_dataframe(self, ordered_dims): + def _normalize_dim_order( + self, dim_order: List[Hashable] = None + ) -> Dict[Hashable, int]: + """ + Check the validity of the provided dimensions if any and return the mapping + between dimension name and their size. + + Parameters + ---------- + dim_order + Dimension order to validate (default to the alphabetical order if None). + + Returns + ------- + result + Validated dimensions mapping. + + """ + if dim_order is None: + dim_order = list(self.dims) + elif set(dim_order) != set(self.dims): + raise ValueError( + "dim_order {} does not match the set of dimensions of this " + "Dataset: {}".format(dim_order, list(self.dims)) + ) + + ordered_dims = {k: self.dims[k] for k in dim_order} + + return ordered_dims + + def _to_dataframe(self, ordered_dims: Mapping[Hashable, int]): columns = [k for k in self.variables if k not in self.dims] data = [ self._variables[k].set_dims(ordered_dims).values.reshape(-1) for k in columns ] - index = self.coords.to_index(ordered_dims) + index = self.coords.to_index([*ordered_dims]) return pd.DataFrame(dict(zip(columns, data)), index=index) - def to_dataframe(self): + def to_dataframe(self, dim_order: List[Hashable] = None) -> pd.DataFrame: """Convert this dataset into a pandas.DataFrame. Non-index variables in this dataset form the columns of the - DataFrame. The DataFrame is be indexed by the Cartesian product of + DataFrame. The DataFrame is indexed by the Cartesian product of this dataset's indices. + + Parameters + ---------- + dim_order + Hierarchical dimension order for the resulting dataframe. All + arrays are transposed to this order and then written out as flat + vectors in contiguous order, so the last dimension in this list + will be contiguous in the resulting DataFrame. This has a major + influence on which operations are efficient on the resulting + dataframe. + + If provided, must include all dimensions of this dataset. By + default, dimensions are sorted alphabetically. + + Returns + ------- + result + Dataset as a pandas DataFrame. + """ - return self._to_dataframe(self.dims) + + ordered_dims = self._normalize_dim_order(dim_order=dim_order) + + return self._to_dataframe(ordered_dims=ordered_dims) def _set_sparse_data_from_dataframe( - self, dataframe: pd.DataFrame, dims: tuple + self, idx: pd.Index, arrays: List[Tuple[Hashable, np.ndarray]], dims: tuple ) -> None: from sparse import COO - idx = dataframe.index if isinstance(idx, pd.MultiIndex): coords = np.stack([np.asarray(code) for code in idx.codes], axis=0) - is_sorted = idx.is_lexsorted + is_sorted = idx.is_lexsorted() shape = tuple(lev.size for lev in idx.levels) else: coords = np.arange(idx.size).reshape(1, -1) is_sorted = True shape = (idx.size,) - for name, series in dataframe.items(): - # Cast to a NumPy array first, in case the Series is a pandas - # Extension array (which doesn't have a valid NumPy dtype) - values = np.asarray(series) - + for name, values in arrays: # In virtually all real use cases, the sparse array will now have # missing values and needs a fill_value. For consistency, don't # special case the rare exceptions (e.g., dtype=int without a @@ -4556,18 +4893,36 @@ def _set_sparse_data_from_dataframe( self[name] = (dims, data) def _set_numpy_data_from_dataframe( - self, dataframe: pd.DataFrame, dims: tuple + self, idx: pd.Index, arrays: List[Tuple[Hashable, np.ndarray]], dims: tuple ) -> None: - idx = dataframe.index - if isinstance(idx, pd.MultiIndex): - # expand the DataFrame to include the product of all levels - full_idx = pd.MultiIndex.from_product(idx.levels, names=idx.names) - dataframe = dataframe.reindex(full_idx) - shape = tuple(lev.size for lev in idx.levels) - else: - shape = (idx.size,) - for name, series in dataframe.items(): - data = np.asarray(series).reshape(shape) + if not isinstance(idx, pd.MultiIndex): + for name, values in arrays: + self[name] = (dims, values) + return + + shape = tuple(lev.size for lev in idx.levels) + indexer = tuple(idx.codes) + + # We already verified that the MultiIndex has all unique values, so + # there are missing values if and only if the size of output arrays is + # larger that the index. + missing_values = np.prod(shape) > idx.shape[0] + + for name, values in arrays: + # NumPy indexing is much faster than using DataFrame.reindex() to + # fill in missing values: + # https://stackoverflow.com/a/35049899/809705 + if missing_values: + dtype, fill_value = dtypes.maybe_promote(values.dtype) + data = np.full(shape, fill_value, dtype) + else: + # If there are no missing values, keep the existing dtype + # instead of promoting to support NA, e.g., keep integer + # columns as integers. + # TODO: consider removing this special case, which doesn't + # exist for sparse=True. + data = np.zeros(shape, values.dtype) + data[indexer] = values self[name] = (dims, data) @classmethod @@ -4584,9 +4939,9 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> "Datas Parameters ---------- - dataframe : pandas.DataFrame + dataframe : DataFrame DataFrame from which to copy data and indices. - sparse : bool + sparse : bool, default: False If true, create a sparse arrays instead of dense numpy arrays. This can potentially save a large amount of memory if the DataFrame has a MultiIndex. Requires the sparse package (sparse.pydata.org). @@ -4598,6 +4953,7 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> "Datas See also -------- xarray.DataArray.from_series + pandas.DataFrame.to_xarray """ # TODO: Add an option to remove dimensions along which the variables # are constant, to enable consistent serialization to/from a dataframe, @@ -4606,7 +4962,19 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> "Datas if not dataframe.columns.is_unique: raise ValueError("cannot convert DataFrame with non-unique columns") - idx, dataframe = remove_unused_levels_categories(dataframe.index, dataframe) + idx = remove_unused_levels_categories(dataframe.index) + + if isinstance(idx, pd.MultiIndex) and not idx.is_unique: + raise ValueError( + "cannot convert a DataFrame with a non-unique MultiIndex into xarray" + ) + + # Cast to a NumPy array first, in case the Series is a pandas Extension + # array (which doesn't have a valid NumPy dtype) + # TODO: allow users to control how this casting happens, e.g., by + # forwarding arguments to pandas.Series.to_numpy? + arrays = [(k, np.asarray(v)) for k, v in dataframe.items()] + obj = cls() if isinstance(idx, pd.MultiIndex): @@ -4622,9 +4990,9 @@ def from_dataframe(cls, dataframe: pd.DataFrame, sparse: bool = False) -> "Datas obj[index_name] = (dims, idx) if sparse: - obj._set_sparse_data_from_dataframe(dataframe, dims) + obj._set_sparse_data_from_dataframe(idx, arrays, dims) else: - obj._set_numpy_data_from_dataframe(dataframe, dims) + obj._set_numpy_data_from_dataframe(idx, arrays, dims) return obj def to_dask_dataframe(self, dim_order=None, set_index=False): @@ -4644,11 +5012,11 @@ def to_dask_dataframe(self, dim_order=None, set_index=False): influence on which operations are efficient on the resulting dask dataframe. - If provided, must include all dimensions on this dataset. By + If provided, must include all dimensions of this dataset. By default, dimensions are sorted alphabetically. set_index : bool, optional If set_index=True, the dask DataFrame is indexed by this dataset's - coordinate. Since dask DataFrames to not support multi-indexes, + coordinate. Since dask DataFrames do not support multi-indexes, set_index only works if the dataset only contains one dimension. Returns @@ -4659,15 +5027,7 @@ def to_dask_dataframe(self, dim_order=None, set_index=False): import dask.array as da import dask.dataframe as dd - if dim_order is None: - dim_order = list(self.dims) - elif set(dim_order) != set(self.dims): - raise ValueError( - "dim_order {} does not match the set of dimensions on this " - "Dataset: {}".format(dim_order, list(self.dims)) - ) - - ordered_dims = {k: self.dims[k] for k in dim_order} + ordered_dims = self._normalize_dim_order(dim_order=dim_order) columns = list(ordered_dims) columns.extend(k for k in self.coords if k not in self.dims) @@ -4694,6 +5054,8 @@ def to_dask_dataframe(self, dim_order=None, set_index=False): df = dd.concat(series_list, axis=1) if set_index: + dim_order = [*ordered_dims] + if len(dim_order) == 1: (dim,) = dim_order df = df.set_index(dim) @@ -4740,27 +5102,35 @@ def from_dict(cls, d): """ Convert a dictionary into an xarray.Dataset. - Input dict can take several forms:: + Input dict can take several forms: - d = {'t': {'dims': ('t'), 'data': t}, - 'a': {'dims': ('t'), 'data': x}, - 'b': {'dims': ('t'), 'data': y}} + .. code:: python - d = {'coords': {'t': {'dims': 't', 'data': t, - 'attrs': {'units':'s'}}}, - 'attrs': {'title': 'air temperature'}, - 'dims': 't', - 'data_vars': {'a': {'dims': 't', 'data': x, }, - 'b': {'dims': 't', 'data': y}}} + d = { + "t": {"dims": ("t"), "data": t}, + "a": {"dims": ("t"), "data": x}, + "b": {"dims": ("t"), "data": y}, + } - where 't' is the name of the dimesion, 'a' and 'b' are names of data + d = { + "coords": {"t": {"dims": "t", "data": t, "attrs": {"units": "s"}}}, + "attrs": {"title": "air temperature"}, + "dims": "t", + "data_vars": { + "a": {"dims": "t", "data": x}, + "b": {"dims": "t", "data": y}, + }, + } + + where "t" is the name of the dimesion, "a" and "b" are names of data variables and t, x, and y are lists, numpy.arrays or pandas objects. Parameters ---------- - d : dict, with a minimum structure of {'var_0': {'dims': [..], \ - 'data': [..]}, \ - ...} + d : dict-like + Mapping with a minimum structure of + ``{"var_0": {"dims": [..], "data": [..]}, \ + ...}`` Returns ------- @@ -4800,15 +5170,20 @@ def from_dict(cls, d): return obj @staticmethod - def _unary_op(f, keep_attrs=False): + def _unary_op(f): @functools.wraps(f) def func(self, *args, **kwargs): variables = {} + keep_attrs = kwargs.pop("keep_attrs", None) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) for k, v in self._variables.items(): if k in self._coord_names: variables[k] = v else: variables[k] = f(v, *args, **kwargs) + if keep_attrs: + variables[k].attrs = v._attrs attrs = self._attrs if keep_attrs else None return self._replace_with_new_dims(variables, attrs=attrs) @@ -4939,17 +5314,15 @@ def diff(self, dim, n=1, label="upper"): >>> ds.diff("x") Dimensions: (x: 3) - Coordinates: - * x (x) int64 1 2 3 + Dimensions without coordinates: x Data variables: foo (x) int64 0 1 0 >>> ds.diff("x", 2) Dimensions: (x: 2) - Coordinates: - * x (x) int64 2 3 + Dimensions without coordinates: x Data variables: - foo (x) int64 1 -1 + foo (x) int64 1 -1 See Also -------- @@ -4970,9 +5343,7 @@ def diff(self, dim, n=1, label="upper"): elif label == "lower": kwargs_new = kwargs_start else: - raise ValueError( - "The 'label' argument has to be either " "'upper' or 'lower'" - ) + raise ValueError("The 'label' argument has to be either 'upper' or 'lower'") variables = {} @@ -5004,13 +5375,14 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): Parameters ---------- - shifts : Mapping with the form of {dim: offset} + shifts : mapping of hashable to int Integer offset to shift along each of the given dimensions. Positive offsets shift to the right; negative offsets shift to the left. - fill_value: scalar, optional - Value to use for newly missing values - **shifts_kwargs: + fill_value : scalar or dict-like, optional + Value to use for newly missing values. If a dict-like, maps + variable names (including coordinates) to fill values. + **shifts_kwargs The keyword arguments form of ``shifts``. One of shifts or shifts_kwargs must be provided. @@ -5031,8 +5403,7 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): >>> ds.shift(x=2) Dimensions: (x: 5) - Coordinates: - * x (x) int64 0 1 2 3 4 + Dimensions without coordinates: x Data variables: foo (x) object nan nan 'a' 'b' 'c' """ @@ -5044,8 +5415,14 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): variables = {} for name, var in self.variables.items(): if name in self.data_vars: + fill_value_ = ( + fill_value.get(name, dtypes.NA) + if isinstance(fill_value, dict) + else fill_value + ) + var_shifts = {k: v for k, v in shifts.items() if k in var.dims} - variables[name] = var.shift(fill_value=fill_value, shifts=var_shifts) + variables[name] = var.shift(fill_value=fill_value_, shifts=var_shifts) else: variables[name] = var @@ -5090,10 +5467,9 @@ def roll(self, shifts=None, roll_coords=None, **shifts_kwargs): >>> ds.roll(x=2) Dimensions: (x: 5) - Coordinates: - * x (x) int64 3 4 0 1 2 + Dimensions without coordinates: x Data variables: - foo (x) object 'd' 'e' 'a' 'b' 'c' + foo (x) "Dataset": - """ Integrate along the given coordinate using the trapezoidal rule. + """Integrate along the given coordinate using the trapezoidal rule. .. note:: This feature is limited to simple cartesian geometry, i.e. coord @@ -5449,14 +5825,14 @@ def integrate( ---------- coord: hashable, or a sequence of hashable Coordinate(s) used for the integration. - datetime_unit + datetime_unit: str, optional Can be specify the unit if datetime coordinate is used. One of {'Y', 'M', 'W', 'D', 'h', 'm', 's', 'ms', 'us', 'ns', 'ps', 'fs', 'as'} Returns ------- - integrated: Dataset + integrated : Dataset See also -------- @@ -5547,22 +5923,13 @@ def _integrate_one(self, coord, datetime_unit=None): @property def real(self): - return self._unary_op(lambda x: x.real, keep_attrs=True)(self) + return self.map(lambda x: x.real, keep_attrs=True) @property def imag(self): - return self._unary_op(lambda x: x.imag, keep_attrs=True)(self) + return self.map(lambda x: x.imag, keep_attrs=True) - @property - def plot(self): - """ - Access plotting functions for Datasets. - Use it as a namespace to use xarray.plot functions as Dataset methods - - >>> ds.plot.scatter(...) # equivalent to xarray.plot.scatter(ds,...) - - """ - return _Dataset_PlotMethods(self) + plot = utils.UncachedAccessor(_Dataset_PlotMethods) def filter_by_attrs(self, **kwargs): """Returns a ``Dataset`` with variables that match specific conditions. @@ -5577,7 +5944,7 @@ def filter_by_attrs(self, **kwargs): Parameters ---------- - **kwargs : key=value + **kwargs key : str Attribute name. value : callable or obj @@ -5594,9 +5961,6 @@ def filter_by_attrs(self, **kwargs): Examples -------- >>> # Create an example dataset: - >>> import numpy as np - >>> import pandas as pd - >>> import xarray as xr >>> temp = 15 + 8 * np.random.randn(2, 2, 3) >>> precip = 10 * np.random.rand(2, 2, 3) >>> lon = [[-99.83, -99.32], [-99.79, -99.23]] @@ -5621,14 +5985,13 @@ def filter_by_attrs(self, **kwargs): Dimensions: (time: 3, x: 2, y: 2) Coordinates: - * x (x) int64 0 1 - * time (time) datetime64[ns] 2014-09-06 2014-09-07 2014-09-08 + lon (x, y) float64 -99.83 -99.32 -99.79 -99.23 lat (x, y) float64 42.25 42.21 42.63 42.59 - * y (y) int64 0 1 + * time (time) datetime64[ns] 2014-09-06 2014-09-07 2014-09-08 reference_time datetime64[ns] 2014-09-05 - lon (x, y) float64 -99.83 -99.32 -99.79 -99.23 + Dimensions without coordinates: x, y Data variables: - precipitation (x, y, time) float64 4.178 2.307 6.041 6.046 0.06648 ... + precipitation (x, y, time) float64 5.68 9.256 0.7104 ... 7.992 4.615 7.805 >>> # Get all variables that have a standard_name attribute. >>> standard_name = lambda v: v is not None >>> ds.filter_by_attrs(standard_name=standard_name) @@ -5637,13 +6000,12 @@ def filter_by_attrs(self, **kwargs): Coordinates: lon (x, y) float64 -99.83 -99.32 -99.79 -99.23 lat (x, y) float64 42.25 42.21 42.63 42.59 - * x (x) int64 0 1 - * y (y) int64 0 1 * time (time) datetime64[ns] 2014-09-06 2014-09-07 2014-09-08 reference_time datetime64[ns] 2014-09-05 + Dimensions without coordinates: x, y Data variables: - temperature (x, y, time) float64 25.86 20.82 6.954 23.13 10.25 11.68 ... - precipitation (x, y, time) float64 5.702 0.9422 2.075 1.178 3.284 ... + temperature (x, y, time) float64 29.11 18.2 22.83 ... 18.28 16.15 26.63 + precipitation (x, y, time) float64 5.68 9.256 0.7104 ... 7.992 4.615 7.805 """ selection = [] @@ -5661,7 +6023,7 @@ def filter_by_attrs(self, **kwargs): return self[selection] def unify_chunks(self) -> "Dataset": - """ Unify chunk size along all chunked dimensions of this Dataset. + """Unify chunk size along all chunked dimensions of this Dataset. Returns ------- @@ -5711,57 +6073,109 @@ def map_blocks( func: "Callable[..., T_DSorDA]", args: Sequence[Any] = (), kwargs: Mapping[str, Any] = None, + template: Union["DataArray", "Dataset"] = None, ) -> "T_DSorDA": """ - Apply a function to each chunk of this Dataset. This method is experimental and - its signature may change. + Apply a function to each block of this Dataset. + + .. warning:: + This method is experimental and its signature may change. Parameters ---------- - func: callable - User-provided function that accepts a Dataset as its first parameter. The - function will receive a subset of this Dataset, corresponding to one chunk - along each chunked dimension. ``func`` will be executed as - ``func(obj_subset, *args, **kwargs)``. - - The function will be first run on mocked-up data, that looks like this - Dataset but has sizes 0, to determine properties of the returned object such - as dtype, variable names, new dimensions and new indexes (if any). + func : callable + User-provided function that accepts a Dataset as its first + parameter. The function will receive a subset or 'block' of this Dataset (see below), + corresponding to one chunk along each chunked dimension. ``func`` will be + executed as ``func(subset_dataset, *subset_args, **kwargs)``. This function must return either a single DataArray or a single Dataset. - This function cannot change size of existing dimensions, or add new chunked - dimensions. - args: Sequence - Passed verbatim to func after unpacking, after the sliced DataArray. xarray - objects, if any, will not be split by chunks. Passing dask collections is - not allowed. - kwargs: Mapping + This function cannot add a new chunked dimension. + args : sequence + Passed to func after unpacking and subsetting any xarray objects by blocks. + xarray objects in args must be aligned with obj, otherwise an error is raised. + kwargs : mapping Passed verbatim to func after unpacking. xarray objects, if any, will not be - split by chunks. Passing dask collections is not allowed. + subset to blocks. Passing dask collections in kwargs is not allowed. + template : DataArray or Dataset, optional + xarray object representing the final result after compute is called. If not provided, + the function will be first run on mocked-up data, that looks like this object but + has sizes 0, to determine properties of the returned object such as dtype, + variable names, attributes, new dimensions and new indexes (if any). + ``template`` must be provided if the function changes the size of existing dimensions. + When provided, ``attrs`` on variables in `template` are copied over to the result. Any + ``attrs`` set by ``func`` will be ignored. + Returns ------- - A single DataArray or Dataset with dask backend, reassembled from the outputs of - the function. + A single DataArray or Dataset with dask backend, reassembled from the outputs of the + function. Notes ----- - This method is designed for when one needs to manipulate a whole xarray object - within each chunk. In the more common case where one can work on numpy arrays, - it is recommended to use apply_ufunc. + This function is designed for when ``func`` needs to manipulate a whole xarray object + subset to each block. In the more common case where ``func`` can work on numpy arrays, it is + recommended to use ``apply_ufunc``. - If none of the variables in this Dataset is backed by dask, calling this method - is equivalent to calling ``func(self, *args, **kwargs)``. + If none of the variables in this object is backed by dask arrays, calling this function is + equivalent to calling ``func(obj, *args, **kwargs)``. See Also -------- - dask.array.map_blocks, xarray.apply_ufunc, xarray.map_blocks, + dask.array.map_blocks, xarray.apply_ufunc, xarray.Dataset.map_blocks, xarray.DataArray.map_blocks + + Examples + -------- + + Calculate an anomaly from climatology using ``.groupby()``. Using + ``xr.map_blocks()`` allows for parallel operations with knowledge of ``xarray``, + its indices, and its methods like ``.groupby()``. + + >>> def calculate_anomaly(da, groupby_type="time.month"): + ... gb = da.groupby(groupby_type) + ... clim = gb.mean(dim="time") + ... return gb - clim + ... + >>> time = xr.cftime_range("1990-01", "1992-01", freq="M") + >>> month = xr.DataArray(time.month, coords={"time": time}, dims=["time"]) + >>> np.random.seed(123) + >>> array = xr.DataArray( + ... np.random.rand(len(time)), + ... dims=["time"], + ... coords={"time": time, "month": month}, + ... ).chunk() + >>> ds = xr.Dataset({"a": array}) + >>> ds.map_blocks(calculate_anomaly, template=ds).compute() + + Dimensions: (time: 24) + Coordinates: + * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + month (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 1 2 3 4 5 6 7 8 9 10 11 12 + Data variables: + a (time) float64 0.1289 0.1132 -0.0856 ... 0.2287 0.1906 -0.05901 + + Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments + to the function being applied in ``xr.map_blocks()``: + + >>> ds.map_blocks( + ... calculate_anomaly, + ... kwargs={"groupby_type": "time.year"}, + ... template=ds, + ... ) + + Dimensions: (time: 24) + Coordinates: + * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + month (time) int64 dask.array + Data variables: + a (time) float64 dask.array """ from .parallel import map_blocks - return map_blocks(func, self, args, kwargs) + return map_blocks(func, self, args, kwargs, template) def polyfit( self, @@ -5791,13 +6205,13 @@ def polyfit( invalid values, False otherwise. rcond : float, optional Relative condition number to the fit. - w : Union[Hashable, Any], optional + w : hashable or Any, optional Weights to apply to the y-coordinate of the sample points. Can be an array-like object or the name of a coordinate in the dataset. full : bool, optional Whether to return the residuals, matrix rank and singular values in addition to the coefficients. - cov : Union[bool, str], optional + cov : bool or str, optional Whether to return to the covariance matrix in addition to the coefficients. The matrix is not scaled if `cov='unscaled'`. @@ -5811,13 +6225,21 @@ def polyfit( The coefficients of the best fit for each variable in this dataset. [var]_polyfit_residuals The residuals of the least-square computation for each variable (only included if `full=True`) + When the matrix rank is deficient, np.nan is returned. [dim]_matrix_rank The effective rank of the scaled Vandermonde coefficient matrix (only included if `full=True`) + The rank is computed ignoring the NaN values that might be skipped. [dim]_singular_values The singular values of the scaled Vandermonde coefficient matrix (only included if `full=True`) [var]_polyfit_covariance The covariance matrix of the polynomial coefficient estimates (only included if `full=False` and `cov=True`) + Warns + ----- + RankWarning + The rank of the coefficient matrix in the least-squares fit is deficient. + The warning is not raised with in-memory (not dask) data and `full=True`. + See also -------- numpy.polyfit @@ -5825,7 +6247,7 @@ def polyfit( variables = {} skipna_da = skipna - x = get_clean_interp_index(self, dim) + x = get_clean_interp_index(self, dim, strict=False) xname = "{}_".format(self[dim].name) order = int(deg) + 1 lhs = np.vander(x, order) @@ -5851,10 +6273,6 @@ def polyfit( degree_dim = utils.get_temp_dimname(self.dims, "degree") rank = np.linalg.matrix_rank(lhs) - if rank != order and not full: - warnings.warn( - "Polyfit may be poorly conditioned", np.RankWarning, stacklevel=4 - ) if full: rank = xr.DataArray(rank, name=xname + "matrix_rank") @@ -5863,7 +6281,7 @@ def polyfit( sing = xr.DataArray( sing, dims=(degree_dim,), - coords={degree_dim: np.arange(order)[::-1]}, + coords={degree_dim: np.arange(rank - 1, -1, -1)}, name=xname + "singular_values", ) variables[sing.name] = sing @@ -5872,14 +6290,17 @@ def polyfit( if dim not in da.dims: continue - if skipna is None: - if isinstance(da.data, dask_array_type): - skipna_da = True - else: - skipna_da = np.any(da.isnull()) + if is_duck_dask_array(da.data) and ( + rank != order or full or skipna is None + ): + # Current algorithm with dask and skipna=False neither supports + # deficient ranks nor does it output the "full" info (issue dask/dask#6516) + skipna_da = True + elif skipna is None: + skipna_da = np.any(da.isnull()) dims_to_stack = [dimname for dimname in da.dims if dimname != dim] - stacked_coords = {} + stacked_coords: Dict[Hashable, DataArray] = {} if dims_to_stack: stacked_dim = utils.get_temp_dimname(dims_to_stack, "stacked") rhs = da.transpose(dim, *dims_to_stack).stack( @@ -5894,9 +6315,15 @@ def polyfit( if w is not None: rhs *= w[:, np.newaxis] - coeffs, residuals = duck_array_ops.least_squares( - lhs, rhs.data, rcond=rcond, skipna=skipna_da - ) + with warnings.catch_warnings(): + if full: # Copy np.polyfit behavior + warnings.simplefilter("ignore", np.RankWarning) + else: # Raise only once per variable + warnings.simplefilter("once", np.RankWarning) + + coeffs, residuals = duck_array_ops.least_squares( + lhs, rhs.data, rcond=rcond, skipna=skipna_da + ) if isinstance(name, str): name = "{}_".format(name) @@ -5936,7 +6363,7 @@ def polyfit( "The number of data points must exceed order to scale the covariance matrix." ) fac = residuals / (x.shape[0] - order) - covariance = xr.DataArray(Vbase, dims=("cov_i", "cov_j"),) * fac + covariance = xr.DataArray(Vbase, dims=("cov_i", "cov_j")) * fac variables[name + "polyfit_covariance"] = covariance return Dataset(data_vars=variables, attrs=self.attrs.copy()) @@ -5969,10 +6396,11 @@ def pad( Parameters ---------- - pad_width : Mapping with the form of {dim: (pad_before, pad_after)} - Number of values padded along each dimension. + pad_width : mapping of hashable to tuple of int + Mapping with the form of {dim: (pad_before, pad_after)} + describing the number of values padded along each dimension. {dim: pad} is a shortcut for pad_before = pad_after = pad - mode : str + mode : str, default: "constant" One of the following string values (taken from numpy docs). 'constant' (default) @@ -6005,7 +6433,7 @@ def pad( Pads with the wrap of the vector along the axis. The first values are used to pad the end and the end values are used to pad the beginning. - stat_length : int, tuple or mapping of the form {dim: tuple} + stat_length : int, tuple or mapping of hashable to tuple, default: None Used in 'maximum', 'mean', 'median', and 'minimum'. Number of values at edge of each axis used to calculate the statistic value. {dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)} unique @@ -6015,7 +6443,7 @@ def pad( (stat_length,) or int is a shortcut for before = after = statistic length for all axes. Default is ``None``, to use the entire axis. - constant_values : scalar, tuple or mapping of the form {dim: tuple} + constant_values : scalar, tuple or mapping of hashable to tuple, default: 0 Used in 'constant'. The values to set the padded values for each axis. ``{dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}`` unique @@ -6025,7 +6453,7 @@ def pad( ``(constant,)`` or ``constant`` is a shortcut for ``before = after = constant`` for all dimensions. Default is 0. - end_values : scalar, tuple or mapping of the form {dim: tuple} + end_values : scalar, tuple or mapping of hashable to tuple, default: 0 Used in 'linear_ramp'. The values used for the ending value of the linear_ramp and that will form the edge of the padded array. ``{dim_1: (before_1, after_1), ... dim_N: (before_N, after_N)}`` unique @@ -6035,12 +6463,12 @@ def pad( ``(constant,)`` or ``constant`` is a shortcut for ``before = after = constant`` for all axes. Default is 0. - reflect_type : {'even', 'odd'}, optional - Used in 'reflect', and 'symmetric'. The 'even' style is the + reflect_type : {"even", "odd"}, optional + Used in "reflect", and "symmetric". The "even" style is the default with an unaltered reflection around the edge value. For - the 'odd' style, the extended part of the array is created by + the "odd" style, the extended part of the array is created by subtracting the reflected values from two times the edge value. - **pad_width_kwargs: + **pad_width_kwargs The keyword arguments form of ``pad_width``. One of ``pad_width`` or ``pad_width_kwargs`` must be provided. @@ -6062,8 +6490,8 @@ def pad( Examples -------- - >>> ds = xr.Dataset({'foo': ('x', range(5))}) - >>> ds.pad(x=(1,2)) + >>> ds = xr.Dataset({"foo": ("x", range(5))}) + >>> ds.pad(x=(1, 2)) Dimensions: (x: 8) Dimensions without coordinates: x @@ -6128,18 +6556,18 @@ def idxmin( dim : str, optional Dimension over which to apply `idxmin`. This is optional for 1D variables, but required for variables with 2 or more dimensions. - skipna : bool or None, default None + skipna : bool or None, default: None If True, skip missing values (as marked by NaN). By default, only skips missing values for ``float``, ``complex``, and ``object`` dtypes; other dtypes either do not have a sentinel missing value (``int``) or ``skipna=True`` has not been implemented (``datetime64`` or ``timedelta64``). - fill_value : Any, default NaN + fill_value : Any, default: NaN Value to be filled in case all of the values along a dimension are null. By default this is NaN. The fill value and result are automatically converted to a compatible dtype if possible. Ignored if ``skipna`` is False. - keep_attrs : bool, default False + keep_attrs : bool, default: False If True, the attributes (``attrs``) will be copied from the original object to the new one. If False (default), the new object will be returned without attributes. @@ -6157,17 +6585,20 @@ def idxmin( Examples -------- - >>> array1 = xr.DataArray([0, 2, 1, 0, -2], dims="x", - ... coords={"x": ['a', 'b', 'c', 'd', 'e']}) - >>> array2 = xr.DataArray([[2.0, 1.0, 2.0, 0.0, -2.0], - ... [-4.0, np.NaN, 2.0, np.NaN, -2.0], - ... [np.NaN, np.NaN, 1., np.NaN, np.NaN]], - ... dims=["y", "x"], - ... coords={"y": [-1, 0, 1], - ... "x": ['a', 'b', 'c', 'd', 'e']} - ... ) - >>> ds = xr.Dataset({'int': array1, 'float': array2}) - >>> ds.min(dim='x') + >>> array1 = xr.DataArray( + ... [0, 2, 1, 0, -2], dims="x", coords={"x": ["a", "b", "c", "d", "e"]} + ... ) + >>> array2 = xr.DataArray( + ... [ + ... [2.0, 1.0, 2.0, 0.0, -2.0], + ... [-4.0, np.NaN, 2.0, np.NaN, -2.0], + ... [np.NaN, np.NaN, 1.0, np.NaN, np.NaN], + ... ], + ... dims=["y", "x"], + ... coords={"y": [-1, 0, 1], "x": ["a", "b", "c", "d", "e"]}, + ... ) + >>> ds = xr.Dataset({"int": array1, "float": array2}) + >>> ds.min(dim="x") Dimensions: (y: 3) Coordinates: @@ -6175,7 +6606,7 @@ def idxmin( Data variables: int int64 -2 float (y) float64 -2.0 -4.0 1.0 - >>> ds.argmin(dim='x') + >>> ds.argmin(dim="x") Dimensions: (y: 3) Coordinates: @@ -6183,14 +6614,14 @@ def idxmin( Data variables: int int64 4 float (y) int64 4 0 2 - >>> ds.idxmin(dim='x') + >>> ds.idxmin(dim="x") Dimensions: (y: 3) Coordinates: * y (y) int64 -1 0 1 Data variables: int >> array1 = xr.DataArray([0, 2, 1, 0, -2], dims="x", - ... coords={"x": ['a', 'b', 'c', 'd', 'e']}) - >>> array2 = xr.DataArray([[2.0, 1.0, 2.0, 0.0, -2.0], - ... [-4.0, np.NaN, 2.0, np.NaN, -2.0], - ... [np.NaN, np.NaN, 1., np.NaN, np.NaN]], - ... dims=["y", "x"], - ... coords={"y": [-1, 0, 1], - ... "x": ['a', 'b', 'c', 'd', 'e']} - ... ) - >>> ds = xr.Dataset({'int': array1, 'float': array2}) - >>> ds.max(dim='x') + >>> array1 = xr.DataArray( + ... [0, 2, 1, 0, -2], dims="x", coords={"x": ["a", "b", "c", "d", "e"]} + ... ) + >>> array2 = xr.DataArray( + ... [ + ... [2.0, 1.0, 2.0, 0.0, -2.0], + ... [-4.0, np.NaN, 2.0, np.NaN, -2.0], + ... [np.NaN, np.NaN, 1.0, np.NaN, np.NaN], + ... ], + ... dims=["y", "x"], + ... coords={"y": [-1, 0, 1], "x": ["a", "b", "c", "d", "e"]}, + ... ) + >>> ds = xr.Dataset({"int": array1, "float": array2}) + >>> ds.max(dim="x") Dimensions: (y: 3) Coordinates: @@ -6270,7 +6704,7 @@ def idxmax( Data variables: int int64 2 float (y) float64 2.0 2.0 1.0 - >>> ds.argmax(dim='x') + >>> ds.argmax(dim="x") Dimensions: (y: 3) Coordinates: @@ -6278,7 +6712,7 @@ def idxmax( Data variables: int int64 1 float (y) int64 0 2 2 - >>> ds.idxmax(dim='x') + >>> ds.idxmax(dim="x") Dimensions: (y: 3) Coordinates: @@ -6294,8 +6728,134 @@ def idxmax( skipna=skipna, fill_value=fill_value, keep_attrs=keep_attrs, - ), + ) ) + def argmin(self, dim=None, axis=None, **kwargs): + """Indices of the minima of the member variables. + + If there are multiple minima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : str, optional + The dimension over which to find the minimum. By default, finds minimum over + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will be an error, since DataArray.argmin will + return a dict with indices for all dimensions, which does not make sense for + a Dataset. + axis : int, optional + Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments + can be supplied. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : Dataset + + See also + -------- + DataArray.argmin + + """ + if dim is None and axis is None: + warnings.warn( + "Once the behaviour of DataArray.argmin() and Variable.argmin() with " + "neither dim nor axis argument changes to return a dict of indices of " + "each dimension, for consistency it will be an error to call " + "Dataset.argmin() with no argument, since we don't return a dict of " + "Datasets.", + DeprecationWarning, + stacklevel=2, + ) + if ( + dim is None + or axis is not None + or (not isinstance(dim, Sequence) and dim is not ...) + or isinstance(dim, str) + ): + # Return int index if single dimension is passed, and is not part of a + # sequence + argmin_func = getattr(duck_array_ops, "argmin") + return self.reduce(argmin_func, dim=dim, axis=axis, **kwargs) + else: + raise ValueError( + "When dim is a sequence or ..., DataArray.argmin() returns a dict. " + "dicts cannot be contained in a Dataset, so cannot call " + "Dataset.argmin() with a sequence or ... for dim" + ) + + def argmax(self, dim=None, axis=None, **kwargs): + """Indices of the maxima of the member variables. + + If there are multiple maxima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : str, optional + The dimension over which to find the maximum. By default, finds maximum over + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will be an error, since DataArray.argmax will + return a dict with indices for all dimensions, which does not make sense for + a Dataset. + axis : int, optional + Axis over which to apply `argmax`. Only one of the 'dim' and 'axis' arguments + can be supplied. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : Dataset + + See also + -------- + DataArray.argmax + + """ + if dim is None and axis is None: + warnings.warn( + "Once the behaviour of DataArray.argmax() and Variable.argmax() with " + "neither dim nor axis argument changes to return a dict of indices of " + "each dimension, for consistency it will be an error to call " + "Dataset.argmax() with no argument, since we don't return a dict of " + "Datasets.", + DeprecationWarning, + stacklevel=2, + ) + if ( + dim is None + or axis is not None + or (not isinstance(dim, Sequence) and dim is not ...) + or isinstance(dim, str) + ): + # Return int index if single dimension is passed, and is not part of a + # sequence + argmax_func = getattr(duck_array_ops, "argmax") + return self.reduce(argmax_func, dim=dim, axis=axis, **kwargs) + else: + raise ValueError( + "When dim is a sequence or ..., DataArray.argmin() returns a dict. " + "dicts cannot be contained in a Dataset, so cannot call " + "Dataset.argmin() with a sequence or ... for dim" + ) + ops.inject_all_ops_and_reduce_methods(Dataset, array_only=False) diff --git a/xarray/core/dtypes.py b/xarray/core/dtypes.py index 4db2990accc..167f00fa932 100644 --- a/xarray/core/dtypes.py +++ b/xarray/core/dtypes.py @@ -137,8 +137,7 @@ def get_neg_infinity(dtype): def is_datetime_like(dtype): - """Check if a dtype is a subclass of the numpy datetime types - """ + """Check if a dtype is a subclass of the numpy datetime types""" return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 1340b456cf2..e6c3aae5bf8 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -4,8 +4,10 @@ accept or return xarray objects. """ import contextlib +import datetime import inspect import warnings +from distutils.version import LooseVersion from functools import partial import numpy as np @@ -13,10 +15,17 @@ from . import dask_array_compat, dask_array_ops, dtypes, npcompat, nputils from .nputils import nanfirst, nanlast -from .pycompat import dask_array_type +from .pycompat import ( + cupy_array_type, + dask_array_type, + is_duck_dask_array, + sparse_array_type, +) +from .utils import is_duck_array try: import dask.array as dask_array + from dask.base import tokenize except ImportError: dask_array = None # type: ignore @@ -37,7 +46,7 @@ def f(*args, **kwargs): dispatch_args = args[0] else: dispatch_args = args[array_args] - if any(isinstance(a, dask_array_type) for a in dispatch_args): + if any(is_duck_dask_array(a) for a in dispatch_args): try: wrapped = getattr(dask_module, name) except AttributeError as e: @@ -55,7 +64,7 @@ def f(*args, **kwargs): def fail_on_dask_array_input(values, msg=None, func_name=None): - if isinstance(values, dask_array_type): + if is_duck_dask_array(values): if msg is None: msg = "%r is not yet a valid method on dask arrays" if func_name is None: @@ -127,7 +136,7 @@ def notnull(data): def gradient(x, coord, axis, edge_order): - if isinstance(x, dask_array_type): + if is_duck_dask_array(x): return dask_array.gradient(x, coord, axis=axis, edge_order=edge_order) return np.gradient(x, coord, axis=axis, edge_order=edge_order) @@ -149,17 +158,41 @@ def trapz(y, x, axis): ) -def asarray(data): - return ( - data - if (isinstance(data, dask_array_type) or hasattr(data, "__array_function__")) - else np.asarray(data) - ) +def astype(data, dtype, **kwargs): + try: + import sparse + except ImportError: + sparse = None + + if ( + sparse is not None + and isinstance(data, sparse_array_type) + and LooseVersion(sparse.__version__) < LooseVersion("0.11.0") + and "casting" in kwargs + ): + warnings.warn( + "The current version of sparse does not support the 'casting' argument. It will be ignored in the call to astype().", + RuntimeWarning, + stacklevel=4, + ) + kwargs.pop("casting") + + return data.astype(dtype, **kwargs) + + +def asarray(data, xp=np): + return data if is_duck_array(data) else xp.asarray(data) def as_shared_dtype(scalars_or_arrays): """Cast a arrays to a shared dtype using xarray's type promotion rules.""" - arrays = [asarray(x) for x in scalars_or_arrays] + + if any([isinstance(x, cupy_array_type) for x in scalars_or_arrays]): + import cupy as cp + + arrays = [asarray(x, xp=cp) for x in scalars_or_arrays] + else: + arrays = [asarray(x) for x in scalars_or_arrays] # Pass arrays directly instead of dtypes to result_type so scalars # get handled properly. # Note that result_type() safely gets the dtype from dask arrays without @@ -170,10 +203,10 @@ def as_shared_dtype(scalars_or_arrays): def lazy_array_equiv(arr1, arr2): """Like array_equal, but doesn't actually compare values. - Returns True when arr1, arr2 identical or their dask names are equal. - Returns False when shapes are not equal. - Returns None when equality cannot determined: one or both of arr1, arr2 are numpy arrays; - or their dask names are not equal + Returns True when arr1, arr2 identical or their dask tokens are equal. + Returns False when shapes are not equal. + Returns None when equality cannot determined: one or both of arr1, arr2 are numpy arrays; + or their dask tokens are not equal """ if arr1 is arr2: return True @@ -181,13 +214,9 @@ def lazy_array_equiv(arr1, arr2): arr2 = asarray(arr2) if arr1.shape != arr2.shape: return False - if ( - dask_array - and isinstance(arr1, dask_array_type) - and isinstance(arr2, dask_array_type) - ): - # GH3068 - if arr1.name == arr2.name: + if dask_array and is_duck_dask_array(arr1) and is_duck_dask_array(arr2): + # GH3068, GH4221 + if tokenize(arr1) == tokenize(arr2): return True else: return None @@ -195,20 +224,21 @@ def lazy_array_equiv(arr1, arr2): def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8): - """Like np.allclose, but also allows values to be NaN in both arrays - """ + """Like np.allclose, but also allows values to be NaN in both arrays""" arr1 = asarray(arr1) arr2 = asarray(arr2) + lazy_equiv = lazy_array_equiv(arr1, arr2) if lazy_equiv is None: - return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all()) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") + return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all()) else: return lazy_equiv def array_equiv(arr1, arr2): - """Like np.array_equal, but also allows values to be NaN in both arrays - """ + """Like np.array_equal, but also allows values to be NaN in both arrays""" arr1 = asarray(arr1) arr2 = asarray(arr2) lazy_equiv = lazy_array_equiv(arr1, arr2) @@ -238,8 +268,7 @@ def array_notnull_equiv(arr1, arr2): def count(data, axis=None): - """Count the number of non-NA in this array along the given axis or axes - """ + """Count the number of non-NA in this array along the given axis or axes""" return np.sum(np.logical_not(isnull(data)), axis=axis) @@ -298,12 +327,16 @@ def f(values, axis=None, skipna=None, **kwargs): nanname = "nan" + name func = getattr(nanops, nanname) else: + if name in ["sum", "prod"]: + kwargs.pop("min_count", None) func = _dask_or_eager_func(name, dask_module=dask_module) try: - return func(values, axis=axis, **kwargs) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", "All-NaN slice encountered") + return func(values, axis=axis, **kwargs) except AttributeError: - if not isinstance(values, dask_array_type): + if not is_duck_dask_array(values): raise try: # dask/dask#3133 dask sometimes needs dtype argument # if func does not accept dtype, then raises TypeError @@ -334,11 +367,12 @@ def f(values, axis=None, skipna=None, **kwargs): median.numeric_only = True prod = _create_nan_agg_method("prod") prod.numeric_only = True -sum.available_min_count = True +prod.available_min_count = True cumprod_1d = _create_nan_agg_method("cumprod") cumprod_1d.numeric_only = True cumsum_1d = _create_nan_agg_method("cumsum") cumsum_1d.numeric_only = True +unravel_index = _dask_or_eager_func("unravel_index") _mean = _create_nan_agg_method("mean") @@ -462,8 +496,7 @@ def timedelta_to_numeric(value, datetime_unit="ns", dtype=float): def _to_pytimedelta(array, unit="us"): - index = pd.TimedeltaIndex(array.ravel(), unit=unit) - return index.to_pytimedelta().reshape(array.shape) + return array.astype(f"timedelta64[{unit}]").astype(datetime.timedelta) def np_timedelta64_to_float(array, datetime_unit): @@ -492,8 +525,7 @@ def pd_timedelta_to_float(value, datetime_unit): def py_timedelta_to_float(array, datetime_unit): - """Convert a timedelta object to a float, possibly at a loss of resolution. - """ + """Convert a timedelta object to a float, possibly at a loss of resolution.""" array = np.asarray(array) array = np.reshape([a.total_seconds() for a in array.ravel()], array.shape) * 1e6 conversion_factor = np.timedelta64(1, "us") / np.timedelta64(1, datetime_unit) @@ -518,7 +550,7 @@ def mean(array, axis=None, skipna=None, **kwargs): + offset ) elif _contains_cftime_datetimes(array): - if isinstance(array, dask_array_type): + if is_duck_dask_array(array): raise NotImplementedError( "Computing the mean of an array containing " "cftime.datetime objects is not yet implemented on " @@ -565,8 +597,7 @@ def cumsum(array, axis=None, **kwargs): def first(values, axis, skipna=None): - """Return the first non-NA elements in this array along the given axis - """ + """Return the first non-NA elements in this array along the given axis""" if (skipna or skipna is None) and values.dtype.kind not in "iSU": # only bother for dtypes that can hold NaN _fail_on_dask_array_input_skipna(values) @@ -575,8 +606,7 @@ def first(values, axis, skipna=None): def last(values, axis, skipna=None): - """Return the last non-NA elements in this array along the given axis - """ + """Return the last non-NA elements in this array along the given axis""" if (skipna or skipna is None) and values.dtype.kind not in "iSU": # only bother for dtypes that can hold NaN _fail_on_dask_array_input_skipna(values) @@ -589,16 +619,15 @@ def rolling_window(array, axis, window, center, fill_value): Make an ndarray with a rolling window of axis-th dimension. The rolling dimension will be placed at the last dimension. """ - if isinstance(array, dask_array_type): + if is_duck_dask_array(array): return dask_array_ops.rolling_window(array, axis, window, center, fill_value) else: # np.ndarray return nputils.rolling_window(array, axis, window, center, fill_value) def least_squares(lhs, rhs, rcond=None, skipna=False): - """Return the coefficients and residuals of a least-squares fit. - """ - if isinstance(rhs, dask_array_type): + """Return the coefficients and residuals of a least-squares fit.""" + if is_duck_dask_array(rhs): return dask_array_ops.least_squares(lhs, rhs, rcond=rcond, skipna=skipna) else: return nputils.least_squares(lhs, rhs, rcond=rcond, skipna=skipna) diff --git a/xarray/core/extensions.py b/xarray/core/extensions.py index e81070d18fd..ee4c3ebc9e6 100644 --- a/xarray/core/extensions.py +++ b/xarray/core/extensions.py @@ -88,35 +88,32 @@ def register_dataset_accessor(name): Examples -------- - In your library code:: - - import xarray as xr - - @xr.register_dataset_accessor('geo') - class GeoAccessor: - def __init__(self, xarray_obj): - self._obj = xarray_obj - - @property - def center(self): - # return the geographic center point of this dataset - lon = self._obj.latitude - lat = self._obj.longitude - return (float(lon.mean()), float(lat.mean())) - - def plot(self): - # plot this array's data on a map, e.g., using Cartopy - pass + In your library code: + + >>> @xr.register_dataset_accessor("geo") + ... class GeoAccessor: + ... def __init__(self, xarray_obj): + ... self._obj = xarray_obj + ... + ... @property + ... def center(self): + ... # return the geographic center point of this dataset + ... lon = self._obj.latitude + ... lat = self._obj.longitude + ... return (float(lon.mean()), float(lat.mean())) + ... + ... def plot(self): + ... # plot this array's data on a map, e.g., using Cartopy + ... pass Back in an interactive IPython session: - >>> ds = xarray.Dataset( - ... {"longitude": np.linspace(0, 10), "latitude": np.linspace(0, 20)} - ... ) - >>> ds.geo.center - (5.0, 10.0) - >>> ds.geo.plot() - # plots data on a map + >>> ds = xr.Dataset( + ... {"longitude": np.linspace(0, 10), "latitude": np.linspace(0, 20)} + ... ) + >>> ds.geo.center + (10.0, 5.0) + >>> ds.geo.plot() # plots data on a map See also -------- diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index 534d253ecc8..282620e3569 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -3,7 +3,7 @@ import contextlib import functools from datetime import datetime, timedelta -from itertools import zip_longest +from itertools import chain, zip_longest from typing import Hashable import numpy as np @@ -13,6 +13,7 @@ from .duck_array_ops import array_equiv from .options import OPTIONS from .pycompat import dask_array_type, sparse_array_type +from .utils import is_duck_array def pretty_print(x, numchars: int): @@ -140,7 +141,7 @@ def format_item(x, timedelta_format=None, quote_strings=True): return format_timedelta(x, timedelta_format=timedelta_format) elif isinstance(x, (str, bytes)): return repr(x) if quote_strings else x - elif isinstance(x, (float, np.float)): + elif np.issubdtype(type(x), np.floating): return f"{x:.4}" else: return str(x) @@ -261,6 +262,8 @@ def inline_variable_array_repr(var, max_width): return inline_dask_repr(var.data) elif isinstance(var._data, sparse_array_type): return inline_sparse_repr(var.data) + elif hasattr(var._data, "_repr_inline_"): + return var._data._repr_inline_(max_width) elif hasattr(var._data, "__array_function__"): return maybe_truncate(repr(var._data).replace("\n", " "), max_width) else: @@ -298,12 +301,10 @@ def _summarize_coord_multiindex(coord, col_width, marker): def _summarize_coord_levels(coord, col_width, marker="-"): return "\n".join( - [ - summarize_variable( - lname, coord.get_level_variable(lname), col_width, marker=marker - ) - for lname in coord.level_names - ] + summarize_variable( + lname, coord.get_level_variable(lname), col_width, marker=marker + ) + for lname in coord.level_names ) @@ -364,12 +365,25 @@ def _calculate_col_width(col_items): return col_width -def _mapping_repr(mapping, title, summarizer, col_width=None): +def _mapping_repr(mapping, title, summarizer, col_width=None, max_rows=None): if col_width is None: col_width = _calculate_col_width(mapping) + if max_rows is None: + max_rows = OPTIONS["display_max_rows"] summary = [f"{title}:"] if mapping: - summary += [summarizer(k, v, col_width) for k, v in mapping.items()] + len_mapping = len(mapping) + if len_mapping > max_rows: + summary = [f"{summary[0]} ({max_rows}/{len_mapping})"] + first_rows = max_rows // 2 + max_rows % 2 + items = list(mapping.items()) + summary += [summarizer(k, v, col_width) for k, v in items[:first_rows]] + if max_rows > 1: + last_rows = max_rows // 2 + summary += [pretty_print(" ...", col_width) + " ..."] + summary += [summarizer(k, v, col_width) for k, v in items[-last_rows:]] + else: + summary += [summarizer(k, v, col_width) for k, v in mapping.items()] else: summary += [EMPTY_REPR] return "\n".join(summary) @@ -424,6 +438,17 @@ def set_numpy_options(*args, **kwargs): np.set_printoptions(**original) +def limit_lines(string: str, *, limit: int): + """ + If the string is more lines than the limit, + this returns the middle lines replaced by an ellipsis + """ + lines = string.splitlines() + if len(lines) > limit: + string = "\n".join(chain(lines[: limit // 2], ["..."], lines[-limit // 2 :])) + return string + + def short_numpy_repr(array): array = np.asarray(array) @@ -446,10 +471,8 @@ def short_data_repr(array): internal_data = getattr(array, "variable", array)._data if isinstance(array, np.ndarray): return short_numpy_repr(array) - elif hasattr(internal_data, "__array_function__") or isinstance( - internal_data, dask_array_type - ): - return repr(array.data) + elif is_duck_array(internal_data): + return limit_lines(repr(array.data), limit=40) elif array._in_memory or array.size < 1e5: return short_numpy_repr(array) else: @@ -516,13 +539,6 @@ def diff_dim_summary(a, b): def _diff_mapping_repr(a_mapping, b_mapping, compat, title, summarizer, col_width=None): - def is_array_like(value): - return ( - hasattr(value, "ndim") - and hasattr(value, "shape") - and hasattr(value, "dtype") - ) - def extra_items_repr(extra_keys, mapping, ab_side): extra_repr = [summarizer(k, mapping[k], col_width) for k in extra_keys] if extra_repr: @@ -541,11 +557,14 @@ def extra_items_repr(extra_keys, mapping, ab_side): for k in a_keys & b_keys: try: # compare xarray variable - compatible = getattr(a_mapping[k], compat)(b_mapping[k]) + if not callable(compat): + compatible = getattr(a_mapping[k], compat)(b_mapping[k]) + else: + compatible = compat(a_mapping[k], b_mapping[k]) is_variable = True except AttributeError: # compare attribute value - if is_array_like(a_mapping[k]) or is_array_like(b_mapping[k]): + if is_duck_array(a_mapping[k]) or is_duck_array(b_mapping[k]): compatible = array_equiv(a_mapping[k], b_mapping[k]) else: compatible = a_mapping[k] == b_mapping[k] @@ -562,7 +581,7 @@ def extra_items_repr(extra_keys, mapping, ab_side): for m in (a_mapping, b_mapping): attr_s = "\n".join( - [summarize_attr(ak, av) for ak, av in m[k].attrs.items()] + summarize_attr(ak, av) for ak, av in m[k].attrs.items() ) attrs_summary.append(attr_s) @@ -574,7 +593,7 @@ def extra_items_repr(extra_keys, mapping, ab_side): diff_items += [ab_side + s[1:] for ab_side, s in zip(("L", "R"), temp)] if diff_items: - summary += ["Differing {}:".format(title.lower())] + diff_items + summary += [f"Differing {title.lower()}:"] + diff_items summary += extra_items_repr(a_keys - b_keys, a_mapping, "left") summary += extra_items_repr(b_keys - a_keys, b_mapping, "right") @@ -598,8 +617,13 @@ def extra_items_repr(extra_keys, mapping, ab_side): def _compat_to_str(compat): + if callable(compat): + compat = compat.__name__ + if compat == "equals": return "equal" + elif compat == "allclose": + return "close" else: return compat @@ -613,8 +637,12 @@ def diff_array_repr(a, b, compat): ] summary.append(diff_dim_summary(a, b)) + if callable(compat): + equiv = compat + else: + equiv = array_equiv - if not array_equiv(a.data, b.data): + if not equiv(a.data, b.data): temp = [wrap_indent(short_numpy_repr(obj), start=" ") for obj in (a, b)] diff_data_repr = [ ab_side + "\n" + ab_data_repr diff --git a/xarray/core/formatting_html.py b/xarray/core/formatting_html.py index 8678a58b381..3392aef8da3 100644 --- a/xarray/core/formatting_html.py +++ b/xarray/core/formatting_html.py @@ -1,18 +1,22 @@ import uuid from collections import OrderedDict -from functools import partial +from functools import lru_cache, partial from html import escape import pkg_resources from .formatting import inline_variable_array_repr, short_data_repr -CSS_FILE_PATH = "/".join(("static", "css", "style.css")) -CSS_STYLE = pkg_resources.resource_string("xarray", CSS_FILE_PATH).decode("utf8") +STATIC_FILES = ("static/html/icons-svg-inline.html", "static/css/style.css") -ICONS_SVG_PATH = "/".join(("static", "html", "icons-svg-inline.html")) -ICONS_SVG = pkg_resources.resource_string("xarray", ICONS_SVG_PATH).decode("utf8") +@lru_cache(None) +def _load_static_files(): + """Lazily load the resource files into memory the first time they are needed""" + return [ + pkg_resources.resource_string("xarray", fname).decode("utf8") + for fname in STATIC_FILES + ] def short_data_repr_html(array): @@ -20,7 +24,9 @@ def short_data_repr_html(array): internal_data = getattr(array, "variable", array)._data if hasattr(internal_data, "_repr_html_"): return internal_data._repr_html_() - return escape(short_data_repr(array)) + else: + text = escape(short_data_repr(array)) + return f"
{text}
" def format_dims(dims, coord_names): @@ -123,7 +129,7 @@ def summarize_variable(name, var, is_index=False, dtype=None, preview=None): f"" f"
{attrs_ul}
" - f"
{data_repr}
" + f"
{data_repr}
" ) @@ -182,8 +188,9 @@ def dim_section(obj): def array_section(obj): # "unique" id to expand/collapse the section data_id = "section-" + str(uuid.uuid4()) - collapsed = "" - preview = escape(inline_variable_array_repr(obj.variable, max_width=70)) + collapsed = "checked" + variable = getattr(obj, "variable", obj) + preview = escape(inline_variable_array_repr(variable, max_width=70)) data_repr = short_data_repr_html(obj) data_icon = _icon("icon-database") @@ -192,7 +199,7 @@ def array_section(obj): f"" f"" f"
{preview}
" - f"
{data_repr}
" + f"
{data_repr}
" "" ) @@ -221,14 +228,21 @@ def array_section(obj): ) -def _obj_repr(header_components, sections): +def _obj_repr(obj, header_components, sections): + """Return HTML repr of an xarray object. + + If CSS is not injected (untrusted notebook), fallback to the plain text repr. + + """ header = f"
{''.join(h for h in header_components)}
" sections = "".join(f"
  • {s}
  • " for s in sections) + icons_svg, css_style = _load_static_files() return ( "
    " - f"{ICONS_SVG}" - "
    " + f"{icons_svg}" + f"
    {escape(repr(obj))}
    " + "" @@ -240,12 +254,12 @@ def array_repr(arr): dims = OrderedDict((k, v) for k, v in zip(arr.dims, arr.shape)) obj_type = "xarray.{}".format(type(arr).__name__) - arr_name = "'{}'".format(arr.name) if getattr(arr, "name", None) else "" + arr_name = f"'{arr.name}'" if getattr(arr, "name", None) else "" coord_names = list(arr.coords) if hasattr(arr, "coords") else [] header_components = [ - "
    {}
    ".format(obj_type), - "
    {}
    ".format(arr_name), + f"
    {obj_type}
    ", + f"
    {arr_name}
    ", format_dims(dims, coord_names), ] @@ -256,7 +270,7 @@ def array_repr(arr): sections.append(attr_section(arr.attrs)) - return _obj_repr(header_components, sections) + return _obj_repr(arr, header_components, sections) def dataset_repr(ds): @@ -271,4 +285,4 @@ def dataset_repr(ds): attr_section(ds.attrs), ] - return _obj_repr(header_components, sections) + return _obj_repr(ds, header_components, sections) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 5a5f4c0d296..e1e5a0fabe8 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -29,7 +29,7 @@ def check_reduce_dims(reduce_dims, dimensions): if reduce_dims is not ...: if is_scalar(reduce_dims): reduce_dims = [reduce_dims] - if any([dim not in dimensions for dim in reduce_dims]): + if any(dim not in dimensions for dim in reduce_dims): raise ValueError( "cannot reduce over dimensions %r. expected either '...' to reduce over all dimensions or one or more of %r." % (reduce_dims, dimensions) @@ -43,7 +43,7 @@ def unique_value_groups(ar, sort=True): ---------- ar : array-like Input array. This will be flattened if it is not already 1-D. - sort : boolean, optional + sort : bool, optional Whether or not to sort unique values. Returns @@ -64,8 +64,8 @@ def unique_value_groups(ar, sort=True): def _dummy_copy(xarray_obj): - from .dataset import Dataset from .dataarray import DataArray + from .dataset import Dataset if isinstance(xarray_obj, Dataset): res = Dataset( @@ -102,8 +102,7 @@ def _is_one_or_none(obj): def _consolidate_slices(slices): - """Consolidate adjacent slices in a list of slices. - """ + """Consolidate adjacent slices in a list of slices.""" result = [] last_slice = slice(None) for slice_ in slices: @@ -128,7 +127,7 @@ def _inverse_permutation_indices(positions): Parameters ---------- - positions : list of np.ndarray or slice objects. + positions : list of ndarray or slice If slice objects, all are assumed to be slices. Returns @@ -272,8 +271,8 @@ def __init__( squeeze=False, grouper=None, bins=None, - restore_coord_dims=None, - cut_kwargs={}, + restore_coord_dims=True, + cut_kwargs=None, ): """Create a GroupBy object @@ -283,22 +282,24 @@ def __init__( Object to group. group : DataArray Array with the group values. - squeeze : boolean, optional + squeeze : bool, optional If "group" is a coordinate of object, `squeeze` controls whether the subarrays have a dimension of length 1 along that coordinate or if the dimension is squeezed out. - grouper : pd.Grouper, optional + grouper : pandas.Grouper, optional Used for grouping values along the `group` array. bins : array-like, optional If `bins` is specified, the groups will be discretized into the specified bins by `pandas.cut`. - restore_coord_dims : bool, optional + restore_coord_dims : bool, default: True If True, also restore the dimension order of multi-dimensional coordinates. cut_kwargs : dict, optional Extra keyword arguments to pass to `pandas.cut` """ + if cut_kwargs is None: + cut_kwargs = {} from .dataarray import DataArray if grouper is not None and bins is not None: @@ -308,7 +309,8 @@ def __init__( if not hashable(group): raise TypeError( "`group` must be an xarray.DataArray or the " - "name of an xarray variable or dimension" + "name of an xarray variable or dimension." + f"Received {group!r} instead." ) group = obj[group] if len(group) == 0: @@ -319,7 +321,7 @@ def __init__( group = _DummyGroup(obj, group.name, group.coords) if getattr(group, "name", None) is None: - raise ValueError("`group` must have a name") + group.name = "group" group, obj, stacked_dim, inserted_dims = _ensure_1d(group, obj) (group_dim,) = group.dims @@ -387,21 +389,6 @@ def __init__( "Failed to group data. Are you grouping by a variable that is all NaN?" ) - if ( - isinstance(obj, DataArray) - and restore_coord_dims is None - and any(obj[c].ndim > 1 for c in obj.coords) - ): - warnings.warn( - "This DataArray contains multi-dimensional " - "coordinates. In the future, the dimension order " - "of these coordinates will be restored as well " - "unless you specify restore_coord_dims=False.", - FutureWarning, - stacklevel=2, - ) - restore_coord_dims = False - # specification for the groupby operation self._obj = obj self._group = group @@ -545,8 +532,10 @@ def fillna(self, value): Parameters ---------- - value : valid type for the grouped object's fillna method - Used to fill all matching missing values by group. + value + Used to fill all matching missing values by group. Needs + to be of a valid type for the wrapped object's fillna + method. Returns ------- @@ -568,13 +557,13 @@ def quantile( Parameters ---------- - q : float in range of [0,1] (or sequence of floats) + q : float or sequence of float Quantile to compute, which must be between 0 and 1 inclusive. - dim : `...`, str or sequence of str, optional + dim : ..., str or sequence of str, optional Dimension(s) over which to apply quantile. Defaults to the grouped dimension. - interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'} + interpolation : {"linear", "lower", "higher", "midpoint", "nearest"}, default: "linear" This optional parameter specifies the interpolation method to use when the desired quantile lies between two data points ``i < j``: @@ -610,7 +599,7 @@ def quantile( >>> da = xr.DataArray( ... [[1.3, 8.4, 0.7, 6.9], [0.7, 4.2, 9.4, 1.5], [6.5, 7.3, 2.6, 1.9]], ... coords={"x": [0, 0, 1], "y": [1, 1, 2, 2]}, - ... dims=("y", "y"), + ... dims=("x", "y"), ... ) >>> ds = xr.Dataset({"a": da}) >>> da.groupby("x").quantile(0) @@ -618,8 +607,8 @@ def quantile( array([[0.7, 4.2, 0.7, 1.5], [6.5, 7.3, 2.6, 1.9]]) Coordinates: - quantile float64 0.0 * y (y) int64 1 1 2 2 + quantile float64 0.0 * x (x) int64 0 1 >>> ds.groupby("y").quantile(0, dim=...) @@ -635,6 +624,7 @@ def quantile( [4.2 , 6.3 , 8.4 ], [0.7 , 5.05, 9.4 ], [1.5 , 4.2 , 6.9 ]], + [[6.5 , 6.5 , 6.5 ], [7.3 , 7.3 , 7.3 ], [2.6 , 2.6 , 2.6 ], @@ -672,8 +662,8 @@ def where(self, cond, other=dtypes.NA): Parameters ---------- - cond : DataArray or Dataset with boolean dtype - Locations at which to preserve this objects values. + cond : DataArray or Dataset + Locations at which to preserve this objects values. dtypes have to be `bool` other : scalar, DataArray or Dataset, optional Value to use for locations in this object where ``cond`` is False. By default, inserts missing values. @@ -698,13 +688,11 @@ def _first_or_last(self, op, skipna, keep_attrs): return self.reduce(op, self._group_dim, skipna=skipna, keep_attrs=keep_attrs) def first(self, skipna=None, keep_attrs=None): - """Return the first element of each group along the group dimension - """ + """Return the first element of each group along the group dimension""" return self._first_or_last(duck_array_ops.first, skipna, keep_attrs) def last(self, skipna=None, keep_attrs=None): - """Return the last element of each group along the group dimension - """ + """Return the last element of each group along the group dimension""" return self._first_or_last(duck_array_ops.last, skipna, keep_attrs) def assign_coords(self, coords=None, **coords_kwargs): @@ -729,8 +717,7 @@ def _maybe_reorder(xarray_obj, dim, positions): class DataArrayGroupBy(GroupBy, ImplementsArrayReduce): - """GroupBy object specialized to grouping DataArray objects - """ + """GroupBy object specialized to grouping DataArray objects""" def _iter_grouped_shortcut(self): """Fast version of `_iter_grouped` that yields Variables without @@ -781,7 +768,7 @@ def map(self, func, shortcut=False, args=(), **kwargs): Parameters ---------- - func : function + func : callable Callable to apply to each array. shortcut : bool, optional Whether or not to shortcut evaluation under the assumptions that: @@ -795,9 +782,9 @@ def map(self, func, shortcut=False, args=(), **kwargs): If these conditions are satisfied `shortcut` provides significant speedup. This should be the case for many common groupby operations (e.g., applying numpy ufuncs). - ``*args`` : tuple, optional + *args : tuple, optional Positional arguments passed to `func`. - ``**kwargs`` + **kwargs Used to call `func(ar, **kwargs)` for each array `ar`. Returns @@ -827,7 +814,7 @@ def apply(self, func, shortcut=False, args=(), **kwargs): ) return self.map(func, shortcut=shortcut, args=args, **kwargs) - def _combine(self, applied, restore_coord_dims=False, shortcut=False): + def _combine(self, applied, shortcut=False): """Recombine the applied objects like the original.""" applied_example, applied = peek_at(applied) coord, dim, positions = self._infer_concat_args(applied_example) @@ -859,11 +846,11 @@ def reduce( Parameters ---------- - func : function + func : callable Function which can be called in the form `func(x, axis=axis, **kwargs)` to return the result of collapsing an np.ndarray over an integer valued axis. - dim : `...`, str or sequence of str, optional + dim : ..., str or sequence of str, optional Dimension(s) over which to apply `func`. axis : int or sequence of int, optional Axis(es) over which to apply `func`. Only one of the 'dimension' @@ -919,7 +906,7 @@ def map(self, func, args=(), shortcut=None, **kwargs): Parameters ---------- - func : function + func : callable Callable to apply to each sub-dataset. args : tuple, optional Positional arguments to pass to `func`. @@ -970,11 +957,11 @@ def reduce(self, func, dim=None, keep_attrs=None, **kwargs): Parameters ---------- - func : function + func : callable Function which can be called in the form `func(x, axis=axis, **kwargs)` to return the result of collapsing an np.ndarray over an integer valued axis. - dim : `...`, str or sequence of str, optional + dim : ..., str or sequence of str, optional Dimension(s) over which to apply `func`. axis : int or sequence of int, optional Axis(es) over which to apply `func`. Only one of the 'dimension' diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index a4a5fa2c466..a5d1896e74c 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -9,7 +9,7 @@ from .variable import Variable -def remove_unused_levels_categories(index, dataframe=None): +def remove_unused_levels_categories(index: pd.Index) -> pd.Index: """ Remove unused levels from MultiIndex and unused categories from CategoricalIndex """ @@ -25,14 +25,15 @@ def remove_unused_levels_categories(index, dataframe=None): else: level = level[index.codes[i]] levels.append(level) + # TODO: calling from_array() reorders MultiIndex levels. It would + # be best to avoid this, if possible, e.g., by using + # MultiIndex.remove_unused_levels() (which does not reorder) on the + # part of the MultiIndex that is not categorical, or by fixing this + # upstream in pandas. index = pd.MultiIndex.from_arrays(levels, names=index.names) elif isinstance(index, pd.CategoricalIndex): index = index.remove_unused_categories() - - if dataframe is None: - return index - dataframe = dataframe.set_index(index) - return dataframe.index, dataframe + return index class Indexes(collections.abc.Mapping): @@ -99,7 +100,7 @@ def isel_variable_and_index( if len(variable.dims) > 1: raise NotImplementedError( - "indexing multi-dimensional variable with indexes is not " "supported yet" + "indexing multi-dimensional variable with indexes is not supported yet" ) new_variable = variable.isel(indexers) @@ -129,8 +130,7 @@ def roll_index(index: pd.Index, count: int, axis: int = 0) -> pd.Index: def propagate_indexes( indexes: Optional[Dict[Hashable, pd.Index]], exclude: Optional[Any] = None ) -> Optional[Dict[Hashable, pd.Index]]: - """ Creates new indexes dict from existing dict optionally excluding some dimensions. - """ + """Creates new indexes dict from existing dict optionally excluding some dimensions.""" if exclude is None: exclude = () diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index ab049a0a4b4..843feb04479 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -11,7 +11,12 @@ from . import duck_array_ops, nputils, utils from .npcompat import DTypeLike -from .pycompat import dask_array_type, integer_types, sparse_array_type +from .pycompat import ( + dask_array_type, + integer_types, + is_duck_dask_array, + sparse_array_type, +) from .utils import is_dict_like, maybe_cast_to_coords_dtype @@ -50,8 +55,8 @@ def _expand_slice(slice_, size): def _sanitize_slice_element(x): - from .variable import Variable from .dataarray import DataArray + from .variable import Variable if isinstance(x, (Variable, DataArray)): x = x.values @@ -63,11 +68,6 @@ def _sanitize_slice_element(x): ) x = x[()] - if isinstance(x, np.timedelta64): - # pandas does not support indexing with np.timedelta64 yet: - # https://github.com/pandas-dev/pandas/issues/20393 - x = pd.Timedelta(x) - return x @@ -116,7 +116,7 @@ def convert_label_indexer(index, label, index_name="", method=None, tolerance=No if isinstance(label, slice): if method is not None or tolerance is not None: raise NotImplementedError( - "cannot use ``method`` argument if any indexers are " "slice objects" + "cannot use ``method`` argument if any indexers are slice objects" ) indexer = index.slice_indexer( _sanitize_slice_element(label.start), @@ -173,8 +173,10 @@ def convert_label_indexer(index, label, index_name="", method=None, tolerance=No else _asarray_tuplesafe(label) ) if label.ndim == 0: + # see https://github.com/pydata/xarray/pull/4292 for details + label_value = label[()] if label.dtype.kind in "mM" else label.item() if isinstance(index, pd.MultiIndex): - indexer, new_index = index.get_loc_level(label.item(), level=0) + indexer, new_index = index.get_loc_level(label_value, level=0) elif isinstance(index, pd.CategoricalIndex): if method is not None: raise ValueError( @@ -184,11 +186,9 @@ def convert_label_indexer(index, label, index_name="", method=None, tolerance=No raise ValueError( "'tolerance' is not a valid kwarg when indexing using a CategoricalIndex." ) - indexer = index.get_loc(label.item()) + indexer = index.get_loc(label_value) else: - indexer = index.get_loc( - label.item(), method=method, tolerance=tolerance - ) + indexer = index.get_loc(label_value, method=method, tolerance=tolerance) elif label.dtype.kind == "b": indexer = label else: @@ -275,25 +275,38 @@ def remap_label_indexers(data_obj, indexers, method=None, tolerance=None): return pos_indexers, new_indexes +def _normalize_slice(sl, size): + """Ensure that given slice only contains positive start and stop values + (stop can be -1 for full-size slices with negative steps, e.g. [-10::-1])""" + return slice(*sl.indices(size)) + + def slice_slice(old_slice, applied_slice, size): """Given a slice and the size of the dimension to which it will be applied, index it with another slice to return a new slice equivalent to applying the slices sequentially """ - step = (old_slice.step or 1) * (applied_slice.step or 1) - - # For now, use the hack of turning old_slice into an ndarray to reconstruct - # the slice start and stop. This is not entirely ideal, but it is still - # definitely better than leaving the indexer as an array. - items = _expand_slice(old_slice, size)[applied_slice] - if len(items) > 0: - start = items[0] - stop = items[-1] + int(np.sign(step)) - if stop < 0: - stop = None - else: - start = 0 - stop = 0 + old_slice = _normalize_slice(old_slice, size) + + size_after_old_slice = len(range(old_slice.start, old_slice.stop, old_slice.step)) + if size_after_old_slice == 0: + # nothing left after applying first slice + return slice(0) + + applied_slice = _normalize_slice(applied_slice, size_after_old_slice) + + start = old_slice.start + applied_slice.start * old_slice.step + if start < 0: + # nothing left after applying second slice + # (can only happen for old_slice.step < 0, e.g. [10::-1], [20:]) + return slice(0) + + stop = old_slice.start + applied_slice.stop * old_slice.step + if stop < 0: + stop = None + + step = old_slice.step * applied_slice.step + return slice(start, stop, step) @@ -464,8 +477,7 @@ def __init__(self, key): class ExplicitlyIndexed: - """Mixin to mark support for Indexer subclasses in indexing. - """ + """Mixin to mark support for Indexer subclasses in indexing.""" __slots__ = () @@ -502,8 +514,7 @@ def __getitem__(self, key): class LazilyOuterIndexedArray(ExplicitlyIndexedNDArrayMixin): - """Wrap an array to make basic and outer indexing lazy. - """ + """Wrap an array to make basic and outer indexing lazy.""" __slots__ = ("array", "key") @@ -579,8 +590,7 @@ def __repr__(self): class LazilyVectorizedIndexedArray(ExplicitlyIndexedNDArrayMixin): - """Wrap an array to make vectorized indexing lazy. - """ + """Wrap an array to make vectorized indexing lazy.""" __slots__ = ("array", "key") @@ -662,6 +672,12 @@ def __setitem__(self, key, value): self._ensure_copied() self.array[key] = value + def __deepcopy__(self, memo): + # CopyOnWriteArray is used to wrap backend array objects, which might + # point to files on disk, so we can't rely on the default deepcopy + # implementation. + return type(self)(self.array) + class MemoryCachedArray(ExplicitlyIndexedNDArrayMixin): __slots__ = ("array",) @@ -767,7 +783,7 @@ def _outer_to_numpy_indexer(key, shape): def _combine_indexers(old_key, shape, new_key): - """ Combine two indexers. + """Combine two indexers. Parameters ---------- @@ -852,7 +868,7 @@ def decompose_indexer( def _decompose_slice(key, size): - """ convert a slice to successive two slices. The first slice always has + """convert a slice to successive two slices. The first slice always has a positive step. """ start, stop, step = key.indices(size) @@ -897,10 +913,14 @@ def _decompose_vectorized_indexer( Even if the backend array only supports outer indexing, it is more efficient to load a subslice of the array than loading the entire array, - >>> backend_indexer = OuterIndexer([0, 1, 3], [2, 3]) - >>> array = array[backend_indexer] # load subslice of the array - >>> np_indexer = VectorizedIndexer([0, 2, 1], [0, 1, 0]) - >>> array[np_indexer] # vectorized indexing for on-memory np.ndarray. + >>> array = np.arange(36).reshape(6, 6) + >>> backend_indexer = OuterIndexer((np.array([0, 1, 3]), np.array([2, 3]))) + >>> # load subslice of the array + ... array = NumpyIndexingAdapter(array)[backend_indexer] + >>> np_indexer = VectorizedIndexer((np.array([0, 2, 1]), np.array([0, 1, 0]))) + >>> # vectorized indexing for on-memory np.ndarray. + ... NumpyIndexingAdapter(array)[np_indexer] + array([ 2, 21, 8]) """ assert isinstance(indexer, VectorizedIndexer) @@ -975,10 +995,16 @@ def _decompose_outer_indexer( Even if the backend array only supports basic indexing, it is more efficient to load a subslice of the array than loading the entire array, - >>> backend_indexer = BasicIndexer(slice(0, 3), slice(2, 3)) - >>> array = array[backend_indexer] # load subslice of the array - >>> np_indexer = OuterIndexer([0, 2, 1], [0, 1, 0]) - >>> array[np_indexer] # outer indexing for on-memory np.ndarray. + >>> array = np.arange(36).reshape(6, 6) + >>> backend_indexer = BasicIndexer((slice(0, 3), slice(2, 4))) + >>> # load subslice of the array + ... array = NumpyIndexingAdapter(array)[backend_indexer] + >>> np_indexer = OuterIndexer((np.array([0, 2, 1]), np.array([0, 1, 0]))) + >>> # outer indexing for on-memory np.ndarray. + ... NumpyIndexingAdapter(array)[np_indexer] + array([[ 2, 3, 2], + [14, 15, 14], + [ 8, 9, 8]]) """ if indexing_support == IndexingSupport.VECTORIZED: return indexer, BasicIndexer(()) @@ -1111,7 +1137,7 @@ def _masked_result_drop_slice(key, data=None): new_keys = [] for k in key: if isinstance(k, np.ndarray): - if isinstance(data, dask_array_type): + if is_duck_dask_array(data): new_keys.append(_dask_array_with_chunks_hint(k, chunks_hint)) elif isinstance(data, sparse_array_type): import sparse @@ -1308,7 +1334,7 @@ class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin): __slots__ = ("array",) def __init__(self, array): - """ This adapter is created in Variable.__getitem__ in + """This adapter is created in Variable.__getitem__ in Variable._broadcast_indexes. """ self.array = array @@ -1363,8 +1389,7 @@ def transpose(self, order): class PandasIndexAdapter(ExplicitlyIndexedNDArrayMixin): - """Wrap a pandas.Index to preserve dtypes and handle explicit indexing. - """ + """Wrap a pandas.Index to preserve dtypes and handle explicit indexing.""" __slots__ = ("array", "_dtype") diff --git a/xarray/core/merge.py b/xarray/core/merge.py index fea94246471..d29a9e1ff02 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -56,7 +56,7 @@ ) -def broadcast_dimension_size(variables: List[Variable],) -> Dict[Hashable, int]: +def broadcast_dimension_size(variables: List[Variable]) -> Dict[Hashable, int]: """Extract dimension sizes from a dictionary of variables. Raises ValueError if any dimensions have different sizes. @@ -71,8 +71,7 @@ def broadcast_dimension_size(variables: List[Variable],) -> Dict[Hashable, int]: class MergeError(ValueError): - """Error class for merge failures due to incompatible arguments. - """ + """Error class for merge failures due to incompatible arguments.""" # inherits from ValueError for backward compatibility # TODO: move this to an xarray.exceptions module? @@ -90,12 +89,12 @@ def unique_variable( ---------- name : hashable Name for this variable. - variables : list of xarray.Variable + variables : list of Variable List of Variable objects, all of which go by the same name in different inputs. - compat : {'identical', 'equals', 'broadcast_equals', 'no_conflicts', 'override'}, optional + compat : {"identical", "equals", "broadcast_equals", "no_conflicts", "override"}, optional Type of equality check to use. - equals: None or bool, + equals : None or bool, optional corresponding to result of compat test Returns @@ -170,7 +169,9 @@ def merge_collected( Parameters ---------- - + grouped : mapping + prioritized : mapping + compat : str Type of equality check to use when checking for conflicts. Returns @@ -335,7 +336,7 @@ def determine_coords( Parameters ---------- - list_of_mappings : list of dict or Dataset objects + list_of_mappings : list of dict or list of Dataset Of the same form as the arguments to expand_variable_dicts. Returns @@ -371,7 +372,7 @@ def coerce_pandas_values(objects: Iterable["CoercibleMapping"]) -> List["Dataset Parameters ---------- - objects : list of Dataset or mappings + objects : list of Dataset or mapping The mappings may contain any sort of objects coercible to xarray.Variables as keys, including pandas objects. @@ -410,11 +411,11 @@ def _get_priority_vars_and_indexes( Parameters ---------- - objects : list of dictionaries of variables + objects : list of dict-like of Variable Dictionaries in which to find the priority variables. priority_arg : int or None Integer object whose variable should take priority. - compat : {'identical', 'equals', 'broadcast_equals', 'no_conflicts'}, optional + compat : {"identical", "equals", "broadcast_equals", "no_conflicts"}, optional Compatibility checks to use when merging variables. Returns @@ -492,8 +493,7 @@ def assert_valid_explicit_coords(variables, dims, explicit_coords): def merge_attrs(variable_attrs, combine_attrs): - """Combine attributes from different variables according to combine_attrs - """ + """Combine attributes from different variables according to combine_attrs""" if not variable_attrs: # no attributes to merge return None @@ -501,7 +501,7 @@ def merge_attrs(variable_attrs, combine_attrs): if combine_attrs == "drop": return {} elif combine_attrs == "override": - return variable_attrs[0] + return dict(variable_attrs[0]) elif combine_attrs == "no_conflicts": result = dict(variable_attrs[0]) for attrs in variable_attrs[1:]: @@ -550,15 +550,15 @@ def merge_core( Parameters ---------- - objects : list of mappings + objects : list of mapping All values must be convertable to labeled arrays. - compat : {'identical', 'equals', 'broadcast_equals', 'no_conflicts', 'override'}, optional + compat : {"identical", "equals", "broadcast_equals", "no_conflicts", "override"}, optional Compatibility checks to use when merging variables. - join : {'outer', 'inner', 'left', 'right'}, optional + join : {"outer", "inner", "left", "right"}, optional How to combine objects with different indexes. - combine_attrs : {'drop', 'identical', 'no_conflicts', 'override'}, optional + combine_attrs : {"drop", "identical", "no_conflicts", "override"}, optional How to combine attributes of objects - priority_arg : integer, optional + priority_arg : int, optional Optional argument in `objects` that takes precedence over the others. explicit_coords : set, optional An explicit list of variables from `objects` that are coordinates. @@ -636,45 +636,47 @@ def merge( Parameters ---------- - objects : Iterable[Union[xarray.Dataset, xarray.DataArray, dict]] + objects : iterable of Dataset or iterable of DataArray or iterable of dict-like Merge together all variables from these objects. If any of them are DataArray objects, they must have a name. - compat : {'identical', 'equals', 'broadcast_equals', 'no_conflicts', 'override'}, optional + compat : {"identical", "equals", "broadcast_equals", "no_conflicts", "override"}, optional String indicating how to compare variables of the same name for potential conflicts: - - 'broadcast_equals': all values must be equal when variables are + - "broadcast_equals": all values must be equal when variables are broadcast against each other to ensure common dimensions. - - 'equals': all values and dimensions must be the same. - - 'identical': all values, dimensions and attributes must be the + - "equals": all values and dimensions must be the same. + - "identical": all values, dimensions and attributes must be the same. - - 'no_conflicts': only values which are not null in both datasets + - "no_conflicts": only values which are not null in both datasets must be equal. The returned dataset then contains the combination of all non-null values. - - 'override': skip comparing and pick variable from first dataset - join : {'outer', 'inner', 'left', 'right', 'exact'}, optional + - "override": skip comparing and pick variable from first dataset + join : {"outer", "inner", "left", "right", "exact"}, optional String indicating how to combine differing indexes in objects. - - 'outer': use the union of object indexes - - 'inner': use the intersection of object indexes - - 'left': use indexes from the first object with each dimension - - 'right': use indexes from the last object with each dimension - - 'exact': instead of aligning, raise `ValueError` when indexes to be + - "outer": use the union of object indexes + - "inner": use the intersection of object indexes + - "left": use indexes from the first object with each dimension + - "right": use indexes from the last object with each dimension + - "exact": instead of aligning, raise `ValueError` when indexes to be aligned are not equal - - 'override': if indexes are of same size, rewrite indexes to be + - "override": if indexes are of same size, rewrite indexes to be those of the first object with that dimension. Indexes for the same dimension must have the same size in all objects. - fill_value : scalar, optional - Value to use for newly missing values - combine_attrs : {'drop', 'identical', 'no_conflicts', 'override'}, - default 'drop' + fill_value : scalar or dict-like, optional + Value to use for newly missing values. If a dict-like, maps + variable names to fill values. Use a data array's name to + refer to its values. + combine_attrs : {"drop", "identical", "no_conflicts", "override"}, \ + default: "drop" String indicating how to combine attrs of the objects being merged: - - 'drop': empty attrs on returned Dataset. - - 'identical': all attrs must be the same on every object. - - 'no_conflicts': attrs from all objects are combined, any that have + - "drop": empty attrs on returned Dataset. + - "identical": all attrs must be the same on every object. + - "no_conflicts": attrs from all objects are combined, any that have the same name must also have the same value. - - 'override': skip comparing and copy attrs from the first dataset to + - "override": skip comparing and copy attrs from the first dataset to the result. Returns @@ -709,32 +711,32 @@ def merge( array([[1., 2.], [3., 5.]]) Coordinates: - * lat (lat) float64 35.0 40.0 - * lon (lon) float64 100.0 120.0 + * lat (lat) float64 35.0 40.0 + * lon (lon) float64 100.0 120.0 >>> y array([[5., 6.], [7., 8.]]) Coordinates: - * lat (lat) float64 35.0 42.0 - * lon (lon) float64 100.0 150.0 + * lat (lat) float64 35.0 42.0 + * lon (lon) float64 100.0 150.0 >>> z array([[0., 3.], [4., 9.]]) Coordinates: - * time (time) float64 30.0 60.0 - * lon (lon) float64 100.0 150.0 + * time (time) float64 30.0 60.0 + * lon (lon) float64 100.0 150.0 >>> xr.merge([x, y, z]) Dimensions: (lat: 3, lon: 3, time: 2) Coordinates: - * lat (lat) float64 35.0 40.0 42.0 - * lon (lon) float64 100.0 120.0 150.0 - * time (time) float64 30.0 60.0 + * lat (lat) float64 35.0 40.0 42.0 + * lon (lon) float64 100.0 120.0 150.0 + * time (time) float64 30.0 60.0 Data variables: var1 (lat, lon) float64 1.0 2.0 nan 3.0 5.0 nan nan nan nan var2 (lat, lon) float64 5.0 nan 6.0 nan nan nan 7.0 nan 8.0 @@ -744,9 +746,9 @@ def merge( Dimensions: (lat: 3, lon: 3, time: 2) Coordinates: - * lat (lat) float64 35.0 40.0 42.0 - * lon (lon) float64 100.0 120.0 150.0 - * time (time) float64 30.0 60.0 + * lat (lat) float64 35.0 40.0 42.0 + * lon (lon) float64 100.0 120.0 150.0 + * time (time) float64 30.0 60.0 Data variables: var1 (lat, lon) float64 1.0 2.0 nan 3.0 5.0 nan nan nan nan var2 (lat, lon) float64 5.0 nan 6.0 nan nan nan 7.0 nan 8.0 @@ -756,9 +758,9 @@ def merge( Dimensions: (lat: 3, lon: 3, time: 2) Coordinates: - * lat (lat) float64 35.0 40.0 42.0 - * lon (lon) float64 100.0 120.0 150.0 - * time (time) float64 30.0 60.0 + * lat (lat) float64 35.0 40.0 42.0 + * lon (lon) float64 100.0 120.0 150.0 + * time (time) float64 30.0 60.0 Data variables: var1 (lat, lon) float64 1.0 2.0 nan 3.0 5.0 nan nan nan nan var2 (lat, lon) float64 5.0 nan 6.0 nan nan nan 7.0 nan 8.0 @@ -768,9 +770,9 @@ def merge( Dimensions: (lat: 3, lon: 3, time: 2) Coordinates: - * lat (lat) float64 35.0 40.0 42.0 - * lon (lon) float64 100.0 120.0 150.0 - * time (time) float64 30.0 60.0 + * lat (lat) float64 35.0 40.0 42.0 + * lon (lon) float64 100.0 120.0 150.0 + * time (time) float64 30.0 60.0 Data variables: var1 (lat, lon) float64 1.0 2.0 -999.0 3.0 ... -999.0 -999.0 -999.0 var2 (lat, lon) float64 5.0 -999.0 6.0 -999.0 ... -999.0 7.0 -999.0 8.0 @@ -780,9 +782,9 @@ def merge( Dimensions: (lat: 2, lon: 2, time: 2) Coordinates: - * lat (lat) float64 35.0 40.0 - * lon (lon) float64 100.0 120.0 - * time (time) float64 30.0 60.0 + * lat (lat) float64 35.0 40.0 + * lon (lon) float64 100.0 120.0 + * time (time) float64 30.0 60.0 Data variables: var1 (lat, lon) float64 1.0 2.0 3.0 5.0 var2 (lat, lon) float64 5.0 6.0 7.0 8.0 @@ -792,9 +794,9 @@ def merge( Dimensions: (lat: 1, lon: 1, time: 2) Coordinates: - * lat (lat) float64 35.0 - * lon (lon) float64 100.0 - * time (time) float64 30.0 60.0 + * lat (lat) float64 35.0 + * lon (lon) float64 100.0 + * time (time) float64 30.0 60.0 Data variables: var1 (lat, lon) float64 1.0 var2 (lat, lon) float64 5.0 @@ -804,9 +806,9 @@ def merge( Dimensions: (lat: 1, lon: 1, time: 2) Coordinates: - * lat (lat) float64 35.0 - * lon (lon) float64 100.0 - * time (time) float64 30.0 60.0 + * lat (lat) float64 35.0 + * lon (lon) float64 100.0 + * time (time) float64 30.0 60.0 Data variables: var1 (lat, lon) float64 1.0 var2 (lat, lon) float64 5.0 @@ -816,9 +818,9 @@ def merge( Dimensions: (lat: 3, lon: 3, time: 2) Coordinates: - * lat (lat) float64 35.0 40.0 42.0 - * lon (lon) float64 100.0 120.0 150.0 - * time (time) float64 30.0 60.0 + * lat (lat) float64 35.0 40.0 42.0 + * lon (lon) float64 100.0 120.0 150.0 + * time (time) float64 30.0 60.0 Data variables: var1 (lat, lon) float64 1.0 2.0 nan 3.0 5.0 nan nan nan nan var2 (lat, lon) float64 5.0 nan 6.0 nan nan nan 7.0 nan 8.0 @@ -841,7 +843,7 @@ def merge( from .dataarray import DataArray from .dataset import Dataset - dict_like_objects = list() + dict_like_objects = [] for obj in objects: if not isinstance(obj, (DataArray, Dataset, dict)): raise TypeError( @@ -871,8 +873,7 @@ def dataset_merge_method( join: str, fill_value: Any, ) -> _MergeResult: - """Guts of the Dataset.merge method. - """ + """Guts of the Dataset.merge method.""" # we are locked into supporting overwrite_vars for the Dataset.merge # method due for backwards compatibility # TODO: consider deprecating it? @@ -929,9 +930,11 @@ def dataset_update_method( if coord_names: other[key] = value.drop_vars(coord_names) + # use ds.coords and not ds.indexes, else str coords are cast to object + indexes = {key: dataset.coords[key] for key in dataset.indexes.keys()} return merge_core( [dataset, other], priority_arg=1, - indexes=dataset.indexes, + indexes=indexes, combine_attrs="override", ) diff --git a/xarray/core/missing.py b/xarray/core/missing.py index f973b4a5468..8d112b4603c 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -1,5 +1,6 @@ import datetime as dt import warnings +from distutils.version import LooseVersion from functools import partial from numbers import Number from typing import Any, Callable, Dict, Hashable, Sequence, Union @@ -10,8 +11,9 @@ from . import utils from .common import _contains_datetime_like_objects, ones_like from .computation import apply_ufunc -from .duck_array_ops import dask_array_type, datetime_to_numeric, timedelta_to_numeric +from .duck_array_ops import datetime_to_numeric, timedelta_to_numeric from .options import _get_keep_attrs +from .pycompat import is_duck_dask_array from .utils import OrderedSet, is_scalar from .variable import Variable, broadcast_variables @@ -44,8 +46,7 @@ def _get_nan_block_lengths(obj, dim: Hashable, index: Variable): class BaseInterpolator: - """Generic interpolator class for normalizing interpolation methods - """ + """Generic interpolator class for normalizing interpolation methods""" cons_kwargs: Dict[str, Any] call_kwargs: Dict[str, Any] @@ -195,8 +196,7 @@ def __init__( def _apply_over_vars_with_dim(func, self, dim=None, **kwargs): - """Wrapper for datasets - """ + """Wrapper for datasets""" ds = type(self)(coords=self.coords, attrs=self.attrs) for name, var in self.data_vars.items(): @@ -208,7 +208,9 @@ def _apply_over_vars_with_dim(func, self, dim=None, **kwargs): return ds -def get_clean_interp_index(arr, dim: Hashable, use_coordinate: Union[str, bool] = True): +def get_clean_interp_index( + arr, dim: Hashable, use_coordinate: Union[str, bool] = True, strict: bool = True +): """Return index to use for x values in interpolation or curve fitting. Parameters @@ -221,6 +223,8 @@ def get_clean_interp_index(arr, dim: Hashable, use_coordinate: Union[str, bool] If use_coordinate is True, the coordinate that shares the name of the dimension along which interpolation is being performed will be used as the x values. If False, the x values are set as an equally spaced sequence. + strict : bool + Whether to raise errors if the index is either non-unique or non-monotonic (default). Returns ------- @@ -257,11 +261,12 @@ def get_clean_interp_index(arr, dim: Hashable, use_coordinate: Union[str, bool] if isinstance(index, pd.MultiIndex): index.name = dim - if not index.is_monotonic: - raise ValueError(f"Index {index.name!r} must be monotonically increasing") + if strict: + if not index.is_monotonic: + raise ValueError(f"Index {index.name!r} must be monotonically increasing") - if not index.is_unique: - raise ValueError(f"Index {index.name!r} has duplicate values") + if not index.is_unique: + raise ValueError(f"Index {index.name!r} has duplicate values") # Special case for non-standard calendar indexes # Numerical datetime values are defined with respect to 1970-01-01T00:00:00 in units of nanoseconds @@ -282,7 +287,7 @@ def get_clean_interp_index(arr, dim: Hashable, use_coordinate: Union[str, bool] # xarray/numpy raise a ValueError raise TypeError( f"Index {index.name!r} must be castable to float64 to support " - f"interpolation, got {type(index).__name__}." + f"interpolation or curve fitting, got {type(index).__name__}." ) return index @@ -298,8 +303,7 @@ def interp_na( keep_attrs: bool = None, **kwargs, ): - """Interpolate values according to different methods. - """ + """Interpolate values according to different methods.""" from xarray.coding.cftimeindex import CFTimeIndex if dim is None: @@ -433,6 +437,16 @@ def bfill(arr, dim=None, limit=None): ).transpose(*arr.dims) +def _import_interpolant(interpolant, method): + """Import interpolant from scipy.interpolate.""" + try: + from scipy import interpolate + + return getattr(interpolate, interpolant) + except ImportError as e: + raise ImportError(f"Interpolation with method {method} requires scipy.") from e + + def _get_interpolator(method, vectorizeable_only=False, **kwargs): """helper function to select the appropriate interpolator class @@ -455,12 +469,6 @@ def _get_interpolator(method, vectorizeable_only=False, **kwargs): "akima", ] - has_scipy = True - try: - from scipy import interpolate - except ImportError: - has_scipy = False - # prioritize scipy.interpolate if ( method == "linear" @@ -471,32 +479,29 @@ def _get_interpolator(method, vectorizeable_only=False, **kwargs): interp_class = NumpyInterpolator elif method in valid_methods: - if not has_scipy: - raise ImportError("Interpolation with method `%s` requires scipy" % method) - if method in interp1d_methods: kwargs.update(method=method) interp_class = ScipyInterpolator elif vectorizeable_only: raise ValueError( - "{} is not a vectorizeable interpolator. " - "Available methods are {}".format(method, interp1d_methods) + f"{method} is not a vectorizeable interpolator. " + f"Available methods are {interp1d_methods}" ) elif method == "barycentric": - interp_class = interpolate.BarycentricInterpolator + interp_class = _import_interpolant("BarycentricInterpolator", method) elif method == "krog": - interp_class = interpolate.KroghInterpolator + interp_class = _import_interpolant("KroghInterpolator", method) elif method == "pchip": - interp_class = interpolate.PchipInterpolator + interp_class = _import_interpolant("PchipInterpolator", method) elif method == "spline": kwargs.update(method=method) interp_class = SplineInterpolator elif method == "akima": - interp_class = interpolate.Akima1DInterpolator + interp_class = _import_interpolant("Akima1DInterpolator", method) else: - raise ValueError("%s is not a valid scipy interpolator" % method) + raise ValueError(f"{method} is not a valid scipy interpolator") else: - raise ValueError("%s is not a valid interpolator" % method) + raise ValueError(f"{method} is not a valid interpolator") return interp_class, kwargs @@ -508,18 +513,13 @@ def _get_interpolator_nd(method, **kwargs): """ valid_methods = ["linear", "nearest"] - try: - from scipy import interpolate - except ImportError: - raise ImportError("Interpolation with method `%s` requires scipy" % method) - if method in valid_methods: kwargs.update(method=method) - interp_class = interpolate.interpn + interp_class = _import_interpolant("interpn", method) else: raise ValueError( - "%s is not a valid interpolator for interpolating " - "over multiple dimensions." % method + f"{method} is not a valid interpolator for interpolating " + "over multiple dimensions." ) return interp_class, kwargs @@ -539,24 +539,25 @@ def _get_valid_fill_mask(arr, dim, limit): ) <= limit -def _assert_single_chunk(var, axes): - for axis in axes: - if len(var.chunks[axis]) > 1 or var.chunks[axis][0] < var.shape[axis]: - raise NotImplementedError( - "Chunking along the dimension to be interpolated " - "({}) is not yet supported.".format(axis) - ) - - def _localize(var, indexes_coords): - """ Speed up for linear and nearest neighbor method. + """Speed up for linear and nearest neighbor method. Only consider a subspace that is needed for the interpolation """ indexes = {} for dim, [x, new_x] in indexes_coords.items(): + if np.issubdtype(new_x.dtype, np.datetime64) and LooseVersion( + np.__version__ + ) < LooseVersion("1.18"): + # np.nanmin/max changed behaviour for datetime types in numpy 1.18, + # see https://github.com/pydata/xarray/pull/3924/files + minval = np.min(new_x.values) + maxval = np.max(new_x.values) + else: + minval = np.nanmin(new_x.values) + maxval = np.nanmax(new_x.values) index = x.to_index() - imin = index.get_loc(np.min(new_x.values), method="nearest") - imax = index.get_loc(np.max(new_x.values), method="nearest") + imin = index.get_loc(minval, method="nearest") + imax = index.get_loc(maxval, method="nearest") indexes[dim] = slice(max(imin - 2, 0), imax + 2) indexes_coords[dim] = (x[indexes[dim]], new_x) @@ -564,7 +565,7 @@ def _localize(var, indexes_coords): def _floatize_x(x, new_x): - """ Make x and new_x float. + """Make x and new_x float. This is particulary useful for datetime dtype. x, new_x: tuple of np.ndarray """ @@ -584,7 +585,7 @@ def _floatize_x(x, new_x): def interp(var, indexes_coords, method, **kwargs): - """ Make an interpolation of Variable + """Make an interpolation of Variable Parameters ---------- @@ -612,36 +613,42 @@ def interp(var, indexes_coords, method, **kwargs): if not indexes_coords: return var.copy() - # simple speed up for the local interpolation - if method in ["linear", "nearest"]: - var, indexes_coords = _localize(var, indexes_coords) - # default behavior kwargs["bounds_error"] = kwargs.get("bounds_error", False) - # target dimensions - dims = list(indexes_coords) - x, new_x = zip(*[indexes_coords[d] for d in dims]) - destination = broadcast_variables(*new_x) - - # transpose to make the interpolated axis to the last position - broadcast_dims = [d for d in var.dims if d not in dims] - original_dims = broadcast_dims + dims - new_dims = broadcast_dims + list(destination[0].dims) - interped = interp_func( - var.transpose(*original_dims).data, x, destination, method, kwargs - ) + result = var + # decompose the interpolation into a succession of independant interpolation + for indexes_coords in decompose_interp(indexes_coords): + var = result + + # simple speed up for the local interpolation + if method in ["linear", "nearest"]: + var, indexes_coords = _localize(var, indexes_coords) + + # target dimensions + dims = list(indexes_coords) + x, new_x = zip(*[indexes_coords[d] for d in dims]) + destination = broadcast_variables(*new_x) + + # transpose to make the interpolated axis to the last position + broadcast_dims = [d for d in var.dims if d not in dims] + original_dims = broadcast_dims + dims + new_dims = broadcast_dims + list(destination[0].dims) + interped = interp_func( + var.transpose(*original_dims).data, x, destination, method, kwargs + ) - result = Variable(new_dims, interped, attrs=var.attrs) + result = Variable(new_dims, interped, attrs=var.attrs) - # dimension of the output array - out_dims = OrderedSet() - for d in var.dims: - if d in dims: - out_dims.update(indexes_coords[d][1].dims) - else: - out_dims.add(d) - return result.transpose(*tuple(out_dims)) + # dimension of the output array + out_dims = OrderedSet() + for d in var.dims: + if d in dims: + out_dims.update(indexes_coords[d][1].dims) + else: + out_dims.add(d) + result = result.transpose(*out_dims) + return result def interp_func(var, x, new_x, method, kwargs): @@ -659,7 +666,7 @@ def interp_func(var, x, new_x, method, kwargs): New coordinates. Should not contain NaN. method: string {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} for - 1-dimensional itnterpolation. + 1-dimensional interpolation. {'linear', 'nearest'} for multidimensional interpolation **kwargs: Optional keyword arguments to be passed to scipy.interpolator @@ -685,24 +692,61 @@ def interp_func(var, x, new_x, method, kwargs): else: func, kwargs = _get_interpolator_nd(method, **kwargs) - if isinstance(var, dask_array_type): + if is_duck_dask_array(var): import dask.array as da - _assert_single_chunk(var, range(var.ndim - len(x), var.ndim)) - chunks = var.chunks[: -len(x)] + new_x[0].shape - drop_axis = range(var.ndim - len(x), var.ndim) - new_axis = range(var.ndim - len(x), var.ndim - len(x) + new_x[0].ndim) - return da.map_blocks( - _interpnd, + nconst = var.ndim - len(x) + + out_ind = list(range(nconst)) + list(range(var.ndim, var.ndim + new_x[0].ndim)) + + # blockwise args format + x_arginds = [[_x, (nconst + index,)] for index, _x in enumerate(x)] + x_arginds = [item for pair in x_arginds for item in pair] + new_x_arginds = [ + [_x, [var.ndim + index for index in range(_x.ndim)]] for _x in new_x + ] + new_x_arginds = [item for pair in new_x_arginds for item in pair] + + args = ( var, - x, - new_x, - func, - kwargs, - dtype=var.dtype, - chunks=chunks, - new_axis=new_axis, - drop_axis=drop_axis, + range(var.ndim), + *x_arginds, + *new_x_arginds, + ) + + _, rechunked = da.unify_chunks(*args) + + args = tuple([elem for pair in zip(rechunked, args[1::2]) for elem in pair]) + + new_x = rechunked[1 + (len(rechunked) - 1) // 2 :] + + new_axes = { + var.ndim + i: new_x[0].chunks[i] + if new_x[0].chunks is not None + else new_x[0].shape[i] + for i in range(new_x[0].ndim) + } + + # if usefull, re-use localize for each chunk of new_x + localize = (method in ["linear", "nearest"]) and (new_x[0].chunks is not None) + + # scipy.interpolate.interp1d always forces to float. + # Use the same check for blockwise as well: + if not issubclass(var.dtype.type, np.inexact): + dtype = np.float_ + else: + dtype = var.dtype + + return da.blockwise( + _dask_aware_interpnd, + out_ind, + *args, + interp_func=func, + interp_kwargs=kwargs, + localize=localize, + concatenate=True, + dtype=dtype, + new_axes=new_axes, ) return _interpnd(var, x, new_x, func, kwargs) @@ -733,3 +777,67 @@ def _interpnd(var, x, new_x, func, kwargs): # move back the interpolation axes to the last position rslt = rslt.transpose(range(-rslt.ndim + 1, 1)) return rslt.reshape(rslt.shape[:-1] + new_x[0].shape) + + +def _dask_aware_interpnd(var, *coords, interp_func, interp_kwargs, localize=True): + """Wrapper for `_interpnd` through `blockwise` + + The first half arrays in `coords` are original coordinates, + the other half are destination coordinates + """ + n_x = len(coords) // 2 + nconst = len(var.shape) - n_x + + # _interpnd expect coords to be Variables + x = [Variable([f"dim_{nconst + dim}"], _x) for dim, _x in enumerate(coords[:n_x])] + new_x = [ + Variable([f"dim_{len(var.shape) + dim}" for dim in range(len(_x.shape))], _x) + for _x in coords[n_x:] + ] + + if localize: + # _localize expect var to be a Variable + var = Variable([f"dim_{dim}" for dim in range(len(var.shape))], var) + + indexes_coords = {_x.dims[0]: (_x, _new_x) for _x, _new_x in zip(x, new_x)} + + # simple speed up for the local interpolation + var, indexes_coords = _localize(var, indexes_coords) + x, new_x = zip(*[indexes_coords[d] for d in indexes_coords]) + + # put var back as a ndarray + var = var.data + + return _interpnd(var, x, new_x, interp_func, interp_kwargs) + + +def decompose_interp(indexes_coords): + """Decompose the interpolation into a succession of independant interpolation keeping the order""" + + dest_dims = [ + dest[1].dims if dest[1].ndim > 0 else [dim] + for dim, dest in indexes_coords.items() + ] + partial_dest_dims = [] + partial_indexes_coords = {} + for i, index_coords in enumerate(indexes_coords.items()): + partial_indexes_coords.update([index_coords]) + + if i == len(dest_dims) - 1: + break + + partial_dest_dims += [dest_dims[i]] + other_dims = dest_dims[i + 1 :] + + s_partial_dest_dims = {dim for dims in partial_dest_dims for dim in dims} + s_other_dims = {dim for dims in other_dims for dim in dims} + + if not s_partial_dest_dims.intersection(s_other_dims): + # this interpolation is orthogonal to the rest + + yield partial_indexes_coords + + partial_dest_dims = [] + partial_indexes_coords = {} + + yield partial_indexes_coords diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index f9989c2c8c9..5eb88bcd096 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np from . import dtypes, nputils, utils @@ -6,6 +8,7 @@ try: import dask.array as dask_array + from . import dask_array_compat except ImportError: dask_array = None @@ -25,13 +28,9 @@ def _maybe_null_out(result, axis, mask, min_count=1): """ xarray version of pandas.core.nanops._maybe_null_out """ - if hasattr(axis, "__len__"): # if tuple or list - raise ValueError( - "min_count is not available for reduction with more than one dimensions." - ) if axis is not None and getattr(result, "ndim", False): - null_mask = (mask.shape[axis] - mask.sum(axis) - min_count) < 0 + null_mask = (np.take(mask.shape, axis).prod() - mask.sum(axis) - min_count) < 0 if null_mask.any(): dtype, fill_value = dtypes.maybe_promote(result.dtype) result = result.astype(dtype) @@ -46,7 +45,7 @@ def _maybe_null_out(result, axis, mask, min_count=1): def _nan_argminmax_object(func, fill_value, value, axis=None, **kwargs): - """ In house nanargmin, nanargmax for object arrays. Always return integer + """In house nanargmin, nanargmax for object arrays. Always return integer type """ valid_count = count(value, axis=axis) @@ -118,7 +117,7 @@ def nansum(a, axis=None, dtype=None, out=None, min_count=None): def _nanmean_ddof_object(ddof, value, axis=None, dtype=None, **kwargs): """ In house nanmean. ddof argument will be used in _nanvar method """ - from .duck_array_ops import count, fillna, _dask_or_eager_func, where_method + from .duck_array_ops import _dask_or_eager_func, count, fillna, where_method valid_count = count(value, axis=axis) value = fillna(value, 0) @@ -136,10 +135,14 @@ def nanmean(a, axis=None, dtype=None, out=None): if a.dtype.kind == "O": return _nanmean_ddof_object(0, a, axis=axis, dtype=dtype) - if isinstance(a, dask_array_type): - return dask_array.nanmean(a, axis=axis, dtype=dtype) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", r"Mean of empty slice", category=RuntimeWarning + ) + if isinstance(a, dask_array_type): + return dask_array.nanmean(a, axis=axis, dtype=dtype) - return np.nanmean(a, axis=axis, dtype=dtype) + return np.nanmean(a, axis=axis, dtype=dtype) def nanmedian(a, axis=None, out=None): diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index fa6df63e0ea..c65c22f5384 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -90,8 +90,7 @@ def _is_contiguous(positions): def _advanced_indexer_subspaces(key): - """Indices of the advanced indexes subspaces for mixed indexing and vindex. - """ + """Indices of the advanced indexes subspaces for mixed indexing and vindex.""" if not isinstance(key, tuple): key = (key,) advanced_index_positions = [ @@ -135,14 +134,22 @@ def __setitem__(self, key, value): def rolling_window(a, axis, window, center, fill_value): """ rolling window with padding. """ pads = [(0, 0) for s in a.shape] - if center: - start = int(window / 2) # 10 -> 5, 9 -> 4 - end = window - 1 - start - pads[axis] = (start, end) - else: - pads[axis] = (window - 1, 0) + if not hasattr(axis, "__len__"): + axis = [axis] + window = [window] + center = [center] + + for ax, win, cent in zip(axis, window, center): + if cent: + start = int(win / 2) # 10 -> 5, 9 -> 4 + end = win - 1 - start + pads[ax] = (start, end) + else: + pads[ax] = (win - 1, 0) a = np.pad(a, pads, mode="constant", constant_values=fill_value) - return _rolling_window(a, window, axis) + for ax, win in zip(axis, window): + a = _rolling_window(a, win, ax) + return a def _rolling_window(a, window, axis=-1): @@ -166,14 +173,19 @@ def _rolling_window(a, window, axis=-1): Examples -------- >>> x = np.arange(10).reshape((2, 5)) - >>> np.rolling_window(x, 3, axis=-1) - array([[[0, 1, 2], [1, 2, 3], [2, 3, 4]], - [[5, 6, 7], [6, 7, 8], [7, 8, 9]]]) + >>> _rolling_window(x, 3, axis=-1) + array([[[0, 1, 2], + [1, 2, 3], + [2, 3, 4]], + + [[5, 6, 7], + [6, 7, 8], + [7, 8, 9]]]) Calculate rolling mean of last dimension: - >>> np.mean(np.rolling_window(x, 3, axis=-1), -1) - array([[ 1., 2., 3.], - [ 6., 7., 8.]]) + >>> np.mean(_rolling_window(x, 3, axis=-1), -1) + array([[1., 2., 3.], + [6., 7., 8.]]) This function is taken from https://github.com/numpy/numpy/pull/31 but slightly modified to accept axis option. @@ -224,10 +236,17 @@ def _nanpolyfit_1d(arr, x, rcond=None): out = np.full((x.shape[1] + 1,), np.nan) mask = np.isnan(arr) if not np.all(mask): - out[:-1], out[-1], _, _ = np.linalg.lstsq(x[~mask, :], arr[~mask], rcond=rcond) + out[:-1], resid, rank, _ = np.linalg.lstsq(x[~mask, :], arr[~mask], rcond=rcond) + out[-1] = resid if resid.size > 0 else np.nan + warn_on_deficient_rank(rank, x.shape[1]) return out +def warn_on_deficient_rank(rank, order): + if rank != order: + warnings.warn("Polyfit may be poorly conditioned", np.RankWarning, stacklevel=2) + + def least_squares(lhs, rhs, rcond=None, skipna=False): if skipna: added_dim = rhs.ndim == 1 @@ -240,16 +259,21 @@ def least_squares(lhs, rhs, rcond=None, skipna=False): _nanpolyfit_1d, 0, rhs[:, nan_cols], lhs ) if np.any(~nan_cols): - out[:-1, ~nan_cols], out[-1, ~nan_cols], _, _ = np.linalg.lstsq( + out[:-1, ~nan_cols], resids, rank, _ = np.linalg.lstsq( lhs, rhs[:, ~nan_cols], rcond=rcond ) + out[-1, ~nan_cols] = resids if resids.size > 0 else np.nan + warn_on_deficient_rank(rank, lhs.shape[1]) coeffs = out[:-1, :] residuals = out[-1, :] if added_dim: coeffs = coeffs.reshape(coeffs.shape[0]) residuals = residuals.reshape(residuals.shape[0]) else: - coeffs, residuals, _, _ = np.linalg.lstsq(lhs, rhs, rcond=rcond) + coeffs, residuals, rank, _ = np.linalg.lstsq(lhs, rhs, rcond=rcond) + if residuals.size == 0: + residuals = coeffs[0] * np.nan + warn_on_deficient_rank(rank, lhs.shape[1]) return coeffs, residuals diff --git a/xarray/core/ops.py b/xarray/core/ops.py index b789f93b4f1..d56b0d59df0 100644 --- a/xarray/core/ops.py +++ b/xarray/core/ops.py @@ -42,13 +42,10 @@ NUMPY_SAME_METHODS = ["item", "searchsorted"] # methods which don't modify the data shape, so the result should still be # wrapped in an Variable/DataArray -NUMPY_UNARY_METHODS = ["astype", "argsort", "clip", "conj", "conjugate"] -PANDAS_UNARY_FUNCTIONS = ["isnull", "notnull"] +NUMPY_UNARY_METHODS = ["argsort", "clip", "conj", "conjugate"] # methods which remove an axis REDUCE_METHODS = ["all", "any"] NAN_REDUCE_METHODS = [ - "argmax", - "argmin", "max", "min", "mean", @@ -92,12 +89,7 @@ Parameters ---------- -{extra_args} -skipna : bool, optional - If True, skip missing values (as marked by NaN). By default, only - skips missing values for float dtypes; other dtypes either do not - have a sentinel missing value (int) or skipna=True has not been - implemented (object, datetime64 or timedelta64).{min_count_docs} +{extra_args}{skip_na_docs}{min_count_docs} keep_attrs : bool, optional If True, the attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be @@ -113,8 +105,15 @@ indicated dimension(s) removed. """ +_SKIPNA_DOCSTRING = """ +skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64).""" + _MINCOUNT_DOCSTRING = """ -min_count : int, default None +min_count : int, default: None The required number of valid values to perform the operation. If fewer than min_count non-NA values are present the result will be NA. New in version 0.10.8: Added with the default being None.""" @@ -140,22 +139,22 @@ def fillna(data, other, join="left", dataset_join="left"): Parameters ---------- - join : {'outer', 'inner', 'left', 'right'}, optional + join : {"outer", "inner", "left", "right"}, optional Method for joining the indexes of the passed objects along each dimension - - 'outer': use the union of object indexes - - 'inner': use the intersection of object indexes - - 'left': use indexes from the first object with each dimension - - 'right': use indexes from the last object with each dimension - - 'exact': raise `ValueError` instead of aligning when indexes to be + - "outer": use the union of object indexes + - "inner": use the intersection of object indexes + - "left": use indexes from the first object with each dimension + - "right": use indexes from the last object with each dimension + - "exact": raise `ValueError` instead of aligning when indexes to be aligned are not equal - dataset_join : {'outer', 'inner', 'left', 'right'}, optional + dataset_join : {"outer", "inner", "left", "right"}, optional Method for joining variables of Dataset objects with mismatched data variables. - - 'outer': take variables from both Dataset objects - - 'inner': take only overlapped variables - - 'left': take only variables from the first object - - 'right': take only variables from the last object + - "outer": take variables from both Dataset objects + - "inner": take only overlapped variables + - "left": take only variables from the first object + - "right": take only variables from the last object """ from .computation import apply_ufunc @@ -262,6 +261,7 @@ def inject_reduce_methods(cls): for name, f, include_skipna in methods: numeric_only = getattr(f, "numeric_only", False) available_min_count = getattr(f, "available_min_count", False) + skip_na_docs = _SKIPNA_DOCSTRING if include_skipna else "" min_count_docs = _MINCOUNT_DOCSTRING if available_min_count else "" func = cls._reduce_method(f, include_skipna, numeric_only) @@ -270,6 +270,7 @@ def inject_reduce_methods(cls): name=name, cls=cls.__name__, extra_args=cls._reduce_extra_args_docstring.format(name=name), + skip_na_docs=skip_na_docs, min_count_docs=min_count_docs, ) setattr(cls, name, func) @@ -332,10 +333,6 @@ def inject_all_ops_and_reduce_methods(cls, priority=50, array_only=True): for name in NUMPY_UNARY_METHODS: setattr(cls, name, cls._unary_op(_method_wrapper(name))) - for name in PANDAS_UNARY_FUNCTIONS: - f = _func_slash_method_wrapper(getattr(duck_array_ops, name), name=name) - setattr(cls, name, cls._unary_op(f)) - f = _func_slash_method_wrapper(duck_array_ops.around, name="round") setattr(cls, "round", cls._unary_op(f)) diff --git a/xarray/core/options.py b/xarray/core/options.py index 5d81ca40a6e..d421b4c4f17 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -1,26 +1,28 @@ import warnings -DISPLAY_WIDTH = "display_width" ARITHMETIC_JOIN = "arithmetic_join" +CMAP_DIVERGENT = "cmap_divergent" +CMAP_SEQUENTIAL = "cmap_sequential" +DISPLAY_MAX_ROWS = "display_max_rows" +DISPLAY_STYLE = "display_style" +DISPLAY_WIDTH = "display_width" ENABLE_CFTIMEINDEX = "enable_cftimeindex" FILE_CACHE_MAXSIZE = "file_cache_maxsize" -WARN_FOR_UNCLOSED_FILES = "warn_for_unclosed_files" -CMAP_SEQUENTIAL = "cmap_sequential" -CMAP_DIVERGENT = "cmap_divergent" KEEP_ATTRS = "keep_attrs" -DISPLAY_STYLE = "display_style" +WARN_FOR_UNCLOSED_FILES = "warn_for_unclosed_files" OPTIONS = { - DISPLAY_WIDTH: 80, ARITHMETIC_JOIN: "inner", + CMAP_DIVERGENT: "RdBu_r", + CMAP_SEQUENTIAL: "viridis", + DISPLAY_MAX_ROWS: 12, + DISPLAY_STYLE: "html", + DISPLAY_WIDTH: 80, ENABLE_CFTIMEINDEX: True, FILE_CACHE_MAXSIZE: 128, - WARN_FOR_UNCLOSED_FILES: False, - CMAP_SEQUENTIAL: "viridis", - CMAP_DIVERGENT: "RdBu_r", KEEP_ATTRS: "default", - DISPLAY_STYLE: "html", + WARN_FOR_UNCLOSED_FILES: False, } _JOIN_OPTIONS = frozenset(["inner", "outer", "left", "right", "exact"]) @@ -32,13 +34,14 @@ def _positive_integer(value): _VALIDATORS = { - DISPLAY_WIDTH: _positive_integer, ARITHMETIC_JOIN: _JOIN_OPTIONS.__contains__, + DISPLAY_MAX_ROWS: _positive_integer, + DISPLAY_STYLE: _DISPLAY_OPTIONS.__contains__, + DISPLAY_WIDTH: _positive_integer, ENABLE_CFTIMEINDEX: lambda value: isinstance(value, bool), FILE_CACHE_MAXSIZE: _positive_integer, - WARN_FOR_UNCLOSED_FILES: lambda value: isinstance(value, bool), KEEP_ATTRS: lambda choice: choice in [True, False, "default"], - DISPLAY_STYLE: _DISPLAY_OPTIONS.__contains__, + WARN_FOR_UNCLOSED_FILES: lambda value: isinstance(value, bool), } @@ -57,8 +60,8 @@ def _warn_on_setting_enable_cftimeindex(enable_cftimeindex): _SETTERS = { - FILE_CACHE_MAXSIZE: _set_file_cache_maxsize, ENABLE_CFTIMEINDEX: _warn_on_setting_enable_cftimeindex, + FILE_CACHE_MAXSIZE: _set_file_cache_maxsize, } @@ -71,7 +74,7 @@ def _get_keep_attrs(default): return global_choice else: raise ValueError( - "The global option keep_attrs must be one of" " True, False or 'default'." + "The global option keep_attrs must be one of True, False or 'default'." ) @@ -111,16 +114,18 @@ class set_options: >>> ds = xr.Dataset({"x": np.arange(1000)}) >>> with xr.set_options(display_width=40): ... print(ds) + ... Dimensions: (x: 1000) Coordinates: - * x (x) int64 0 1 2 3 4 5 6 ... + * x (x) int64 0 1 2 ... 998 999 Data variables: *empty* Or to set global options: - >>> xr.set_options(display_width=80) + >>> xr.set_options(display_width=80) # doctest: +ELLIPSIS + """ def __init__(self, **kwargs): @@ -132,7 +137,15 @@ def __init__(self, **kwargs): % (k, set(OPTIONS)) ) if k in _VALIDATORS and not _VALIDATORS[k](v): - raise ValueError(f"option {k!r} given an invalid value: {v!r}") + if k == ARITHMETIC_JOIN: + expected = f"Expected one of {_JOIN_OPTIONS!r}" + elif k == DISPLAY_STYLE: + expected = f"Expected one of {_DISPLAY_OPTIONS!r}" + else: + expected = "" + raise ValueError( + f"option {k!r} given an invalid value: {v!r}. " + expected + ) self.old[k] = OPTIONS[k] self._apply_update(kwargs) diff --git a/xarray/core/parallel.py b/xarray/core/parallel.py index 6f1668f698f..20b4b9f9eb3 100644 --- a/xarray/core/parallel.py +++ b/xarray/core/parallel.py @@ -1,8 +1,8 @@ try: import dask import dask.array + from dask.array.utils import meta_from_array from dask.highlevelgraph import HighLevelGraph - from .dask_array_compat import meta_from_array except ImportError: pass @@ -16,6 +16,8 @@ DefaultDict, Dict, Hashable, + Iterable, + List, Mapping, Sequence, Tuple, @@ -25,12 +27,50 @@ import numpy as np +from .alignment import align from .dataarray import DataArray from .dataset import Dataset T_DSorDA = TypeVar("T_DSorDA", DataArray, Dataset) +def unzip(iterable): + return zip(*iterable) + + +def assert_chunks_compatible(a: Dataset, b: Dataset): + a = a.unify_chunks() + b = b.unify_chunks() + + for dim in set(a.chunks).intersection(set(b.chunks)): + if a.chunks[dim] != b.chunks[dim]: + raise ValueError(f"Chunk sizes along dimension {dim!r} are not equal.") + + +def check_result_variables( + result: Union[DataArray, Dataset], expected: Mapping[str, Any], kind: str +): + + if kind == "coords": + nice_str = "coordinate" + elif kind == "data_vars": + nice_str = "data" + + # check that coords and data variables are as expected + missing = expected[kind] - set(getattr(result, kind)) + if missing: + raise ValueError( + "Result from applying user function does not contain " + f"{nice_str} variables {missing}." + ) + extra = set(getattr(result, kind)) - expected[kind] + if extra: + raise ValueError( + "Result from applying user function has unexpected " + f"{nice_str} variables {extra}." + ) + + def dataset_to_dataarray(obj: Dataset) -> DataArray: if not isinstance(obj, Dataset): raise TypeError("Expected Dataset, got %s" % type(obj)) @@ -43,6 +83,17 @@ def dataset_to_dataarray(obj: Dataset) -> DataArray: return next(iter(obj.data_vars.values())) +def dataarray_to_dataset(obj: DataArray) -> Dataset: + # only using _to_temp_dataset would break + # func = lambda x: x.to_dataset() + # since that relies on preserving name. + if obj.name is None: + dataset = obj._to_temp_dataset() + else: + dataset = obj.to_dataset() + return dataset + + def make_meta(obj): """If obj is a DataArray or Dataset, return a new object of the same type and with the same variables and dtypes, but where all variables have size 0 and numpy @@ -72,15 +123,15 @@ def make_meta(obj): def infer_template( func: Callable[..., T_DSorDA], obj: Union[DataArray, Dataset], *args, **kwargs ) -> T_DSorDA: - """Infer return object by running the function on meta objects. - """ + """Infer return object by running the function on meta objects.""" meta_args = [make_meta(arg) for arg in (obj,) + args] try: template = func(*meta_args, **kwargs) except Exception as e: raise Exception( - "Cannot infer object returned from running user provided function." + "Cannot infer object returned from running user provided function. " + "Please supply the 'template' kwarg to map_blocks." ) from e if not isinstance(template, (Dataset, DataArray)): @@ -102,39 +153,54 @@ def make_dict(x: Union[DataArray, Dataset]) -> Dict[Hashable, Any]: return {k: v.data for k, v in x.variables.items()} +def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping): + if dim in chunk_index: + which_chunk = chunk_index[dim] + return slice(chunk_bounds[dim][which_chunk], chunk_bounds[dim][which_chunk + 1]) + return slice(None) + + def map_blocks( func: Callable[..., T_DSorDA], obj: Union[DataArray, Dataset], args: Sequence[Any] = (), kwargs: Mapping[str, Any] = None, + template: Union[DataArray, Dataset] = None, ) -> T_DSorDA: - """Apply a function to each chunk of a DataArray or Dataset. This function is - experimental and its signature may change. + """Apply a function to each block of a DataArray or Dataset. + + .. warning:: + This function is experimental and its signature may change. Parameters ---------- - func: callable + func : callable User-provided function that accepts a DataArray or Dataset as its first - parameter. The function will receive a subset of 'obj' (see below), + parameter ``obj``. The function will receive a subset or 'block' of ``obj`` (see below), corresponding to one chunk along each chunked dimension. ``func`` will be - executed as ``func(obj_subset, *args, **kwargs)``. - - The function will be first run on mocked-up data, that looks like 'obj' but - has sizes 0, to determine properties of the returned object such as dtype, - variable names, new dimensions and new indexes (if any). + executed as ``func(subset_obj, *subset_args, **kwargs)``. This function must return either a single DataArray or a single Dataset. - This function cannot change size of existing dimensions, or add new chunked - dimensions. - obj: DataArray, Dataset - Passed to the function as its first argument, one dask chunk at a time. - args: Sequence - Passed verbatim to func after unpacking, after the sliced obj. xarray objects, - if any, will not be split by chunks. Passing dask collections is not allowed. - kwargs: Mapping + This function cannot add a new chunked dimension. + + obj : DataArray, Dataset + Passed to the function as its first argument, one block at a time. + args : sequence + Passed to func after unpacking and subsetting any xarray objects by blocks. + xarray objects in args must be aligned with obj, otherwise an error is raised. + kwargs : mapping Passed verbatim to func after unpacking. xarray objects, if any, will not be - split by chunks. Passing dask collections is not allowed. + subset to blocks. Passing dask collections in kwargs is not allowed. + template : DataArray or Dataset, optional + xarray object representing the final result after compute is called. If not provided, + the function will be first run on mocked-up data, that looks like ``obj`` but + has sizes 0, to determine properties of the returned object such as dtype, + variable names, attributes, new dimensions and new indexes (if any). + ``template`` must be provided if the function changes the size of existing dimensions. + When provided, ``attrs`` on variables in `template` are copied over to the result. Any + ``attrs`` set by ``func`` will be ignored. + Returns ------- @@ -143,11 +209,11 @@ def map_blocks( Notes ----- - This function is designed for when one needs to manipulate a whole xarray object - within each chunk. In the more common case where one can work on numpy arrays, it is - recommended to use apply_ufunc. + This function is designed for when ``func`` needs to manipulate a whole xarray object + subset to each block. In the more common case where ``func`` can work on numpy arrays, it is + recommended to use ``apply_ufunc``. - If none of the variables in obj is backed by dask, calling this function is + If none of the variables in ``obj`` is backed by dask arrays, calling this function is equivalent to calling ``func(obj, *args, **kwargs)``. See Also @@ -163,19 +229,19 @@ def map_blocks( its indices, and its methods like ``.groupby()``. >>> def calculate_anomaly(da, groupby_type="time.month"): - ... # Necessary workaround to xarray's check with zero dimensions - ... # https://github.com/pydata/xarray/issues/3575 - ... if sum(da.shape) == 0: - ... return da ... gb = da.groupby(groupby_type) ... clim = gb.mean(dim="time") ... return gb - clim + ... >>> time = xr.cftime_range("1990-01", "1992-01", freq="M") + >>> month = xr.DataArray(time.month, coords={"time": time}, dims=["time"]) >>> np.random.seed(123) >>> array = xr.DataArray( - ... np.random.rand(len(time)), dims="time", coords=[time] + ... np.random.rand(len(time)), + ... dims=["time"], + ... coords={"time": time, "month": month}, ... ).chunk() - >>> xr.map_blocks(calculate_anomaly, array).compute() + >>> array.map_blocks(calculate_anomaly, template=array).compute() array([ 0.12894847, 0.11323072, -0.0855964 , -0.09334032, 0.26848862, 0.12382735, 0.22460641, 0.07650108, -0.07673453, -0.22865714, @@ -184,39 +250,74 @@ def map_blocks( 0.07673453, 0.22865714, 0.19063865, -0.0590131 ]) Coordinates: * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + month (time) int64 1 2 3 4 5 6 7 8 9 10 11 12 1 2 3 4 5 6 7 8 9 10 11 12 Note that one must explicitly use ``args=[]`` and ``kwargs={}`` to pass arguments to the function being applied in ``xr.map_blocks()``: - >>> xr.map_blocks( - ... calculate_anomaly, array, kwargs={"groupby_type": "time.year"}, - ... ) + >>> array.map_blocks( + ... calculate_anomaly, + ... kwargs={"groupby_type": "time.year"}, + ... template=array, + ... ) # doctest: +ELLIPSIS - array([ 0.15361741, -0.25671244, -0.31600032, 0.008463 , 0.1766172 , - -0.11974531, 0.43791243, 0.14197797, -0.06191987, -0.15073425, - -0.19967375, 0.18619794, -0.05100474, -0.42989909, -0.09153273, - 0.24841842, -0.30708526, -0.31412523, 0.04197439, 0.0422506 , - 0.14482397, 0.35985481, 0.23487834, 0.12144652]) + dask.array Coordinates: - * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + * time (time) object 1990-01-31 00:00:00 ... 1991-12-31 00:00:00 + month (time) int64 dask.array """ - def _wrapper(func, obj, to_array, args, kwargs): - if to_array: - obj = dataset_to_dataarray(obj) - - result = func(obj, *args, **kwargs) + def _wrapper( + func: Callable, + args: List, + kwargs: dict, + arg_is_array: Iterable[bool], + expected: dict, + ): + """ + Wrapper function that receives datasets in args; converts to dataarrays when necessary; + passes these to the user function `func` and checks returned objects for expected shapes/sizes/etc. + """ + + converted_args = [ + dataset_to_dataarray(arg) if is_array else arg + for is_array, arg in zip(arg_is_array, args) + ] + + result = func(*converted_args, **kwargs) + + # check all dims are present + missing_dimensions = set(expected["shapes"]) - set(result.sizes) + if missing_dimensions: + raise ValueError( + f"Dimensions {missing_dimensions} missing on returned object." + ) + # check that index lengths and values are as expected for name, index in result.indexes.items(): - if name in obj.indexes: - if len(index) != len(obj.indexes[name]): + if name in expected["shapes"]: + if len(index) != expected["shapes"][name]: raise ValueError( - "Length of the %r dimension has changed. This is not allowed." - % name + f"Received dimension {name!r} of length {len(index)}. Expected length {expected['shapes'][name]}." ) + if name in expected["indexes"]: + expected_index = expected["indexes"][name] + if not index.equals(expected_index): + raise ValueError( + f"Expected index {name!r} to be {expected_index!r}. Received {index!r} instead." + ) + + # check that all expected variables were returned + check_result_variables(result, expected, "coords") + if isinstance(result, Dataset): + check_result_variables(result, expected, "data_vars") return make_dict(result) + if template is not None and not isinstance(template, (DataArray, Dataset)): + raise TypeError( + f"template must be a DataArray or Dataset. Received {type(template).__name__} instead." + ) if not isinstance(args, Sequence): raise TypeError("args must be a sequence (for example, a list or tuple).") if kwargs is None: @@ -224,32 +325,76 @@ def _wrapper(func, obj, to_array, args, kwargs): elif not isinstance(kwargs, Mapping): raise TypeError("kwargs must be a mapping (for example, a dict)") - for value in list(args) + list(kwargs.values()): + for value in kwargs.values(): if dask.is_dask_collection(value): raise TypeError( - "Cannot pass dask collections in args or kwargs yet. Please compute or " + "Cannot pass dask collections in kwargs yet. Please compute or " "load values before passing to map_blocks." ) if not dask.is_dask_collection(obj): return func(obj, *args, **kwargs) - if isinstance(obj, DataArray): - # only using _to_temp_dataset would break - # func = lambda x: x.to_dataset() - # since that relies on preserving name. - if obj.name is None: - dataset = obj._to_temp_dataset() - else: - dataset = obj.to_dataset() - input_is_array = True - else: - dataset = obj - input_is_array = False + all_args = [obj] + list(args) + is_xarray = [isinstance(arg, (Dataset, DataArray)) for arg in all_args] + is_array = [isinstance(arg, DataArray) for arg in all_args] + + # there should be a better way to group this. partition? + xarray_indices, xarray_objs = unzip( + (index, arg) for index, arg in enumerate(all_args) if is_xarray[index] + ) + others = [ + (index, arg) for index, arg in enumerate(all_args) if not is_xarray[index] + ] + + # all xarray objects must be aligned. This is consistent with apply_ufunc. + aligned = align(*xarray_objs, join="exact") + xarray_objs = tuple( + dataarray_to_dataset(arg) if is_da else arg + for is_da, arg in zip(is_array, aligned) + ) + + _, npargs = unzip( + sorted(list(zip(xarray_indices, xarray_objs)) + others, key=lambda x: x[0]) + ) + + # check that chunk sizes are compatible + input_chunks = dict(npargs[0].chunks) + input_indexes = dict(npargs[0].indexes) + for arg in xarray_objs[1:]: + assert_chunks_compatible(npargs[0], arg) + input_chunks.update(arg.chunks) + input_indexes.update(arg.indexes) + + if template is None: + # infer template by providing zero-shaped arrays + template = infer_template(func, aligned[0], *args, **kwargs) + template_indexes = set(template.indexes) + preserved_indexes = template_indexes & set(input_indexes) + new_indexes = template_indexes - set(input_indexes) + indexes = {dim: input_indexes[dim] for dim in preserved_indexes} + indexes.update({k: template.indexes[k] for k in new_indexes}) + output_chunks = { + dim: input_chunks[dim] for dim in template.dims if dim in input_chunks + } - input_chunks = dataset.chunks + else: + # template xarray object has been provided with proper sizes and chunk shapes + indexes = dict(template.indexes) + if isinstance(template, DataArray): + output_chunks = dict(zip(template.dims, template.chunks)) # type: ignore + else: + output_chunks = dict(template.chunks) + + for dim in output_chunks: + if dim in input_chunks and len(input_chunks[dim]) != len(output_chunks[dim]): + raise ValueError( + "map_blocks requires that one block of the input maps to one block of output. " + f"Expected number of output chunks along dimension {dim!r} to be {len(input_chunks[dim])}. " + f"Received {len(output_chunks[dim])} instead. Please provide template if not provided, or " + "fix the provided template." + ) - template: Union[DataArray, Dataset] = infer_template(func, obj, *args, **kwargs) if isinstance(template, DataArray): result_is_array = True template_name = template.name @@ -261,13 +406,6 @@ def _wrapper(func, obj, to_array, args, kwargs): f"func output must be DataArray or Dataset; got {type(template)}" ) - template_indexes = set(template.indexes) - dataset_indexes = set(dataset.indexes) - preserved_indexes = template_indexes & dataset_indexes - new_indexes = template_indexes - dataset_indexes - indexes = {dim: dataset.indexes[dim] for dim in preserved_indexes} - indexes.update({k: template.indexes[k] for k in new_indexes}) - # We're building a new HighLevelGraph hlg. We'll have one new layer # for each variable in the dataset, which is the result of the # func applied to the values. @@ -275,19 +413,27 @@ def _wrapper(func, obj, to_array, args, kwargs): graph: Dict[Any, Any] = {} new_layers: DefaultDict[str, Dict[Any, Any]] = collections.defaultdict(dict) gname = "{}-{}".format( - dask.utils.funcname(func), dask.base.tokenize(dataset, args, kwargs) + dask.utils.funcname(func), dask.base.tokenize(npargs[0], args, kwargs) ) # map dims to list of chunk indexes ichunk = {dim: range(len(chunks_v)) for dim, chunks_v in input_chunks.items()} # mapping from chunk index to slice bounds - chunk_index_bounds = { + input_chunk_bounds = { dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in input_chunks.items() } + output_chunk_bounds = { + dim: np.cumsum((0,) + chunks_v) for dim, chunks_v in output_chunks.items() + } - # iterate over all possible chunk combinations - for v in itertools.product(*ichunk.values()): - chunk_index_dict = dict(zip(dataset.dims, v)) + def subset_dataset_to_block( + graph: dict, gname: str, dataset: Dataset, input_chunk_bounds, chunk_index + ): + """ + Creates a task that subsets an xarray dataset to a block determined by chunk_index. + Block extents are determined by input_chunk_bounds. + Also subtasks that subset the constituent variables of a dataset. + """ # this will become [[name1, variable1], # [name2, variable2], @@ -296,35 +442,31 @@ def _wrapper(func, obj, to_array, args, kwargs): data_vars = [] coords = [] + chunk_tuple = tuple(chunk_index.values()) for name, variable in dataset.variables.items(): # make a task that creates tuple of (dims, chunk) if dask.is_dask_collection(variable.data): # recursively index into dask_keys nested list to get chunk chunk = variable.__dask_keys__() for dim in variable.dims: - chunk = chunk[chunk_index_dict[dim]] + chunk = chunk[chunk_index[dim]] - chunk_variable_task = (f"{gname}-{chunk[0]}",) + v + chunk_variable_task = (f"{gname}-{name}-{chunk[0]}",) + chunk_tuple graph[chunk_variable_task] = ( tuple, [variable.dims, chunk, variable.attrs], ) else: - # non-dask array with possibly chunked dimensions + # non-dask array possibly with dimensions chunked on other variables # index into variable appropriately - subsetter = {} - for dim in variable.dims: - if dim in chunk_index_dict: - which_chunk = chunk_index_dict[dim] - subsetter[dim] = slice( - chunk_index_bounds[dim][which_chunk], - chunk_index_bounds[dim][which_chunk + 1], - ) - + subsetter = { + dim: _get_chunk_slicer(dim, chunk_index, input_chunk_bounds) + for dim in variable.dims + } subset = variable.isel(subsetter) chunk_variable_task = ( "{}-{}".format(gname, dask.base.tokenize(subset)), - ) + v + ) + chunk_tuple graph[chunk_variable_task] = ( tuple, [subset.dims, subset, subset.attrs], @@ -336,15 +478,37 @@ def _wrapper(func, obj, to_array, args, kwargs): else: data_vars.append([name, chunk_variable_task]) - from_wrapper = (gname,) + v - graph[from_wrapper] = ( - _wrapper, - func, - (Dataset, (dict, data_vars), (dict, coords), dataset.attrs), - input_is_array, - args, - kwargs, - ) + return (Dataset, (dict, data_vars), (dict, coords), dataset.attrs) + + # iterate over all possible chunk combinations + for chunk_tuple in itertools.product(*ichunk.values()): + # mapping from dimension name to chunk index + chunk_index = dict(zip(ichunk.keys(), chunk_tuple)) + + blocked_args = [ + subset_dataset_to_block(graph, gname, arg, input_chunk_bounds, chunk_index) + if isxr + else arg + for isxr, arg in zip(is_xarray, npargs) + ] + + # expected["shapes", "coords", "data_vars", "indexes"] are used to + # raise nice error messages in _wrapper + expected = {} + # input chunk 0 along a dimension maps to output chunk 0 along the same dimension + # even if length of dimension is changed by the applied function + expected["shapes"] = { + k: output_chunks[k][v] for k, v in chunk_index.items() if k in output_chunks + } + expected["data_vars"] = set(template.data_vars.keys()) # type: ignore + expected["coords"] = set(template.coords.keys()) # type: ignore + expected["indexes"] = { + dim: indexes[dim][_get_chunk_slicer(dim, chunk_index, output_chunk_bounds)] + for dim in indexes + } + + from_wrapper = (gname,) + chunk_tuple + graph[from_wrapper] = (_wrapper, func, blocked_args, kwargs, is_array, expected) # mapping from variable name to dask graph key var_key_map: Dict[Hashable, str] = {} @@ -356,10 +520,11 @@ def _wrapper(func, obj, to_array, args, kwargs): key: Tuple[Any, ...] = (gname_l,) for dim in variable.dims: - if dim in chunk_index_dict: - key += (chunk_index_dict[dim],) + if dim in chunk_index: + key += (chunk_index[dim],) else: # unchunked dimensions in the input have one chunk in the result + # output can have new dimensions with exactly one chunk key += (0,) # We're adding multiple new layers to the graph: @@ -370,7 +535,11 @@ def _wrapper(func, obj, to_array, args, kwargs): # layer. new_layers[gname_l][key] = (operator.getitem, from_wrapper, name) - hlg = HighLevelGraph.from_collections(gname, graph, dependencies=[dataset]) + hlg = HighLevelGraph.from_collections( + gname, + graph, + dependencies=[arg for arg in npargs if dask.is_dask_collection(arg)], + ) for gname_l, layer in new_layers.items(): # This adds in the getitems for each variable in the dataset. @@ -378,12 +547,16 @@ def _wrapper(func, obj, to_array, args, kwargs): hlg.layers[gname_l] = layer result = Dataset(coords=indexes, attrs=template.attrs) + for index in result.indexes: + result[index].attrs = template[index].attrs + result[index].encoding = template[index].encoding + for name, gname_l in var_key_map.items(): dims = template[name].dims var_chunks = [] for dim in dims: - if dim in input_chunks: - var_chunks.append(input_chunks[dim]) + if dim in output_chunks: + var_chunks.append(output_chunks[dim]) elif dim in indexes: var_chunks.append((len(indexes[dim]),)) elif dim in template.dims: @@ -394,6 +567,7 @@ def _wrapper(func, obj, to_array, args, kwargs): hlg, name=gname_l, chunks=var_chunks, dtype=template[name].dtype ) result[name] = (dims, data, template[name].attrs) + result[name].encoding = template[name].encoding result = result.set_coords(template._coord_names) diff --git a/xarray/core/pdcompat.py b/xarray/core/pdcompat.py index f2e4518e0dc..f2e22329fc8 100644 --- a/xarray/core/pdcompat.py +++ b/xarray/core/pdcompat.py @@ -55,4 +55,4 @@ def count_not_none(*args) -> int: Copied from pandas.core.common.count_not_none (not part of the public API) """ - return sum([arg is not None for arg in args]) + return sum(arg is not None for arg in args) diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index aaf52b9f295..8d613038957 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -1,14 +1,24 @@ import numpy as np +from .utils import is_duck_array + integer_types = (int, np.integer) try: - # solely for isinstance checks import dask.array + from dask.base import is_dask_collection + # solely for isinstance checks dask_array_type = (dask.array.Array,) + + def is_duck_dask_array(x): + return is_duck_array(x) and is_dask_collection(x) + + except ImportError: # pragma: no cover dask_array_type = () + is_duck_dask_array = lambda _: False + is_dask_collection = lambda _: False try: # solely for isinstance checks @@ -17,3 +27,11 @@ sparse_array_type = (sparse.SparseArray,) except ImportError: # pragma: no cover sparse_array_type = () + +try: + # solely for isinstance checks + import cupy + + cupy_array_type = (cupy.ndarray,) +except ImportError: # pragma: no cover + cupy_array_type = () diff --git a/xarray/core/resample.py b/xarray/core/resample.py index 2b3b7da6217..0a20d918bf1 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -29,8 +29,8 @@ def _upsample(self, method, *args, **kwargs): Parameters ---------- - method : str {'asfreq', 'pad', 'ffill', 'backfill', 'bfill', 'nearest', - 'interpolate'} + method : {"asfreq", "pad", "ffill", "backfill", "bfill", "nearest", \ + "interpolate"} Method to use for up-sampling See Also @@ -130,8 +130,8 @@ def interpolate(self, kind="linear"): Parameters ---------- - kind : str {'linear', 'nearest', 'zero', 'slinear', - 'quadratic', 'cubic'} + kind : {"linear", "nearest", "zero", "slinear", \ + "quadratic", "cubic"}, default: "linear" Interpolation scheme to use See Also @@ -193,7 +193,7 @@ def map(self, func, shortcut=False, args=(), **kwargs): Parameters ---------- - func : function + func : callable Callable to apply to each array. shortcut : bool, optional Whether or not to shortcut evaluation under the assumptions that: @@ -253,8 +253,7 @@ def apply(self, func, args=(), shortcut=None, **kwargs): class DatasetResample(DatasetGroupBy, Resample): - """DatasetGroupBy object specialized to resampling a specified dimension - """ + """DatasetGroupBy object specialized to resampling a specified dimension""" def __init__(self, *args, dim=None, resample_dim=None, **kwargs): @@ -271,7 +270,7 @@ def __init__(self, *args, dim=None, resample_dim=None, **kwargs): def map(self, func, args=(), shortcut=None, **kwargs): """Apply a function over each Dataset in the groups generated for - resampling and concatenate them together into a new Dataset. + resampling and concatenate them together into a new Dataset. `func` is called like `func(ds, *args, **kwargs)` for each dataset `ds` in this group. @@ -287,7 +286,7 @@ def map(self, func, args=(), shortcut=None, **kwargs): Parameters ---------- - func : function + func : callable Callable to apply to each sub-dataset. args : tuple, optional Positional arguments passed on to `func`. @@ -327,7 +326,7 @@ def reduce(self, func, dim=None, keep_attrs=None, **kwargs): Parameters ---------- - func : function + func : callable Function which can be called in the form `func(x, axis=axis, **kwargs)` to return the result of collapsing an np.ndarray over an integer valued axis. diff --git a/xarray/core/resample_cftime.py b/xarray/core/resample_cftime.py index cfac224363d..882664cbb60 100644 --- a/xarray/core/resample_cftime.py +++ b/xarray/core/resample_cftime.py @@ -224,7 +224,7 @@ def _adjust_bin_edges(datetime_bins, offset, closed, index, labels): def _get_range_edges(first, last, offset, closed="left", base=0): - """ Get the correct starting and ending datetimes for the resampled + """Get the correct starting and ending datetimes for the resampled CFTimeIndex range. Parameters @@ -272,7 +272,7 @@ def _get_range_edges(first, last, offset, closed="left", base=0): def _adjust_dates_anchored(first, last, offset, closed="right", base=0): - """ First and last offsets should be calculated from the start day to fix + """First and last offsets should be calculated from the start day to fix an error cause by resampling across multiple days when a one day period is not a multiple of the frequency. See https://github.com/pandas-dev/pandas/issues/8683 diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index ecba5307680..39d889244dc 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -8,7 +8,7 @@ from .dask_array_ops import dask_rolling_wrapper from .ops import inject_reduce_methods from .options import _get_keep_attrs -from .pycompat import dask_array_type +from .pycompat import is_duck_dask_array try: import bottleneck @@ -22,6 +22,10 @@ Parameters ---------- +keep_attrs : bool, default: None + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False, the new object will be returned + without attributes. If None uses the global default. **kwargs : dict Additional keyword arguments passed on to `{name}`. @@ -37,10 +41,10 @@ class Rolling: See Also -------- - Dataset.groupby - DataArray.groupby - Dataset.rolling - DataArray.rolling + xarray.Dataset.groupby + xarray.DataArray.groupby + xarray.Dataset.rolling + xarray.DataArray.rolling """ __slots__ = ("obj", "window", "min_periods", "center", "dim", "keep_attrs") @@ -54,61 +58,51 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None ---------- obj : Dataset or DataArray Object to window. - windows : A mapping from a dimension name to window size - dim : str - Name of the dimension to create the rolling iterator - along (e.g., `time`). - window : int - Size of the moving window. - min_periods : int, default None + windows : mapping of hashable to int + A mapping from the name of the dimension to create the rolling + window along (e.g. `time`) to the size of the moving window. + min_periods : int, default: None Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. - center : boolean, default False + center : bool, default: False Set the labels at the center of the window. - keep_attrs : bool, optional - If True, the object's attributes (`attrs`) will be copied from - the original object to the new one. If False (default), the new - object will be returned without attributes. Returns ------- rolling : type of input argument """ - if len(windows) != 1: - raise ValueError("exactly one dim/window should be provided") - - dim, window = next(iter(windows.items())) - - if window <= 0: - raise ValueError("window must be > 0") - + self.dim, self.window = [], [] + for d, w in windows.items(): + self.dim.append(d) + if w <= 0: + raise ValueError("window must be > 0") + self.window.append(w) + + self.center = self._mapping_to_list(center, default=False) self.obj = obj # attributes - self.window = window if min_periods is not None and min_periods <= 0: raise ValueError("min_periods must be greater than zero or None") - self.min_periods = min_periods - self.center = center - self.dim = dim + self.min_periods = np.prod(self.window) if min_periods is None else min_periods - if keep_attrs is None: - keep_attrs = _get_keep_attrs(default=False) + if keep_attrs is not None: + warnings.warn( + "Passing ``keep_attrs`` to ``rolling`` is deprecated and will raise an" + " error in xarray 0.18. Please pass ``keep_attrs`` directly to the" + " applied function. Note that keep_attrs is now True per default.", + FutureWarning, + ) self.keep_attrs = keep_attrs - @property - def _min_periods(self): - return self.min_periods if self.min_periods is not None else self.window - def __repr__(self): """provide a nice str repr of our rolling object""" attrs = [ - "{k}->{v}".format(k=k, v=getattr(self, k)) - for k in self._attributes - if getattr(self, k, None) is not None + "{k}->{v}{c}".format(k=k, v=w, c="(center)" if c else "") + for k, w, c in zip(self.dim, self.window, self.center) ] return "{klass} [{attrs}]".format( klass=self.__class__.__name__, attrs=",".join(attrs) @@ -121,9 +115,12 @@ def _reduce_method(name: str) -> Callable: # type: ignore array_agg_func = getattr(duck_array_ops, name) bottleneck_move_func = getattr(bottleneck, "move_" + name, None) - def method(self, **kwargs): + def method(self, keep_attrs=None, **kwargs): + + keep_attrs = self._get_keep_attrs(keep_attrs) + return self._numpy_or_bottleneck_reduce( - array_agg_func, bottleneck_move_func, **kwargs + array_agg_func, bottleneck_move_func, keep_attrs=keep_attrs, **kwargs ) method.__name__ = name @@ -141,13 +138,47 @@ def method(self, **kwargs): var = _reduce_method("var") median = _reduce_method("median") - def count(self): - rolling_count = self._counts() - enough_periods = rolling_count >= self._min_periods + def count(self, keep_attrs=None): + keep_attrs = self._get_keep_attrs(keep_attrs) + rolling_count = self._counts(keep_attrs=keep_attrs) + enough_periods = rolling_count >= self.min_periods return rolling_count.where(enough_periods) count.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name="count") + def _mapping_to_list( + self, arg, default=None, allow_default=True, allow_allsame=True + ): + if utils.is_dict_like(arg): + if allow_default: + return [arg.get(d, default) for d in self.dim] + else: + for d in self.dim: + if d not in arg: + raise KeyError(f"argument has no key {d}.") + return [arg[d] for d in self.dim] + elif allow_allsame: # for single argument + return [arg] * len(self.dim) + elif len(self.dim) == 1: + return [arg] + else: + raise ValueError( + "Mapping argument is necessary for {}d-rolling.".format(len(self.dim)) + ) + + def _get_keep_attrs(self, keep_attrs): + + if keep_attrs is None: + # TODO: uncomment the next line and remove the others after the deprecation + # keep_attrs = _get_keep_attrs(default=True) + + if self.keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + else: + keep_attrs = self.keep_attrs + + return keep_attrs + class DataArrayRolling(Rolling): __slots__ = ("window_labels",) @@ -162,22 +193,15 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None ---------- obj : DataArray Object to window. - windows : A mapping from a dimension name to window size - dim : str - Name of the dimension to create the rolling iterator - along (e.g., `time`). - window : int - Size of the moving window. - min_periods : int, default None + windows : mapping of hashable to int + A mapping from the name of the dimension to create the rolling + exponential window along (e.g. `time`) to the size of the moving window. + min_periods : int, default: None Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. - center : boolean, default False + center : bool, default: False Set the labels at the center of the window. - keep_attrs : bool, optional - If True, the object's attributes (`attrs`) will be copied from - the original object to the new one. If False (default), the new - object will be returned without attributes. Returns ------- @@ -185,44 +209,58 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None See Also -------- - DataArray.rolling - DataArray.groupby - Dataset.rolling - Dataset.groupby + xarray.DataArray.rolling + xarray.DataArray.groupby + xarray.Dataset.rolling + xarray.Dataset.groupby """ - if keep_attrs is None: - keep_attrs = _get_keep_attrs(default=False) super().__init__( obj, windows, min_periods=min_periods, center=center, keep_attrs=keep_attrs ) - self.window_labels = self.obj[self.dim] + # TODO legacy attribute + self.window_labels = self.obj[self.dim[0]] def __iter__(self): + if len(self.dim) > 1: + raise ValueError("__iter__ is only supported for 1d-rolling") stops = np.arange(1, len(self.window_labels) + 1) - starts = stops - int(self.window) - starts[: int(self.window)] = 0 + starts = stops - int(self.window[0]) + starts[: int(self.window[0])] = 0 for (label, start, stop) in zip(self.window_labels, starts, stops): - window = self.obj.isel(**{self.dim: slice(start, stop)}) + window = self.obj.isel(**{self.dim[0]: slice(start, stop)}) - counts = window.count(dim=self.dim) - window = window.where(counts >= self._min_periods) + counts = window.count(dim=self.dim[0]) + window = window.where(counts >= self.min_periods) yield (label, window) - def construct(self, window_dim, stride=1, fill_value=dtypes.NA): + def construct( + self, + window_dim=None, + stride=1, + fill_value=dtypes.NA, + keep_attrs=None, + **window_dim_kwargs, + ): """ Convert this rolling object to xr.DataArray, where the window dimension is stacked as a new dimension Parameters ---------- - window_dim: str - New name of the window dimension. - stride: integer, optional + window_dim : str or mapping, optional + A mapping from dimension name to the new window dimension names. + stride : int or mapping of int, default: 1 Size of stride for the rolling window. - fill_value: optional. Default dtypes.NA + fill_value : default: dtypes.NA Filling value to match the dimension size. + keep_attrs : bool, default: None + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False, the new object will be returned + without attributes. If None uses the global default. + **window_dim_kwargs : {dim: new_name, ...}, optional + The keyword arguments form of ``window_dim``. Returns ------- @@ -236,39 +274,80 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA): >>> rolling = da.rolling(b=3) >>> rolling.construct("window_dim") - array([[[np.nan, np.nan, 0], [np.nan, 0, 1], [0, 1, 2], [1, 2, 3]], - [[np.nan, np.nan, 4], [np.nan, 4, 5], [4, 5, 6], [5, 6, 7]]]) + array([[[nan, nan, 0.], + [nan, 0., 1.], + [ 0., 1., 2.], + [ 1., 2., 3.]], + + [[nan, nan, 4.], + [nan, 4., 5.], + [ 4., 5., 6.], + [ 5., 6., 7.]]]) Dimensions without coordinates: a, b, window_dim >>> rolling = da.rolling(b=3, center=True) >>> rolling.construct("window_dim") - array([[[np.nan, 0, 1], [0, 1, 2], [1, 2, 3], [2, 3, np.nan]], - [[np.nan, 4, 5], [4, 5, 6], [5, 6, 7], [6, 7, np.nan]]]) + array([[[nan, 0., 1.], + [ 0., 1., 2.], + [ 1., 2., 3.], + [ 2., 3., nan]], + + [[nan, 4., 5.], + [ 4., 5., 6.], + [ 5., 6., 7.], + [ 6., 7., nan]]]) Dimensions without coordinates: a, b, window_dim """ from .dataarray import DataArray + keep_attrs = self._get_keep_attrs(keep_attrs) + + if window_dim is None: + if len(window_dim_kwargs) == 0: + raise ValueError( + "Either window_dim or window_dim_kwargs need to be specified." + ) + window_dim = {d: window_dim_kwargs[d] for d in self.dim} + + window_dim = self._mapping_to_list( + window_dim, allow_default=False, allow_allsame=False + ) + stride = self._mapping_to_list(stride, default=1) + window = self.obj.variable.rolling_window( self.dim, self.window, window_dim, self.center, fill_value=fill_value ) + + attrs = self.obj.attrs if keep_attrs else {} + result = DataArray( - window, dims=self.obj.dims + (window_dim,), coords=self.obj.coords + window, + dims=self.obj.dims + tuple(window_dim), + coords=self.obj.coords, + attrs=attrs, + name=self.obj.name, + ) + return result.isel( + **{d: slice(None, None, s) for d, s in zip(self.dim, stride)} ) - return result.isel(**{self.dim: slice(None, None, stride)}) - def reduce(self, func, **kwargs): + def reduce(self, func, keep_attrs=None, **kwargs): """Reduce the items in this group by applying `func` along some dimension(s). Parameters ---------- - func : function + func : callable Function which can be called in the form `func(x, **kwargs)` to return the result of collapsing an np.ndarray over an the rolling dimension. + keep_attrs : bool, default: None + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False, the new object will be returned + without attributes. If None uses the global default. **kwargs : dict Additional keyword arguments passed on to `func`. @@ -283,8 +362,15 @@ def reduce(self, func, **kwargs): >>> rolling = da.rolling(b=3) >>> rolling.construct("window_dim") - array([[[np.nan, np.nan, 0], [np.nan, 0, 1], [0, 1, 2], [1, 2, 3]], - [[np.nan, np.nan, 4], [np.nan, 4, 5], [4, 5, 6], [5, 6, 7]]]) + array([[[nan, nan, 0.], + [nan, 0., 1.], + [ 0., 1., 2.], + [ 1., 2., 3.]], + + [[nan, nan, 4.], + [nan, 4., 5.], + [ 4., 5., 6.], + [ 5., 6., 7.]]]) Dimensions without coordinates: a, b, window_dim >>> rolling.reduce(np.sum) @@ -298,95 +384,117 @@ def reduce(self, func, **kwargs): array([[ 0., 1., 3., 6.], [ 4., 9., 15., 18.]]) - + Dimensions without coordinates: a, b """ - rolling_dim = utils.get_temp_dimname(self.obj.dims, "_rolling_dim") - windows = self.construct(rolling_dim) - result = windows.reduce(func, dim=rolling_dim, **kwargs) + + keep_attrs = self._get_keep_attrs(keep_attrs) + + rolling_dim = { + d: utils.get_temp_dimname(self.obj.dims, f"_rolling_dim_{d}") + for d in self.dim + } + windows = self.construct(rolling_dim, keep_attrs=keep_attrs) + result = windows.reduce( + func, dim=list(rolling_dim.values()), keep_attrs=keep_attrs, **kwargs + ) # Find valid windows based on count. - counts = self._counts() - return result.where(counts >= self._min_periods) + counts = self._counts(keep_attrs=False) + return result.where(counts >= self.min_periods) - def _counts(self): - """ Number of non-nan entries in each rolling window. """ + def _counts(self, keep_attrs): + """Number of non-nan entries in each rolling window.""" - rolling_dim = utils.get_temp_dimname(self.obj.dims, "_rolling_dim") + rolling_dim = { + d: utils.get_temp_dimname(self.obj.dims, f"_rolling_dim_{d}") + for d in self.dim + } # We use False as the fill_value instead of np.nan, since boolean # array is faster to be reduced than object array. # The use of skipna==False is also faster since it does not need to # copy the strided array. counts = ( - self.obj.notnull() - .rolling(center=self.center, **{self.dim: self.window}) - .construct(rolling_dim, fill_value=False) - .sum(dim=rolling_dim, skipna=False) + self.obj.notnull(keep_attrs=keep_attrs) + .rolling( + center={d: self.center[i] for i, d in enumerate(self.dim)}, + **{d: w for d, w in zip(self.dim, self.window)}, + ) + .construct(rolling_dim, fill_value=False, keep_attrs=keep_attrs) + .sum(dim=list(rolling_dim.values()), skipna=False, keep_attrs=keep_attrs) ) return counts - def _bottleneck_reduce(self, func, **kwargs): + def _bottleneck_reduce(self, func, keep_attrs, **kwargs): from .dataarray import DataArray # bottleneck doesn't allow min_count to be 0, although it should # work the same as if min_count = 1 + # Note bottleneck only works with 1d-rolling. if self.min_periods is not None and self.min_periods == 0: min_count = 1 else: min_count = self.min_periods - axis = self.obj.get_axis_num(self.dim) + axis = self.obj.get_axis_num(self.dim[0]) padded = self.obj.variable - if self.center: - if isinstance(padded.data, dask_array_type): - # Workaround to make the padded chunk size is larger than - # self.window-1 - shift = -(self.window + 1) // 2 - offset = (self.window - 1) // 2 + if self.center[0]: + if is_duck_dask_array(padded.data): + # workaround to make the padded chunk size larger than + # self.window - 1 + shift = -(self.window[0] + 1) // 2 + offset = (self.window[0] - 1) // 2 valid = (slice(None),) * axis + ( slice(offset, offset + self.obj.shape[axis]), ) else: - shift = (-self.window // 2) + 1 + shift = (-self.window[0] // 2) + 1 valid = (slice(None),) * axis + (slice(-shift, None),) - padded = padded.pad({self.dim: (0, -shift)}, mode="constant") + padded = padded.pad({self.dim[0]: (0, -shift)}, mode="constant") - if isinstance(padded.data, dask_array_type): + if is_duck_dask_array(padded.data): raise AssertionError("should not be reachable") values = dask_rolling_wrapper( - func, padded.data, window=self.window, min_count=min_count, axis=axis + func, padded.data, window=self.window[0], min_count=min_count, axis=axis ) else: values = func( - padded.data, window=self.window, min_count=min_count, axis=axis + padded.data, window=self.window[0], min_count=min_count, axis=axis ) - if self.center: + if self.center[0]: values = values[valid] - result = DataArray(values, self.obj.coords) - return result + attrs = self.obj.attrs if keep_attrs else {} + + return DataArray(values, self.obj.coords, attrs=attrs, name=self.obj.name) def _numpy_or_bottleneck_reduce( - self, array_agg_func, bottleneck_move_func, **kwargs + self, array_agg_func, bottleneck_move_func, keep_attrs, **kwargs ): if "dim" in kwargs: warnings.warn( - f"Reductions will be applied along the rolling dimension '{self.dim}'. Passing the 'dim' kwarg to reduction operations has no effect and will raise an error in xarray 0.16.0.", + f"Reductions are applied along the rolling dimension(s) " + f"'{self.dim}'. Passing the 'dim' kwarg to reduction " + f"operations has no effect.", DeprecationWarning, stacklevel=3, ) del kwargs["dim"] - if bottleneck_move_func is not None and not isinstance( - self.obj.data, dask_array_type + if ( + bottleneck_move_func is not None + and not is_duck_dask_array(self.obj.data) + and len(self.dim) == 1 ): # TODO: renable bottleneck with dask after the issues # underlying https://github.com/pydata/xarray/issues/2940 are # fixed. - return self._bottleneck_reduce(bottleneck_move_func, **kwargs) + return self._bottleneck_reduce( + bottleneck_move_func, keep_attrs=keep_attrs, **kwargs + ) else: - return self.reduce(array_agg_func, **kwargs) + return self.reduce(array_agg_func, keep_attrs=keep_attrs, **kwargs) class DatasetRolling(Rolling): @@ -402,22 +510,15 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None ---------- obj : Dataset Object to window. - windows : A mapping from a dimension name to window size - dim : str - Name of the dimension to create the rolling iterator - along (e.g., `time`). - window : int - Size of the moving window. - min_periods : int, default None + windows : mapping of hashable to int + A mapping from the name of the dimension to create the rolling + exponential window along (e.g. `time`) to the size of the moving window. + min_periods : int, default: None Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. - center : boolean, default False + center : bool or mapping of hashable to bool, default: False Set the labels at the center of the window. - keep_attrs : bool, optional - If True, the object's attributes (`attrs`) will be copied from - the original object to the new one. If False (default), the new - object will be returned without attributes. Returns ------- @@ -425,45 +526,60 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None See Also -------- - Dataset.rolling - DataArray.rolling - Dataset.groupby - DataArray.groupby + xarray.Dataset.rolling + xarray.DataArray.rolling + xarray.Dataset.groupby + xarray.DataArray.groupby """ super().__init__(obj, windows, min_periods, center, keep_attrs) - if self.dim not in self.obj.dims: + if any(d not in self.obj.dims for d in self.dim): raise KeyError(self.dim) # Keep each Rolling object as a dictionary self.rollings = {} for key, da in self.obj.data_vars.items(): - # keeps rollings only for the dataset depending on slf.dim - if self.dim in da.dims: - self.rollings[key] = DataArrayRolling( - da, windows, min_periods, center, keep_attrs - ) - - def _dataset_implementation(self, func, **kwargs): + # keeps rollings only for the dataset depending on self.dim + dims, center = [], {} + for i, d in enumerate(self.dim): + if d in da.dims: + dims.append(d) + center[d] = self.center[i] + + if len(dims) > 0: + w = {d: windows[d] for d in dims} + self.rollings[key] = DataArrayRolling(da, w, min_periods, center) + + def _dataset_implementation(self, func, keep_attrs, **kwargs): from .dataset import Dataset + keep_attrs = self._get_keep_attrs(keep_attrs) + reduced = {} for key, da in self.obj.data_vars.items(): - if self.dim in da.dims: - reduced[key] = func(self.rollings[key], **kwargs) + if any(d in da.dims for d in self.dim): + reduced[key] = func(self.rollings[key], keep_attrs=keep_attrs, **kwargs) else: - reduced[key] = self.obj[key] - attrs = self.obj.attrs if self.keep_attrs else {} + reduced[key] = self.obj[key].copy() + # we need to delete the attrs of the copied DataArray + if not keep_attrs: + reduced[key].attrs = {} + + attrs = self.obj.attrs if keep_attrs else {} return Dataset(reduced, coords=self.obj.coords, attrs=attrs) - def reduce(self, func, **kwargs): + def reduce(self, func, keep_attrs=None, **kwargs): """Reduce the items in this group by applying `func` along some dimension(s). Parameters ---------- - func : function + func : callable Function which can be called in the form `func(x, **kwargs)` to return the result of collapsing an np.ndarray over an the rolling dimension. + keep_attrs : bool, default: None + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False, the new object will be returned + without attributes. If None uses the global default. **kwargs : dict Additional keyword arguments passed on to `func`. @@ -473,14 +589,18 @@ def reduce(self, func, **kwargs): Array with summarized data. """ return self._dataset_implementation( - functools.partial(DataArrayRolling.reduce, func=func), **kwargs + functools.partial(DataArrayRolling.reduce, func=func), + keep_attrs=keep_attrs, + **kwargs, ) - def _counts(self): - return self._dataset_implementation(DataArrayRolling._counts) + def _counts(self, keep_attrs): + return self._dataset_implementation( + DataArrayRolling._counts, keep_attrs=keep_attrs + ) def _numpy_or_bottleneck_reduce( - self, array_agg_func, bottleneck_move_func, **kwargs + self, array_agg_func, bottleneck_move_func, keep_attrs, **kwargs ): return self._dataset_implementation( functools.partial( @@ -488,22 +608,33 @@ def _numpy_or_bottleneck_reduce( array_agg_func=array_agg_func, bottleneck_move_func=bottleneck_move_func, ), + keep_attrs=keep_attrs, **kwargs, ) - def construct(self, window_dim, stride=1, fill_value=dtypes.NA, keep_attrs=None): + def construct( + self, + window_dim=None, + stride=1, + fill_value=dtypes.NA, + keep_attrs=None, + **window_dim_kwargs, + ): """ Convert this rolling object to xr.Dataset, where the window dimension is stacked as a new dimension Parameters ---------- - window_dim: str - New name of the window dimension. - stride: integer, optional + window_dim : str or mapping, optional + A mapping from dimension name to the new window dimension names. + Just a string can be used for 1d-rolling. + stride : int, optional size of stride for the rolling window. - fill_value: optional. Default dtypes.NA + fill_value : Any, default: dtypes.NA Filling value to match the dimension size. + **window_dim_kwargs : {dim: new_name, ...}, optional + The keyword arguments form of ``window_dim``. Returns ------- @@ -512,19 +643,45 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA, keep_attrs=None) from .dataset import Dataset - if keep_attrs is None: - keep_attrs = _get_keep_attrs(default=True) + keep_attrs = self._get_keep_attrs(keep_attrs) + + if window_dim is None: + if len(window_dim_kwargs) == 0: + raise ValueError( + "Either window_dim or window_dim_kwargs need to be specified." + ) + window_dim = {d: window_dim_kwargs[d] for d in self.dim} + + window_dim = self._mapping_to_list( + window_dim, allow_default=False, allow_allsame=False + ) + stride = self._mapping_to_list(stride, default=1) dataset = {} for key, da in self.obj.data_vars.items(): - if self.dim in da.dims: + # keeps rollings only for the dataset depending on self.dim + dims = [d for d in self.dim if d in da.dims] + if len(dims) > 0: + wi = {d: window_dim[i] for i, d in enumerate(self.dim) if d in da.dims} + st = {d: stride[i] for i, d in enumerate(self.dim) if d in da.dims} + dataset[key] = self.rollings[key].construct( - window_dim, fill_value=fill_value + window_dim=wi, + fill_value=fill_value, + stride=st, + keep_attrs=keep_attrs, ) else: - dataset[key] = da - return Dataset(dataset, coords=self.obj.coords).isel( - **{self.dim: slice(None, None, stride)} + dataset[key] = da.copy() + + # as the DataArrays can be copied we need to delete the attrs + if not keep_attrs: + dataset[key].attrs = {} + + attrs = self.obj.attrs if keep_attrs else {} + + return Dataset(dataset, coords=self.obj.coords, attrs=attrs).isel( + **{d: slice(None, None, s) for d, s in zip(self.dim, stride)} ) @@ -556,12 +713,9 @@ def __init__(self, obj, windows, boundary, side, coord_func, keep_attrs): ---------- obj : Dataset or DataArray Object to window. - windows : A mapping from a dimension name to window size - dim : str - Name of the dimension to create the rolling iterator - along (e.g., `time`). - window : int - Size of the moving window. + windows : mapping of hashable to int + A mapping from the name of the dimension to create the rolling + exponential window along (e.g. `time`) to the size of the moving window. boundary : 'exact' | 'trim' | 'pad' If 'exact', a ValueError will be raised if dimension size is not a multiple of window size. If 'trim', the excess indexes are trimed. @@ -623,7 +777,7 @@ def wrapped_func(self, **kwargs): from .dataarray import DataArray reduced = self.obj.variable.coarsen( - self.windows, func, self.boundary, self.side, **kwargs + self.windows, func, self.boundary, self.side, self.keep_attrs, **kwargs ) coords = {} for c, v in self.obj.coords.items(): @@ -636,6 +790,7 @@ def wrapped_func(self, **kwargs): self.coord_func[c], self.boundary, self.side, + self.keep_attrs, **kwargs, ) else: diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index 6ef63e42291..0ae85a870e8 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -1,7 +1,16 @@ +from typing import TYPE_CHECKING, Generic, Hashable, Mapping, Optional, TypeVar + import numpy as np +from .options import _get_keep_attrs from .pdcompat import count_not_none -from .pycompat import dask_array_type +from .pycompat import is_duck_dask_array + +if TYPE_CHECKING: + from .dataarray import DataArray # noqa: F401 + from .dataset import Dataset # noqa: F401 + +T_DSorDA = TypeVar("T_DSorDA", "DataArray", "Dataset") def _get_alpha(com=None, span=None, halflife=None, alpha=None): @@ -13,8 +22,8 @@ def _get_alpha(com=None, span=None, halflife=None, alpha=None): def move_exp_nanmean(array, *, axis, alpha): - if isinstance(array, dask_array_type): - raise TypeError("rolling_exp is not currently support for dask arrays") + if is_duck_dask_array(array): + raise TypeError("rolling_exp is not currently support for dask-like arrays") import numbagg if axis == (): @@ -31,7 +40,7 @@ def _get_center_of_mass(comass, span, halflife, alpha): """ valid_count = count_not_none(comass, span, halflife, alpha) if valid_count > 1: - raise ValueError("comass, span, halflife, and alpha " "are mutually exclusive") + raise ValueError("comass, span, halflife, and alpha are mutually exclusive") # Convert to center of mass; domain checks ensure 0 < alpha <= 1 if comass is not None: @@ -56,7 +65,7 @@ def _get_center_of_mass(comass, span, halflife, alpha): return float(comass) -class RollingExp: +class RollingExp(Generic[T_DSorDA]): """ Exponentially-weighted moving window object. Similar to EWM in pandas @@ -65,40 +74,53 @@ class RollingExp: ---------- obj : Dataset or DataArray Object to window. - windows : A single mapping from a single dimension name to window value - dim : str - Name of the dimension to create the rolling exponential window - along (e.g., `time`). - window : int - Size of the moving window. The type of this is specified in - `window_type` - window_type : str, one of ['span', 'com', 'halflife', 'alpha'], default 'span' + windows : mapping of hashable to int + A mapping from the name of the dimension to create the rolling + exponential window along (e.g. `time`) to the size of the moving window. + window_type : {"span", "com", "halflife", "alpha"}, default: "span" The format of the previously supplied window. Each is a simple numerical transformation of the others. Described in detail: - https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.ewm.html + https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.ewm.html Returns ------- RollingExp : type of input argument """ - def __init__(self, obj, windows, window_type="span"): - self.obj = obj + def __init__( + self, + obj: T_DSorDA, + windows: Mapping[Hashable, int], + window_type: str = "span", + ): + self.obj: T_DSorDA = obj dim, window = next(iter(windows.items())) self.dim = dim self.alpha = _get_alpha(**{window_type: window}) - def mean(self): + def mean(self, keep_attrs: Optional[bool] = None) -> T_DSorDA: """ Exponentially weighted moving average + Parameters + ---------- + keep_attrs : bool, default: None + If True, the attributes (``attrs``) will be copied from the original + object to the new one. If False, the new object will be returned + without attributes. If None uses the global default. + Examples -------- >>> da = xr.DataArray([1, 1, 2, 2, 2], dims="x") >>> da.rolling_exp(x=2, window_type="span").mean() - array([1. , 1. , 1.692308, 1.9 , 1.966942]) + array([1. , 1. , 1.69230769, 1.9 , 1.96694215]) Dimensions without coordinates: x """ - return self.obj.reduce(move_exp_nanmean, dim=self.dim, alpha=self.alpha) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) + + return self.obj.reduce( + move_exp_nanmean, dim=self.dim, alpha=self.alpha, keep_attrs=keep_attrs + ) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 1126cf3037f..ced688f32dd 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -2,13 +2,13 @@ """ import contextlib import functools +import io import itertools import os.path import re import warnings from enum import Enum from typing import ( - AbstractSet, Any, Callable, Collection, @@ -31,19 +31,13 @@ import numpy as np import pandas as pd +from . import dtypes + K = TypeVar("K") V = TypeVar("V") T = TypeVar("T") -def _check_inplace(inplace: Optional[bool]) -> None: - if inplace is not None: - raise TypeError( - "The `inplace` argument has been removed from xarray. " - "You can achieve an identical effect with python's standard assignment." - ) - - def alias_message(old_name: str, new_name: str) -> str: return f"{old_name} has been deprecated. Use {new_name} instead." @@ -84,6 +78,23 @@ def maybe_cast_to_coords_dtype(label, coords_dtype): return label +def maybe_coerce_to_str(index, original_coords): + """maybe coerce a pandas Index back to a nunpy array of type str + + pd.Index uses object-dtype to store str - try to avoid this for coords + """ + + try: + result_type = dtypes.result_type(*original_coords) + except TypeError: + pass + else: + if result_type.kind in "SU": + index = np.asarray(index, dtype=result_type.type) + + return index + + def safe_cast_to_index(array: Any) -> pd.Index: """Given an array, safely cast it to a pandas.Index. @@ -116,7 +127,7 @@ def multiindex_from_product_levels( ---------- levels : sequence of pd.Index Values for each MultiIndex level. - names : optional sequence of objects + names : sequence of str, optional Names for each level. Returns @@ -133,7 +144,7 @@ def multiindex_from_product_levels( def maybe_wrap_array(original, new_array): - """Wrap a transformed array with __array_wrap__ is it can be done safely. + """Wrap a transformed array with __array_wrap__ if it can be done safely. This lets us treat arbitrary functions that take and return ndarray objects like ufuncs, as long as they return an array with the same shape. @@ -247,6 +258,18 @@ def is_list_like(value: Any) -> bool: return isinstance(value, list) or isinstance(value, tuple) +def is_duck_array(value: Any) -> bool: + if isinstance(value, np.ndarray): + return True + return ( + hasattr(value, "ndim") + and hasattr(value, "shape") + and hasattr(value, "dtype") + and hasattr(value, "__array_function__") + and hasattr(value, "__array_ufunc__") + ) + + def either_dict_or_kwargs( pos_kwargs: Optional[Mapping[Hashable, T]], kw_kwargs: Mapping[str, T], @@ -298,16 +321,14 @@ def is_valid_numpy_dtype(dtype: Any) -> bool: def to_0d_object_array(value: Any) -> np.ndarray: - """Given a value, wrap it in a 0-D numpy.ndarray with dtype=object. - """ + """Given a value, wrap it in a 0-D numpy.ndarray with dtype=object.""" result = np.empty((), dtype=object) result[()] = value return result def to_0d_array(value: Any) -> np.ndarray: - """Given a value, wrap it in a 0-D numpy.ndarray. - """ + """Given a value, wrap it in a 0-D numpy.ndarray.""" if np.isscalar(value) or (isinstance(value, np.ndarray) and value.ndim == 0): return np.array(value) else: @@ -432,6 +453,35 @@ def FrozenDict(*args, **kwargs) -> Frozen: return Frozen(dict(*args, **kwargs)) +class HybridMappingProxy(Mapping[K, V]): + """Implements the Mapping interface. Uses the wrapped mapping for item lookup + and a separate wrapped keys collection for iteration. + + Can be used to construct a mapping object from another dict-like object without + eagerly accessing its items or when a mapping object is expected but only + iteration over keys is actually used. + + Note: HybridMappingProxy does not validate consistency of the provided `keys` + and `mapping`. It is the caller's responsibility to ensure that they are + suitable for the task at hand. + """ + + __slots__ = ("_keys", "mapping") + + def __init__(self, keys: Collection[K], mapping: Mapping[K, V]): + self._keys = keys + self.mapping = mapping + + def __getitem__(self, key: K) -> V: + return self.mapping[key] + + def __iter__(self) -> Iterator[K]: + return iter(self._keys) + + def __len__(self) -> int: + return len(self._keys) + + class SortedKeysDict(MutableMapping[K, V]): """An wrapper for dictionary-like objects that always iterates over its items in sorted order by key but is otherwise equivalent to the underlying @@ -453,7 +503,8 @@ def __delitem__(self, key: K) -> None: del self.mapping[key] def __iter__(self) -> Iterator[K]: - return iter(sorted(self.mapping)) + # see #4571 for the reason of the type ignore + return iter(sorted(self.mapping)) # type: ignore def __len__(self) -> int: return len(self.mapping) @@ -476,17 +527,14 @@ class OrderedSet(MutableSet[T]): __slots__ = ("_d",) - def __init__(self, values: AbstractSet[T] = None): + def __init__(self, values: Iterable[T] = None): self._d = {} if values is not None: - # Disable type checking - both mypy and PyCharm believe that - # we're altering the type of self in place (see signature of - # MutableSet.__ior__) - self |= values # type: ignore + self.update(values) # Required methods for MutableSet - def __contains__(self, value: object) -> bool: + def __contains__(self, value: Hashable) -> bool: return value in self._d def __iter__(self) -> Iterator[T]: @@ -503,9 +551,9 @@ def discard(self, value: T) -> None: # Additional methods - def update(self, values: AbstractSet[T]) -> None: - # See comment on __init__ re. type checking - self |= values # type: ignore + def update(self, values: Iterable[T]) -> None: + for v in values: + self._d[v] = None def __repr__(self) -> str: return "{}({!r})".format(type(self).__name__, list(self)) @@ -560,8 +608,7 @@ def __repr__(self: Any) -> str: class ReprObject: - """Object that prints as the given value, for use with sentinel values. - """ + """Object that prints as the given value, for use with sentinel values.""" __slots__ = ("_value",) @@ -601,6 +648,24 @@ def is_remote_uri(path: str) -> bool: return bool(re.search(r"^https?\://", path)) +def read_magic_number(filename_or_obj, count=8): + # check byte header to determine file type + if isinstance(filename_or_obj, bytes): + magic_number = filename_or_obj[:count] + elif isinstance(filename_or_obj, io.IOBase): + if filename_or_obj.tell() != 0: + raise ValueError( + "cannot guess the engine, " + "file-like object read/write pointer not at the start of the file, " + "please close and reopen, or use a context manager" + ) + magic_number = filename_or_obj.read(count) + filename_or_obj.seek(0) + else: + raise TypeError(f"cannot read the magic number form {type(filename_or_obj)}") + return magic_number + + def is_grib_path(path: str) -> bool: _, ext = os.path.splitext(path) return ext in [".grib", ".grb", ".grib2", ".grb2"] @@ -622,8 +687,7 @@ def is_uniform_spaced(arr, **kwargs) -> bool: def hashable(v: Any) -> bool: - """Determine whether `v` can be hashed. - """ + """Determine whether `v` can be hashed.""" try: hash(v) except TypeError: @@ -659,8 +723,7 @@ def ensure_us_time_resolution(val): class HiddenKeyDict(MutableMapping[K, V]): - """Acts like a normal dictionary, but hides certain keys. - """ + """Acts like a normal dictionary, but hides certain keys.""" __slots__ = ("_data", "_hidden_keys") @@ -697,32 +760,36 @@ def __len__(self) -> int: return len(self._data) - num_hidden -def infix_dims(dims_supplied: Collection, dims_all: Collection) -> Iterator: +def infix_dims( + dims_supplied: Collection, dims_all: Collection, missing_dims: str = "raise" +) -> Iterator: """ - Resolves a supplied list containing an ellispsis representing other items, to + Resolves a supplied list containing an ellipsis representing other items, to a generator with the 'realized' list of all items """ if ... in dims_supplied: if len(set(dims_all)) != len(dims_all): raise ValueError("Cannot use ellipsis with repeated dims") - if len([d for d in dims_supplied if d == ...]) > 1: + if list(dims_supplied).count(...) > 1: raise ValueError("More than one ellipsis supplied") other_dims = [d for d in dims_all if d not in dims_supplied] - for d in dims_supplied: - if d == ...: + existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims) + for d in existing_dims: + if d is ...: yield from other_dims else: yield d else: - if set(dims_supplied) ^ set(dims_all): + existing_dims = drop_missing_dims(dims_supplied, dims_all, missing_dims) + if set(existing_dims) ^ set(dims_all): raise ValueError( f"{dims_supplied} must be a permuted list of {dims_all}, unless `...` is included" ) - yield from dims_supplied + yield from existing_dims def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable: - """ Get an new dimension name based on new_dim, that is not used in dims. + """Get an new dimension name based on new_dim, that is not used in dims. If the same name exists, we add an underscore(s) in the head. Example1: @@ -744,7 +811,7 @@ def drop_dims_from_indexers( dims: Union[list, Mapping[Hashable, int]], missing_dims: str, ) -> Mapping[Hashable, Any]: - """ Depending on the setting of missing_dims, drop any dimensions from indexers that + """Depending on the setting of missing_dims, drop any dimensions from indexers that are not present in dims. Parameters @@ -758,7 +825,7 @@ def drop_dims_from_indexers( invalid = indexers.keys() - set(dims) if invalid: raise ValueError( - f"dimensions {invalid} do not exist. Expected one or more of {dims}" + f"Dimensions {invalid} do not exist. Expected one or more of {dims}" ) return indexers @@ -771,7 +838,7 @@ def drop_dims_from_indexers( invalid = indexers.keys() - set(dims) if invalid: warnings.warn( - f"dimensions {invalid} do not exist. Expected one or more of {dims}" + f"Dimensions {invalid} do not exist. Expected one or more of {dims}" ) for key in invalid: indexers.pop(key) @@ -787,6 +854,66 @@ def drop_dims_from_indexers( ) +def drop_missing_dims( + supplied_dims: Collection, dims: Collection, missing_dims: str +) -> Collection: + """Depending on the setting of missing_dims, drop any dimensions from supplied_dims that + are not present in dims. + + Parameters + ---------- + supplied_dims : dict + dims : sequence + missing_dims : {"raise", "warn", "ignore"} + """ + + if missing_dims == "raise": + supplied_dims_set = set(val for val in supplied_dims if val is not ...) + invalid = supplied_dims_set - set(dims) + if invalid: + raise ValueError( + f"Dimensions {invalid} do not exist. Expected one or more of {dims}" + ) + + return supplied_dims + + elif missing_dims == "warn": + + invalid = set(supplied_dims) - set(dims) + if invalid: + warnings.warn( + f"Dimensions {invalid} do not exist. Expected one or more of {dims}" + ) + + return [val for val in supplied_dims if val in dims or val is ...] + + elif missing_dims == "ignore": + return [val for val in supplied_dims if val in dims or val is ...] + + else: + raise ValueError( + f"Unrecognised option {missing_dims} for missing_dims argument" + ) + + +class UncachedAccessor: + """Acts like a property, but on both classes and class instances + + This class is necessary because some tools (e.g. pydoc and sphinx) + inspect classes for which property returns itself and not the + accessor. + """ + + def __init__(self, accessor): + self._accessor = accessor + + def __get__(self, obj, cls): + if obj is None: + return self._accessor + + return self._accessor(obj) + + # Singleton type, as per https://github.com/python/typing/pull/240 class Default(Enum): token = 0 diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 68e823ca426..797de65bbcf 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -6,7 +6,17 @@ from collections import defaultdict from datetime import timedelta from distutils.version import LooseVersion -from typing import Any, Dict, Hashable, Mapping, Tuple, TypeVar, Union +from typing import ( + Any, + Dict, + Hashable, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) import numpy as np import pandas as pd @@ -23,7 +33,12 @@ ) from .npcompat import IS_NEP18_ACTIVE from .options import _get_keep_attrs -from .pycompat import dask_array_type, integer_types +from .pycompat import ( + cupy_array_type, + dask_array_type, + integer_types, + is_duck_dask_array, +) from .utils import ( OrderedSet, _default, @@ -32,12 +47,18 @@ either_dict_or_kwargs, ensure_us_time_resolution, infix_dims, + is_duck_array, + maybe_coerce_to_str, ) NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( - indexing.ExplicitlyIndexed, - pd.Index, -) + dask_array_type + ( + indexing.ExplicitlyIndexed, + pd.Index, + ) + + dask_array_type + + cupy_array_type +) # https://github.com/python/mypy/issues/224 BASIC_INDEXING_TYPES = integer_types + (slice,) # type: ignore @@ -55,8 +76,7 @@ def f(self: VariableType, ...) -> VariableType: class MissingDimensionsError(ValueError): - """Error class used when we can't safely guess a dimension name. - """ + """Error class used when we can't safely guess a dimension name.""" # inherits from ValueError for backward compatibility # TODO: move this to an xarray.exceptions module? @@ -158,7 +178,9 @@ def _maybe_wrap_data(data): def _possibly_convert_objects(values): """Convert arrays of datetime.datetime and datetime.timedelta objects into - datetime64 and timedelta64, according to the pandas convention. + datetime64 and timedelta64, according to the pandas convention. Also used for + validating that datetime64 and timedelta64 objects are within the valid date + range for ns precision, as pandas will raise an error if they are not. """ return np.asarray(pd.Series(values.ravel())).reshape(values.shape) @@ -219,16 +241,16 @@ def as_compatible_data(data, fastpath=False): '"1"' ) - # validate whether the data is valid data types + # validate whether the data is valid data types. data = np.asarray(data) if isinstance(data, np.ndarray): if data.dtype.kind == "O": data = _possibly_convert_objects(data) elif data.dtype.kind == "M": - data = np.asarray(data, "datetime64[ns]") + data = _possibly_convert_objects(data) elif data.dtype.kind == "m": - data = np.asarray(data, "timedelta64[ns]") + data = _possibly_convert_objects(data) return _maybe_wrap_data(data) @@ -247,7 +269,10 @@ def _as_array_or_item(data): TODO: remove this (replace with np.asarray) once these issues are fixed """ - data = np.asarray(data) + if isinstance(data, cupy_array_type): + data = data.get() + else: + data = np.asarray(data) if data.ndim == 0: if data.dtype.kind == "M": data = np.datetime64(data, "ns") @@ -331,9 +356,7 @@ def _in_memory(self): @property def data(self): - if hasattr(self._data, "__array_function__") or isinstance( - self._data, dask_array_type - ): + if is_duck_array(self._data): return self._data else: return self.values @@ -348,6 +371,82 @@ def data(self, data): ) self._data = data + def astype( + self: VariableType, + dtype, + *, + order=None, + casting=None, + subok=None, + copy=None, + keep_attrs=True, + ) -> VariableType: + """ + Copy of the Variable object, with data cast to a specified type. + + Parameters + ---------- + dtype : str or dtype + Typecode or data-type to which the array is cast. + order : {'C', 'F', 'A', 'K'}, optional + Controls the memory layout order of the result. ‘C’ means C order, + ‘F’ means Fortran order, ‘A’ means ‘F’ order if all the arrays are + Fortran contiguous, ‘C’ order otherwise, and ‘K’ means as close to + the order the array elements appear in memory as possible. + casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional + Controls what kind of data casting may occur. + + * 'no' means the data types should not be cast at all. + * 'equiv' means only byte-order changes are allowed. + * 'safe' means only casts which can preserve values are allowed. + * 'same_kind' means only safe casts or casts within a kind, + like float64 to float32, are allowed. + * 'unsafe' means any data conversions may be done. + + subok : bool, optional + If True, then sub-classes will be passed-through, otherwise the + returned array will be forced to be a base-class array. + copy : bool, optional + By default, astype always returns a newly allocated array. If this + is set to False and the `dtype` requirement is satisfied, the input + array is returned instead of a copy. + keep_attrs : bool, optional + By default, astype keeps attributes. Set to False to remove + attributes in the returned object. + + Returns + ------- + out : same as object + New object with data cast to the specified type. + + Notes + ----- + The ``order``, ``casting``, ``subok`` and ``copy`` arguments are only passed + through to the ``astype`` method of the underlying array when a value + different than ``None`` is supplied. + Make sure to only supply these arguments if the underlying array class + supports them. + + See also + -------- + numpy.ndarray.astype + dask.array.Array.astype + sparse.COO.astype + """ + from .computation import apply_ufunc + + kwargs = dict(order=order, casting=casting, subok=subok, copy=copy) + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + return apply_ufunc( + duck_array_ops.astype, + self, + dtype, + kwargs=kwargs, + keep_attrs=keep_attrs, + dask="allowed", + ) + def load(self, **kwargs): """Manually trigger loading of this variable's data from disk or a remote source into memory and return this variable. @@ -365,9 +464,9 @@ def load(self, **kwargs): -------- dask.array.compute """ - if isinstance(self._data, dask_array_type): + if is_duck_dask_array(self._data): self._data = as_compatible_data(self._data.compute(**kwargs)) - elif not hasattr(self._data, "__array_function__"): + elif not is_duck_array(self._data): self._data = np.asarray(self._data) return self @@ -400,7 +499,7 @@ def __dask_tokenize__(self): return normalize_token((type(self), self._dims, self.data, self._attrs)) def __dask_graph__(self): - if isinstance(self._data, dask_array_type): + if is_duck_dask_array(self._data): return self._data.__dask_graph__() else: return None @@ -435,9 +534,6 @@ def __dask_postpersist__(self): @staticmethod def _dask_finalize(results, array_func, array_args, dims, attrs, encoding): - if isinstance(results, dict): # persist case - name = array_args[0] - results = {k: v for k, v in results.items() if k[0] == name} data = array_func(results, *array_args) return Variable(dims, data, attrs=attrs, encoding=encoding) @@ -481,8 +577,7 @@ def to_dict(self, data=True): @property def dims(self): - """Tuple of dimension names with which this variable is associated. - """ + """Tuple of dimension names with which this variable is associated.""" return self._dims @dims.setter @@ -511,14 +606,14 @@ def _broadcast_indexes(self, key): Parameters ----------- - key: int, slice, array, dict or tuple of integer, slices and arrays + key: int, slice, array-like, dict or tuple of integer, slice and array-like Any valid input for indexing. Returns ------- - dims: tuple + dims : tuple Dimension of the resultant variable. - indexers: IndexingTuple subclass + indexers : IndexingTuple subclass Tuple of integer, array-like, or slices to use when indexing self._data. The type of this argument indicates the type of indexing to perform, either basic, outer or vectorized. @@ -708,8 +803,7 @@ def __getitem__(self: VariableType, key) -> VariableType: return self._finalize_indexing_result(dims, data) def _finalize_indexing_result(self: VariableType, dims, data) -> VariableType: - """Used by IndexVariable to return IndexVariable objects when possible. - """ + """Used by IndexVariable to return IndexVariable objects when possible.""" return type(self)(dims, data, self._attrs, self._encoding, fastpath=True) def _getitem_with_mask(self, key, fill_value=dtypes.NA): @@ -728,7 +822,7 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA): dims, indexer, new_order = self._broadcast_indexes(key) if self.size: - if isinstance(self._data, dask_array_type): + if is_duck_dask_array(self._data): # dask's indexing is faster this way; also vindex does not # support negative indices yet: # https://github.com/dask/dask/pull/2967 @@ -785,8 +879,7 @@ def __setitem__(self, key, value): @property def attrs(self) -> Dict[Hashable, Any]: - """Dictionary of local attributes on this variable. - """ + """Dictionary of local attributes on this variable.""" if self._attrs is None: self._attrs = {} return self._attrs @@ -797,8 +890,7 @@ def attrs(self, value: Mapping[Hashable, Any]) -> None: @property def encoding(self): - """Dictionary of encodings on this variable. - """ + """Dictionary of encodings on this variable.""" if self._encoding is None: self._encoding = {} return self._encoding @@ -858,7 +950,7 @@ def copy(self, deep=True, data=None): >>> var.copy(data=[0.1, 0.2, 0.3]) - array([ 0.1, 0.2, 0.3]) + array([0.1, 0.2, 0.3]) >>> var array([7, 2, 3]) @@ -875,13 +967,8 @@ def copy(self, deep=True, data=None): data = indexing.MemoryCachedArray(data.array) if deep: - if hasattr(data, "__array_function__") or isinstance( - data, dask_array_type - ): - data = data.copy() - elif not isinstance(data, PandasIndexAdapter): - # pandas.Index is immutable - data = np.array(data) + data = copy.deepcopy(data) + else: data = as_compatible_data(data) if self.shape != data.shape: @@ -930,7 +1017,7 @@ def chunks(self): _array_counter = itertools.count() - def chunk(self, chunks=None, name=None, lock=False): + def chunk(self, chunks={}, name=None, lock=False): """Coerce this array's data into a dask arrays with the given chunks. If this variable is a non-dask array, it will be converted to dask @@ -960,14 +1047,19 @@ def chunk(self, chunks=None, name=None, lock=False): import dask import dask.array as da + if chunks is None: + warnings.warn( + "None value for 'chunks' is deprecated. " + "It will raise an error in the future. Use instead '{}'", + category=FutureWarning, + ) + chunks = {} + if utils.is_dict_like(chunks): chunks = {self.get_axis_num(dim): chunk for dim, chunk in chunks.items()} - if chunks is None: - chunks = self.chunks or self.shape - data = self._data - if isinstance(data, da.Array): + if is_duck_dask_array(data): data = data.rechunk(chunks) else: if isinstance(data, indexing.ExplicitlyIndexed): @@ -1004,7 +1096,7 @@ def _as_sparse(self, sparse_format=_default, fill_value=dtypes.NA): """ import sparse - # TODO what to do if dask-backended? + # TODO: what to do if dask-backended? if fill_value is dtypes.NA: dtype, fill_value = dtypes.maybe_promote(self.dtype) else: @@ -1013,9 +1105,9 @@ def _as_sparse(self, sparse_format=_default, fill_value=dtypes.NA): if sparse_format is _default: sparse_format = "coo" try: - as_sparse = getattr(sparse, "as_{}".format(sparse_format.lower())) + as_sparse = getattr(sparse, f"as_{sparse_format.lower()}") except AttributeError: - raise ValueError("{} is not a valid sparse format".format(sparse_format)) + raise ValueError(f"{sparse_format} is not a valid sparse format") data = as_sparse(self.data.astype(dtype), fill_value=fill_value) return self._replace(data=data) @@ -1041,10 +1133,10 @@ def isel( **indexers : {dim: indexer, ...} Keyword arguments with names matching dimensions and values given by integers, slice objects or arrays. - missing_dims : {"raise", "warn", "ignore"}, default "raise" + missing_dims : {"raise", "warn", "ignore"}, default: "raise" What to do if dimensions that should be selected from are not present in the DataArray: - - "exception": raise an exception + - "raise": raise an exception - "warning": raise a warning, and ignore the missing dimensions - "ignore": ignore the missing dimensions @@ -1114,7 +1206,7 @@ def _shift_one_dim(self, dim, count, fill_value=dtypes.NA): constant_values=fill_value, ) - if isinstance(data, dask_array_type): + if is_duck_dask_array(data): # chunked data should come out with the same chunks; this makes # it feasible to combine shifted and unshifted data # TODO: remove this once dask.array automatically aligns chunks @@ -1134,7 +1226,7 @@ def shift(self, shifts=None, fill_value=dtypes.NA, **shifts_kwargs): left. fill_value: scalar, optional Value to use for newly missing values - **shifts_kwargs: + **shifts_kwargs The keyword arguments form of ``shifts``. One of shifts or shifts_kwargs must be provided. @@ -1182,26 +1274,27 @@ def pad( Parameters ---------- - pad_width: Mapping with the form of {dim: (pad_before, pad_after)} - Number of values padded along each dimension. + pad_width : mapping of hashable to tuple of int + Mapping with the form of {dim: (pad_before, pad_after)} + describing the number of values padded along each dimension. {dim: pad} is a shortcut for pad_before = pad_after = pad - mode: (str) + mode : str, default: "constant" See numpy / Dask docs - stat_length : int, tuple or mapping of the form {dim: tuple} + stat_length : int, tuple or mapping of hashable to tuple Used in 'maximum', 'mean', 'median', and 'minimum'. Number of values at edge of each axis used to calculate the statistic value. - constant_values : scalar, tuple or mapping of the form {dim: tuple} + constant_values : scalar, tuple or mapping of hashable to tuple Used in 'constant'. The values to set the padded values for each axis. - end_values : scalar, tuple or mapping of the form {dim: tuple} + end_values : scalar, tuple or mapping of hashable to tuple Used in 'linear_ramp'. The values used for the ending value of the linear_ramp and that will form the edge of the padded array. - reflect_type : {'even', 'odd'}, optional - Used in 'reflect', and 'symmetric'. The 'even' style is the + reflect_type : {"even", "odd"}, optional + Used in "reflect", and "symmetric". The "even" style is the default with an unaltered reflection around the edge value. For - the 'odd' style, the extended part of the array is created by + the "odd" style, the extended part of the array is created by subtracting the reflected values from two times the edge value. - **pad_width_kwargs: + **pad_width_kwargs One of pad_width or pad_width_kwargs must be provided. Returns @@ -1229,7 +1322,7 @@ def pad( if isinstance(end_values, dict): end_values = self._pad_options_dim_to_index(end_values) - # workaround for bug in Dask's default value of stat_length https://github.com/dask/dask/issues/5303 + # workaround for bug in Dask's default value of stat_length https://github.com/dask/dask/issues/5303 if stat_length is None and mode in ["maximum", "mean", "median", "minimum"]: stat_length = [(n, n) for n in self.data.shape] # type: ignore @@ -1272,7 +1365,7 @@ def _roll_one_dim(self, dim, count): data = duck_array_ops.concatenate(arrays, axis) - if isinstance(data, dask_array_type): + if is_duck_dask_array(data): # chunked data should come out with the same chunks; this makes # it feasible to combine shifted and unshifted data # TODO: remove this once dask.array automatically aligns chunks @@ -1286,11 +1379,11 @@ def roll(self, shifts=None, **shifts_kwargs): Parameters ---------- - shifts : mapping of the form {dim: offset} + shifts : mapping of hashable to int Integer offset to roll along each of the given dimensions. Positive offsets roll to the right; negative offsets roll to the left. - **shifts_kwargs: + **shifts_kwargs The keyword arguments form of ``shifts``. One of shifts or shifts_kwargs must be provided. @@ -1428,10 +1521,11 @@ def stack(self, dimensions=None, **dimensions_kwargs): Parameters ---------- - dimensions : Mapping of form new_name=(dim1, dim2, ...) - Names of new dimensions, and the existing dimensions that they - replace. - **dimensions_kwargs: + dimensions : mapping of hashable to tuple of hashable + Mapping of form new_name=(dim1, dim2, ...) describing the + names of new dimensions, and the existing dimensions that + they replace. + **dimensions_kwargs The keyword arguments form of ``dimensions``. One of dimensions or dimensions_kwargs must be provided. @@ -1488,10 +1582,11 @@ def unstack(self, dimensions=None, **dimensions_kwargs): Parameters ---------- - dimensions : mapping of the form old_dim={dim1: size1, ...} - Names of existing dimensions, and the new dimensions and sizes + dimensions : mapping of hashable to mapping of hashable to int + Mapping of the form old_dim={dim1: size1, ...} describing the + names of existing dimensions, and the new dimensions and sizes that they map to. - **dimensions_kwargs: + **dimensions_kwargs The keyword arguments form of ``dimensions``. One of dimensions or dimensions_kwargs must be provided. @@ -1523,14 +1618,13 @@ def reduce( axis=None, keep_attrs=None, keepdims=False, - allow_lazy=None, **kwargs, ): """Reduce this array by applying `func` along some dimension(s). Parameters ---------- - func : function + func : callable Function which can be called in the form `func(x, axis=axis, **kwargs)` to return the result of reducing an np.ndarray over an integer valued axis. @@ -1545,7 +1639,7 @@ def reduce( If True, the variable's attributes (`attrs`) will be copied from the original object to the new one. If False (default), the new object will be returned without attributes. - keepdims : bool, default False + keepdims : bool, default: False If True, the dimensions which are reduced are left in the result as dimensions of size one **kwargs : dict @@ -1565,20 +1659,14 @@ def reduce( if dim is not None: axis = self.get_axis_num(dim) - if allow_lazy is not None: - warnings.warn( - "allow_lazy is deprecated and will be removed in version 0.16.0. It is now True by default.", - DeprecationWarning, + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", r"Mean of empty slice", category=RuntimeWarning ) - else: - allow_lazy = True - - input_data = self.data if allow_lazy else self.values - - if axis is not None: - data = func(input_data, axis=axis, **kwargs) - else: - data = func(input_data, **kwargs) + if axis is not None: + data = func(self.data, axis=axis, **kwargs) + else: + data = func(self.data, **kwargs) if getattr(data, "shape", ()) == self.shape: dims = self.dims @@ -1615,7 +1703,7 @@ def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False): Parameters ---------- - variables : iterable of Array + variables : iterable of Variable Arrays to stack together. Each variable is expected to have matching dimensions and shape except for along the stacked dimension. @@ -1625,7 +1713,7 @@ def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False): existing dimension name, in which case the location of the dimension is unchanged. Where to insert the new dimension is determined by the first variable. - positions : None or list of integer arrays, optional + positions : None or list of array-like, optional List of integer arrays which specifies the integer positions to which to assign each dataset along the concatenated dimension. If not supplied, objects are concatenated in the provided order. @@ -1707,8 +1795,7 @@ def broadcast_equals(self, other, equiv=duck_array_ops.array_equiv): return self.equals(other, equiv=equiv) def identical(self, other, equiv=duck_array_ops.array_equiv): - """Like equals, but also checks attributes. - """ + """Like equals, but also checks attributes.""" try: return utils.dict_equiv(self.attrs, other.attrs) and self.equals( other, equiv=equiv @@ -1734,12 +1821,12 @@ def quantile( Parameters ---------- - q : float in range of [0,1] (or sequence of floats) + q : float or sequence of float Quantile to compute, which must be between 0 and 1 inclusive. dim : str or sequence of str, optional Dimension(s) over which to apply quantile. - interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'} + interpolation : {"linear", "lower", "higher", "midpoint", "nearest"}, default: "linear" This optional parameter specifies the interpolation method to use when the desired quantile lies between two data points ``i < j``: @@ -1800,7 +1887,7 @@ def _wrapper(npa, **kwargs): exclude_dims=set(dim), output_core_dims=[["quantile"]], output_dtypes=[np.float64], - output_sizes={"quantile": len(q)}, + dask_gufunc_kwargs=dict(output_sizes={"quantile": len(q)}), dask="parallelized", kwargs={"q": q, "axis": axis, "interpolation": interpolation}, ) @@ -1843,7 +1930,7 @@ def rank(self, dim, pct=False): data = self.data - if isinstance(data, dask_array_type): + if is_duck_dask_array(data): raise TypeError( "rank does not work for arrays stored as dask " "arrays. Load the data via .compute() or .load() " @@ -1870,16 +1957,19 @@ def rolling_window( Parameters ---------- - dim: str - Dimension over which to compute rolling_window - window: int + dim : str + Dimension over which to compute rolling_window. + For nd-rolling, should be list of dimensions. + window : int Window size of the rolling - window_dim: str + For nd-rolling, should be list of integers. + window_dim : str New name of the window dimension. - center: boolean. default False. + For nd-rolling, should be list of integers. + center : bool, default: False If True, pad fill_value for both ends. Otherwise, pad in the head of the axis. - fill_value: + fill_value value to be filled. Returns @@ -1892,15 +1982,29 @@ def rolling_window( Examples -------- >>> v = Variable(("a", "b"), np.arange(8).reshape((2, 4))) - >>> v.rolling_window(x, "b", 3, "window_dim") + >>> v.rolling_window("b", 3, "window_dim") - array([[[nan, nan, 0], [nan, 0, 1], [0, 1, 2], [1, 2, 3]], - [[nan, nan, 4], [nan, 4, 5], [4, 5, 6], [5, 6, 7]]]) - - >>> v.rolling_window(x, "b", 3, "window_dim", center=True) + array([[[nan, nan, 0.], + [nan, 0., 1.], + [ 0., 1., 2.], + [ 1., 2., 3.]], + + [[nan, nan, 4.], + [nan, 4., 5.], + [ 4., 5., 6.], + [ 5., 6., 7.]]]) + + >>> v.rolling_window("b", 3, "window_dim", center=True) - array([[[nan, 0, 1], [0, 1, 2], [1, 2, 3], [2, 3, nan]], - [[nan, 4, 5], [4, 5, 6], [5, 6, 7], [6, 7, nan]]]) + array([[[nan, 0., 1.], + [ 0., 1., 2.], + [ 1., 2., 3.], + [ 2., 3., nan]], + + [[nan, 4., 5.], + [ 4., 5., 6.], + [ 5., 6., 7.], + [ 6., 7., nan]]]) """ if fill_value is dtypes.NA: # np.nan is passed dtype, fill_value = dtypes.maybe_promote(self.dtype) @@ -1909,19 +2013,27 @@ def rolling_window( dtype = self.dtype array = self.data - new_dims = self.dims + (window_dim,) + if isinstance(dim, list): + assert len(dim) == len(window) + assert len(dim) == len(window_dim) + assert len(dim) == len(center) + else: + dim = [dim] + window = [window] + window_dim = [window_dim] + center = [center] + axis = [self.get_axis_num(d) for d in dim] + new_dims = self.dims + tuple(window_dim) return Variable( new_dims, duck_array_ops.rolling_window( - array, - axis=self.get_axis_num(dim), - window=window, - center=center, - fill_value=fill_value, + array, axis=axis, window=window, center=center, fill_value=fill_value ), ) - def coarsen(self, windows, func, boundary="exact", side="left", **kwargs): + def coarsen( + self, windows, func, boundary="exact", side="left", keep_attrs=None, **kwargs + ): """ Apply reduction function. """ @@ -1929,13 +2041,22 @@ def coarsen(self, windows, func, boundary="exact", side="left", **kwargs): if not windows: return self.copy() + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + + if keep_attrs: + _attrs = self.attrs + else: + _attrs = None + reshaped, axes = self._coarsen_reshape(windows, boundary, side) if isinstance(func, str): name = func func = getattr(duck_array_ops, name, None) if func is None: raise NameError(f"{name} is not a valid method.") - return self._replace(data=func(reshaped, axis=axes, **kwargs)) + + return self._replace(data=func(reshaped, axis=axes, **kwargs), attrs=_attrs) def _coarsen_reshape(self, windows, boundary, side): """ @@ -2000,11 +2121,76 @@ def _coarsen_reshape(self, windows, boundary, side): else: shape.append(variable.shape[i]) - keep_attrs = _get_keep_attrs(default=False) - variable.attrs = variable._attrs if keep_attrs else {} - return variable.data.reshape(shape), tuple(axes) + def isnull(self, keep_attrs: bool = None): + """Test each value in the array for whether it is a missing value. + + Returns + ------- + isnull : Variable + Same type and shape as object, but the dtype of the data is bool. + + See Also + -------- + pandas.isnull + + Examples + -------- + >>> var = xr.Variable("x", [1, np.nan, 3]) + >>> var + + array([ 1., nan, 3.]) + >>> var.isnull() + + array([False, True, False]) + """ + from .computation import apply_ufunc + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + + return apply_ufunc( + duck_array_ops.isnull, + self, + dask="allowed", + keep_attrs=keep_attrs, + ) + + def notnull(self, keep_attrs: bool = None): + """Test each value in the array for whether it is not a missing value. + + Returns + ------- + notnull : Variable + Same type and shape as object, but the dtype of the data is bool. + + See Also + -------- + pandas.notnull + + Examples + -------- + >>> var = xr.Variable("x", [1, np.nan, 3]) + >>> var + + array([ 1., nan, 3.]) + >>> var.notnull() + + array([ True, False, True]) + """ + from .computation import apply_ufunc + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + + return apply_ufunc( + duck_array_ops.notnull, + self, + dask="allowed", + keep_attrs=keep_attrs, + ) + @property def real(self): return type(self)(self.dims, self.data.real, self._attrs) @@ -2020,8 +2206,14 @@ def __array_wrap__(self, obj, context=None): def _unary_op(f): @functools.wraps(f) def func(self, *args, **kwargs): + keep_attrs = kwargs.pop("keep_attrs", None) + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=True) with np.errstate(all="ignore"): - return self.__array_wrap__(f(self.data, *args, **kwargs)) + result = self.__array_wrap__(f(self.data, *args, **kwargs)) + if keep_attrs: + result.attrs = self.attrs + return result return func @@ -2053,7 +2245,7 @@ def func(self, other): raise TypeError("cannot add a Dataset to a Variable in-place") self_data, other_data, dims = _broadcast_compat_data(self, other) if dims != self.dims: - raise ValueError("dimensions cannot change for in-place " "operations") + raise ValueError("dimensions cannot change for in-place operations") with np.errstate(all="ignore"): self.values = f(self_data, other_data) return self @@ -2061,7 +2253,7 @@ def func(self, other): return func def _to_numeric(self, offset=None, datetime_unit=None, dtype=float): - """ A (private) method to convert datetime array to numeric dtype + """A (private) method to convert datetime array to numeric dtype See duck_array_ops.datetime_to_numeric """ numeric_array = duck_array_ops.datetime_to_numeric( @@ -2069,6 +2261,166 @@ def _to_numeric(self, offset=None, datetime_unit=None, dtype=float): ) return type(self)(self.dims, numeric_array, self._attrs) + def _unravel_argminmax( + self, + argminmax: str, + dim: Union[Hashable, Sequence[Hashable], None], + axis: Union[int, None], + keep_attrs: Optional[bool], + skipna: Optional[bool], + ) -> Union["Variable", Dict[Hashable, "Variable"]]: + """Apply argmin or argmax over one or more dimensions, returning the result as a + dict of DataArray that can be passed directly to isel. + """ + if dim is None and axis is None: + warnings.warn( + "Behaviour of argmin/argmax with neither dim nor axis argument will " + "change to return a dict of indices of each dimension. To get a " + "single, flat index, please use np.argmin(da.data) or " + "np.argmax(da.data) instead of da.argmin() or da.argmax().", + DeprecationWarning, + stacklevel=3, + ) + + argminmax_func = getattr(duck_array_ops, argminmax) + + if dim is ...: + # In future, should do this also when (dim is None and axis is None) + dim = self.dims + if ( + dim is None + or axis is not None + or not isinstance(dim, Sequence) + or isinstance(dim, str) + ): + # Return int index if single dimension is passed, and is not part of a + # sequence + return self.reduce( + argminmax_func, dim=dim, axis=axis, keep_attrs=keep_attrs, skipna=skipna + ) + + # Get a name for the new dimension that does not conflict with any existing + # dimension + newdimname = "_unravel_argminmax_dim_0" + count = 1 + while newdimname in self.dims: + newdimname = f"_unravel_argminmax_dim_{count}" + count += 1 + + stacked = self.stack({newdimname: dim}) + + result_dims = stacked.dims[:-1] + reduce_shape = tuple(self.sizes[d] for d in dim) + + result_flat_indices = stacked.reduce(argminmax_func, axis=-1, skipna=skipna) + + result_unravelled_indices = duck_array_ops.unravel_index( + result_flat_indices.data, reduce_shape + ) + + result = { + d: Variable(dims=result_dims, data=i) + for d, i in zip(dim, result_unravelled_indices) + } + + if keep_attrs is None: + keep_attrs = _get_keep_attrs(default=False) + if keep_attrs: + for v in result.values(): + v.attrs = self.attrs + + return result + + def argmin( + self, + dim: Union[Hashable, Sequence[Hashable]] = None, + axis: int = None, + keep_attrs: bool = None, + skipna: bool = None, + ) -> Union["Variable", Dict[Hashable, "Variable"]]: + """Index or indices of the minimum of the Variable over one or more dimensions. + If a sequence is passed to 'dim', then result returned as dict of Variables, + which can be passed directly to isel(). If a single str is passed to 'dim' then + returns a Variable with dtype int. + + If there are multiple minima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : hashable, sequence of hashable or ..., optional + The dimensions over which to find the minimum. By default, finds minimum over + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will return a dict with indices for all + dimensions; to return a dict with all dimensions now, pass '...'. + axis : int, optional + Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments + can be supplied. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : Variable or dict of Variable + + See also + -------- + DataArray.argmin, DataArray.idxmin + """ + return self._unravel_argminmax("argmin", dim, axis, keep_attrs, skipna) + + def argmax( + self, + dim: Union[Hashable, Sequence[Hashable]] = None, + axis: int = None, + keep_attrs: bool = None, + skipna: bool = None, + ) -> Union["Variable", Dict[Hashable, "Variable"]]: + """Index or indices of the maximum of the Variable over one or more dimensions. + If a sequence is passed to 'dim', then result returned as dict of Variables, + which can be passed directly to isel(). If a single str is passed to 'dim' then + returns a Variable with dtype int. + + If there are multiple maxima, the indices of the first one found will be + returned. + + Parameters + ---------- + dim : hashable, sequence of hashable or ..., optional + The dimensions over which to find the maximum. By default, finds maximum over + all dimensions - for now returning an int for backward compatibility, but + this is deprecated, in future will return a dict with indices for all + dimensions; to return a dict with all dimensions now, pass '...'. + axis : int, optional + Axis over which to apply `argmin`. Only one of the 'dim' and 'axis' arguments + can be supplied. + keep_attrs : bool, optional + If True, the attributes (`attrs`) will be copied from the original + object to the new one. If False (default), the new object will be + returned without attributes. + skipna : bool, optional + If True, skip missing values (as marked by NaN). By default, only + skips missing values for float dtypes; other dtypes either do not + have a sentinel missing value (int) or skipna=True has not been + implemented (object, datetime64 or timedelta64). + + Returns + ------- + result : Variable or dict of Variable + + See also + -------- + DataArray.argmax, DataArray.idxmax + """ + return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna) + ops.inject_all_ops_and_reduce_methods(Variable) @@ -2120,7 +2472,7 @@ def values(self, values): f"Please use DataArray.assign_coords, Dataset.assign_coords or Dataset.assign as appropriate." ) - def chunk(self, chunks=None, name=None, lock=False): + def chunk(self, chunks={}, name=None, lock=False): # Dummy - do not chunk. This method is invoked e.g. by Dataset.chunk() return self.copy(deep=False) @@ -2172,6 +2524,9 @@ def concat(cls, variables, dim="concat_dim", positions=None, shortcut=False): indices = nputils.inverse_permutation(np.concatenate(positions)) data = data.take(indices) + # keep as str if possible as pandas.Index uses object (converts to numpy array) + data = maybe_coerce_to_str(data, variables) + attrs = dict(first_var.attrs) if not shortcut: for var in variables: @@ -2356,7 +2711,7 @@ def concat(variables, dim="concat_dim", positions=None, shortcut=False): Parameters ---------- - variables : iterable of Array + variables : iterable of Variable Arrays to stack together. Each variable is expected to have matching dimensions and shape except for along the stacked dimension. @@ -2366,7 +2721,7 @@ def concat(variables, dim="concat_dim", positions=None, shortcut=False): existing dimension name, in which case the location of the dimension is unchanged. Where to insert the new dimension is determined by the first variable. - positions : None or list of integer arrays, optional + positions : None or list of array-like, optional List of integer arrays which specifies the integer positions to which to assign each dataset along the concatenated dimension. If not supplied, objects are concatenated in the provided order. @@ -2412,7 +2767,7 @@ def assert_unique_multiindex_level_names(variables): duplicate_names = [v for v in level_names.values() if len(v) > 1] if duplicate_names: - conflict_str = "\n".join([", ".join(v) for v in duplicate_names]) + conflict_str = "\n".join(", ".join(v) for v in duplicate_names) raise ValueError("conflicting MultiIndex level name(s):\n%s" % conflict_str) # Check confliction between level names and dimensions GH:2299 for k, v in variables.items(): diff --git a/xarray/core/weighted.py b/xarray/core/weighted.py index 996d2e4c43e..dbd4e1ad103 100644 --- a/xarray/core/weighted.py +++ b/xarray/core/weighted.py @@ -1,7 +1,9 @@ from typing import TYPE_CHECKING, Hashable, Iterable, Optional, Union, overload +from . import duck_array_ops from .computation import dot from .options import _get_keep_attrs +from .pycompat import is_duck_dask_array if TYPE_CHECKING: from .dataarray import DataArray, Dataset @@ -72,11 +74,11 @@ class Weighted: def __init__(self, obj: "DataArray", weights: "DataArray") -> None: ... - @overload # noqa: F811 - def __init__(self, obj: "Dataset", weights: "DataArray") -> None: # noqa: F811 + @overload + def __init__(self, obj: "Dataset", weights: "DataArray") -> None: ... - def __init__(self, obj, weights): # noqa: F811 + def __init__(self, obj, weights): """ Create a Weighted object @@ -100,12 +102,25 @@ def __init__(self, obj, weights): # noqa: F811 if not isinstance(weights, DataArray): raise ValueError("`weights` must be a DataArray") - if weights.isnull().any(): - raise ValueError( - "`weights` cannot contain missing values. " - "Missing values can be replaced by `weights.fillna(0)`." + def _weight_check(w): + # Ref https://github.com/pydata/xarray/pull/4559/files#r515968670 + if duck_array_ops.isnull(w).any(): + raise ValueError( + "`weights` cannot contain missing values. " + "Missing values can be replaced by `weights.fillna(0)`." + ) + return w + + if is_duck_dask_array(weights.data): + # assign to copy - else the check is not triggered + weights = weights.copy( + data=weights.data.map_blocks(_weight_check, dtype=weights.dtype), + deep=False, ) + else: + _weight_check(weights.data) + self.obj = obj self.weights = weights @@ -118,7 +133,7 @@ def _reduce( ) -> "DataArray": """reduce using dot; equivalent to (da * weights).sum(dim, skipna) - for internal use only + for internal use only """ # need to infer dims as we use `dot` @@ -142,7 +157,14 @@ def _sum_of_weights( # we need to mask data values that are nan; else the weights are wrong mask = da.notnull() - sum_of_weights = self._reduce(mask, self.weights, dim=dim, skipna=False) + # bool -> int, because ``xr.dot([True, True], [True, True])`` -> True + # (and not 2); GH4074 + if self.weights.dtype == bool: + sum_of_weights = self._reduce( + mask, self.weights.astype(int), dim=dim, skipna=False + ) + else: + sum_of_weights = self._reduce(mask, self.weights, dim=dim, skipna=False) # 0-weights are not valid valid_weights = sum_of_weights != 0.0 diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index ea037c1a2c2..6d942e1b0fa 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -38,7 +38,7 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide): if not hue_is_numeric and (hue_style == "continuous"): raise ValueError( - "Cannot create a colorbar for a non numeric" " coordinate: " + hue + f"Cannot create a colorbar for a non numeric coordinate: {hue}" ) if add_guide is None or add_guide is True: @@ -54,9 +54,7 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide): add_colorbar = False if hue_style is not None and hue_style not in ["discrete", "continuous"]: - raise ValueError( - "hue_style must be either None, 'discrete' " "or 'continuous'." - ) + raise ValueError("hue_style must be either None, 'discrete' or 'continuous'.") if hue: hue_label = label_from_attrs(ds[hue]) @@ -131,7 +129,7 @@ def _parse_size(data, norm): elif isinstance(norm, tuple): norm = mpl.colors.Normalize(*norm) elif not isinstance(norm, mpl.colors.Normalize): - err = "``size_norm`` must be None, tuple, " "or Normalize object." + err = "``size_norm`` must be None, tuple, or Normalize object." raise ValueError(err) norm.clip = True @@ -170,14 +168,14 @@ def _dsplot(plotfunc): ---------- ds : Dataset - x, y : string + x, y : str Variable names for x, y axis. hue: str, optional Variable by which to color scattered points hue_style: str, optional Can be either 'discrete' (legend) or 'continuous' (color bar). - markersize: str, optional (scatter only) - Variably by which to vary size of scattered points + markersize: str, optional + scatter only. Variable by which to vary size of scattered points. size_norm: optional Either None or 'Norm' instance to normalize the 'markersize' variable. add_guide: bool, optional @@ -185,13 +183,13 @@ def _dsplot(plotfunc): - for "discrete", build a legend. This is the default for non-numeric `hue` variables. - for "continuous", build a colorbar - row : string, optional + row : str, optional If passed, make row faceted plots on this dimension name - col : string, optional + col : str, optional If passed, make column faceted plots on this dimension name - col_wrap : integer, optional + col_wrap : int, optional Use together with ``col`` to wrap faceted plots - ax : matplotlib axes, optional + ax : matplotlib axes object, optional If None, uses the current axis. Not applicable when using facets. subplot_kws : dict, optional Dictionary of keyword arguments for matplotlib subplots. Only applies @@ -205,21 +203,23 @@ def _dsplot(plotfunc): norm : ``matplotlib.colors.Normalize`` instance, optional If the ``norm`` has vmin or vmax specified, the corresponding kwarg must be None. - vmin, vmax : floats, optional + vmin, vmax : float, optional Values to anchor the colormap, otherwise they are inferred from the data and other keyword arguments. When a diverging dataset is inferred, setting one of these values will fix the other by symmetry around ``center``. Setting both values prevents use of a diverging colormap. If discrete levels are provided as an explicit list, both of these values are ignored. - cmap : matplotlib colormap name or object, optional - The mapping from data values to color space. If not provided, this - will be either be ``viridis`` (if the function infers a sequential - dataset) or ``RdBu_r`` (if the function infers a diverging dataset). - When `Seaborn` is installed, ``cmap`` may also be a `seaborn` - color palette. If ``cmap`` is seaborn color palette and the plot type - is not ``contour`` or ``contourf``, ``levels`` must also be specified. - colors : discrete colors to plot, optional + cmap : str or colormap, optional + The mapping from data values to color space. Either a + matplotlib colormap name or object. If not provided, this will + be either ``viridis`` (if the function infers a sequential + dataset) or ``RdBu_r`` (if the function infers a diverging + dataset). When `Seaborn` is installed, ``cmap`` may also be a + `seaborn` color palette. If ``cmap`` is seaborn color palette + and the plot type is not ``contour`` or ``contourf``, ``levels`` + must also be specified. + colors : color-like or list of color-like, optional A single color or a list of colors. If the plot type is not ``contour`` or ``contourf``, the ``levels`` argument is required. center : float, optional @@ -229,7 +229,7 @@ def _dsplot(plotfunc): robust : bool, optional If True and ``vmin`` or ``vmax`` are absent, the colormap range is computed with 2nd and 98th percentiles instead of the extreme values. - extend : {'neither', 'both', 'min', 'max'}, optional + extend : {"neither", "both", "min", "max"}, optional How to draw arrows extending the colorbar beyond its limits. If not provided, extend is inferred from vmin, vmax and the data limits. levels : int or list-like object, optional @@ -291,7 +291,7 @@ def newplotfunc( allargs = locals().copy() allargs["plotfunc"] = globals()[plotfunc.__name__] allargs["data"] = ds - # TODO dcherian: why do I need to remove kwargs? + # remove kwargs to avoid passing the information twice for arg in ["meta_data", "kwargs", "ds"]: del allargs[arg] @@ -337,11 +337,7 @@ def newplotfunc( ax.set_ylabel(meta_data.get("ylabel")) if meta_data["add_legend"]: - ax.legend( - handles=primitive, - labels=list(meta_data["hue"].values), - title=meta_data.get("hue_label", None), - ) + ax.legend(handles=primitive, title=meta_data.get("hue_label", None)) if meta_data["add_colorbar"]: cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs if "label" not in cbar_kwargs: @@ -426,7 +422,10 @@ def scatter(ds, x, y, ax, **kwargs): if hue_style == "discrete": primitive = [] - for label in np.unique(data["hue"].values): + # use pd.unique instead of np.unique because that keeps the order of the labels, + # which is important to keep them in sync with the ones used in + # FacetGrid.add_legend + for label in pd.unique(data["hue"].values.ravel()): mask = data["hue"] == label if data["sizes"] is not None: kwargs.update(s=data["sizes"].where(mask, drop=True).values.flatten()) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 819eded694e..58b38251352 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -131,7 +131,7 @@ def __init__( ncol = len(data[col]) nfacet = nrow * ncol if col_wrap is not None: - warnings.warn("Ignoring col_wrap since both col and row " "were passed") + warnings.warn("Ignoring col_wrap since both col and row were passed") elif row and not col: single_group = row elif not row and col: @@ -306,9 +306,11 @@ def map_dataarray_line( ) self._mappables.append(mappable) - _, _, hueplt, xlabel, ylabel, huelabel = _infer_line_data( + xplt, yplt, hueplt, huelabel = _infer_line_data( darray=self.data.loc[self.name_dicts.flat[0]], x=x, y=y, hue=hue ) + xlabel = label_from_attrs(xplt) + ylabel = label_from_attrs(yplt) self._hue_var = hueplt self._hue_label = huelabel @@ -410,11 +412,13 @@ def add_legend(self, **kwargs): self.fig.subplots_adjust(right=right) def add_colorbar(self, **kwargs): - """Draw a colorbar - """ + """Draw a colorbar""" kwargs = kwargs.copy() if self._cmap_extend is not None: kwargs.setdefault("extend", self._cmap_extend) + # dont pass extend as kwarg if it is in the mappable + if hasattr(self._mappables[-1], "extend"): + kwargs.pop("extend", None) if "label" not in kwargs: kwargs.setdefault("label", label_from_attrs(self.data)) self.cbar = self.fig.colorbar( diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 4657bee9415..8a57e17e5e8 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -14,6 +14,7 @@ from .facetgrid import _easy_facetgrid from .utils import ( _add_colorbar, + _assert_valid_xy, _ensure_plottable, _infer_interval_breaks, _infer_xy_labels, @@ -29,19 +30,17 @@ def _infer_line_data(darray, x, y, hue): - error_msg = "must be either None or one of ({:s})".format( - ", ".join([repr(dd) for dd in darray.dims]) - ) + ndims = len(darray.dims) - if x is not None and x not in darray.dims and x not in darray.coords: - raise ValueError("x " + error_msg) + if x is not None and y is not None: + raise ValueError("Cannot specify both x and y kwargs for line plots.") - if y is not None and y not in darray.dims and y not in darray.coords: - raise ValueError("y " + error_msg) + if x is not None: + _assert_valid_xy(darray, x, "x") - if x is not None and y is not None: - raise ValueError("You cannot specify both x and y kwargs" "for line plots.") + if y is not None: + _assert_valid_xy(darray, y, "y") if ndims == 1: huename = None @@ -63,7 +62,7 @@ def _infer_line_data(darray, x, y, hue): else: if x is None and y is None and hue is None: - raise ValueError("For 2D inputs, please" "specify either hue, x or y.") + raise ValueError("For 2D inputs, please specify either hue, x or y.") if y is None: xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue) @@ -108,10 +107,7 @@ def _infer_line_data(darray, x, y, hue): huelabel = label_from_attrs(darray[huename]) hueplt = darray[huename] - xlabel = label_from_attrs(xplt) - ylabel = label_from_attrs(yplt) - - return xplt, yplt, hueplt, xlabel, ylabel, huelabel + return xplt, yplt, hueplt, huelabel def plot( @@ -142,22 +138,21 @@ def plot( Parameters ---------- darray : DataArray - row : string, optional + row : str, optional If passed, make row faceted plots on this dimension name - col : string, optional + col : str, optional If passed, make column faceted plots on this dimension name - hue : string, optional + hue : str, optional If passed, make faceted line plots with hue on this dimension name - col_wrap : integer, optional + col_wrap : int, optional Use together with ``col`` to wrap faceted plots - ax : matplotlib axes, optional + ax : matplotlib.axes.Axes, optional If None, uses the current axis. Not applicable when using facets. - rtol : number, optional + rtol : float, optional Relative tolerance used to determine if the indexes are uniformly spaced. Usually a small positive number. subplot_kws : dict, optional - Dictionary of keyword arguments for matplotlib subplots. Only applies - to FacetGrid plotting. + Dictionary of keyword arguments for matplotlib subplots. **kwargs : optional Additional keyword arguments to matplotlib @@ -178,10 +173,10 @@ def plot( if ndims in [1, 2]: if row or col: + kwargs["subplot_kws"] = subplot_kws kwargs["row"] = row kwargs["col"] = col kwargs["col_wrap"] = col_wrap - kwargs["subplot_kws"] = subplot_kws if ndims == 1: plotfunc = line kwargs["hue"] = hue @@ -191,6 +186,7 @@ def plot( kwargs["hue"] = hue else: plotfunc = pcolormesh + kwargs["subplot_kws"] = subplot_kws else: if row or col or hue: raise ValueError(error_msg) @@ -252,7 +248,7 @@ def line( Dimension or coordinate for which you want multiple lines plotted. If plotting against a 2D coordinate, ``hue`` must be a dimension. x, y : string, optional - Dimensions or coordinates for x, y axis. + Dimension, coordinate or MultiIndex level for x, y axis. Only one of these may be specified. The other coordinate plots values from the DataArray on which this plot method is called. @@ -266,9 +262,9 @@ def line( yincrease : None, True, or False, optional Should the values on the y axes be increasing from top to bottom? if None, use the default for the matplotlib function. - add_legend : boolean, optional + add_legend : bool, optional Add legend with y axis coordinates (2D inputs only). - ``*args``, ``**kwargs`` : optional + *args, **kwargs : optional Additional arguments to matplotlib.pyplot.plot """ # Handle facetgrids first @@ -293,12 +289,14 @@ def line( assert "args" not in kwargs ax = get_axis(figsize, size, aspect, ax) - xplt, yplt, hueplt, xlabel, ylabel, hue_label = _infer_line_data(darray, x, y, hue) + xplt, yplt, hueplt, hue_label = _infer_line_data(darray, x, y, hue) # Remove pd.Intervals if contained in xplt.values and/or yplt.values. - xplt_val, yplt_val, xlabel, ylabel, kwargs = _resolve_intervals_1dplot( - xplt.values, yplt.values, xlabel, ylabel, kwargs + xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot( + xplt.values, yplt.values, kwargs ) + xlabel = label_from_attrs(xplt, extra=x_suffix) + ylabel = label_from_attrs(yplt, extra=y_suffix) _ensure_plottable(xplt_val, yplt_val) @@ -338,27 +336,27 @@ def step(darray, *args, where="pre", drawstyle=None, ds=None, **kwargs): Parameters ---------- - where : {'pre', 'post', 'mid'}, optional, default 'pre' + where : {"pre", "post", "mid"}, default: "pre" Define where the steps should be placed: - - 'pre': The y value is continued constantly to the left from + - "pre": The y value is continued constantly to the left from every *x* position, i.e. the interval ``(x[i-1], x[i]]`` has the value ``y[i]``. - - 'post': The y value is continued constantly to the right from + - "post": The y value is continued constantly to the right from every *x* position, i.e. the interval ``[x[i], x[i+1])`` has the value ``y[i]``. - - 'mid': Steps occur half-way between the *x* positions. + - "mid": Steps occur half-way between the *x* positions. Note that this parameter is ignored if one coordinate consists of :py:func:`pandas.Interval` values, e.g. as a result of :py:func:`xarray.Dataset.groupby_bins`. In this case, the actual boundaries of the interval are used. - ``*args``, ``**kwargs`` : optional + *args, **kwargs : optional Additional arguments following :py:func:`xarray.plot.line` """ if where not in {"pre", "post", "mid"}: - raise ValueError("'where' argument to step must be " "'pre', 'post' or 'mid'") + raise ValueError("'where' argument to step must be 'pre', 'post' or 'mid'") if ds is not None: if drawstyle is None: @@ -408,7 +406,7 @@ def hist( size : scalar, optional If provided, create a new figure for the plot with the given size. Height (in inches) of each plot. See also: ``aspect``. - ax : matplotlib axes object, optional + ax : matplotlib.axes.Axes, optional Axis on which to plot this figure. By default, use the current axis. Mutually exclusive with ``size`` and ``figsize``. **kwargs : optional @@ -446,6 +444,11 @@ def __init__(self, darray): def __call__(self, **kwargs): return plot(self._da, **kwargs) + # we can't use functools.wraps here since that also modifies the name / qualname + __doc__ = __call__.__doc__ = plot.__doc__ + __call__.__wrapped__ = plot # type: ignore + __call__.__annotations__ = plot.__annotations__ + @functools.wraps(hist) def hist(self, ax=None, **kwargs): return hist(self._da, ax=ax, **kwargs) @@ -459,6 +462,15 @@ def step(self, *args, **kwargs): return step(self._da, *args, **kwargs) +def override_signature(f): + def wrapper(func): + func.__wrapped__ = f + + return func + + return wrapper + + def _plot2d(plotfunc): """ Decorator for common 2d plotting logic @@ -490,7 +502,7 @@ def _plot2d(plotfunc): If passed, make row faceted plots on this dimension name col : string, optional If passed, make column faceted plots on this dimension name - col_wrap : integer, optional + col_wrap : int, optional Use together with ``col`` to wrap faceted plots xscale, yscale : 'linear', 'symlog', 'log', 'logit', optional Specifies scaling for the x- and y-axes respectively @@ -502,9 +514,9 @@ def _plot2d(plotfunc): yincrease : None, True, or False, optional Should the values on the y axes be increasing from top to bottom? if None, use the default for the matplotlib function. - add_colorbar : Boolean, optional + add_colorbar : bool, optional Adds colorbar to axis - add_labels : Boolean, optional + add_labels : bool, optional Use xarray metadata to label axes norm : ``matplotlib.colors.Normalize`` instance, optional If the ``norm`` has vmin or vmax specified, the corresponding kwarg @@ -533,7 +545,7 @@ def _plot2d(plotfunc): robust : bool, optional If True and ``vmin`` or ``vmax`` are absent, the colormap range is computed with 2nd and 98th percentiles instead of the extreme values. - extend : {'neither', 'both', 'min', 'max'}, optional + extend : {"neither", "both", "min", "max"}, optional How to draw arrows extending the colorbar beyond its limits. If not provided, extend is inferred from vmin, vmax and the data limits. levels : int or list-like object, optional @@ -549,8 +561,8 @@ def _plot2d(plotfunc): always infer intervals, unless the mesh is irregular and plotted on a map projection. subplot_kws : dict, optional - Dictionary of keyword arguments for matplotlib subplots. Only applies - to FacetGrid plotting. + Dictionary of keyword arguments for matplotlib subplots. Only used + for 2D and FacetGrid plots. cbar_ax : matplotlib Axes, optional Axes in which to draw the colorbar. cbar_kwargs : dict, optional @@ -568,6 +580,16 @@ def _plot2d(plotfunc): # Build on the original docstring plotfunc.__doc__ = f"{plotfunc.__doc__}\n{commondoc}" + # plotfunc and newplotfunc have different signatures: + # - plotfunc: (x, y, z, ax, **kwargs) + # - newplotfunc: (darray, x, y, **kwargs) + # where plotfunc accepts numpy arrays, while newplotfunc accepts a DataArray + # and variable names. newplotfunc also explicitly lists most kwargs, so we + # need to shorten it + def signature(darray, x, y, **kwargs): + pass + + @override_signature(signature) @functools.wraps(plotfunc) def newplotfunc( darray, @@ -716,11 +738,12 @@ def newplotfunc( if "imshow" == plotfunc.__name__ and isinstance(aspect, str): # forbid usage of mpl strings - raise ValueError( - "plt.imshow's `aspect` kwarg is not available " "in xarray" - ) + raise ValueError("plt.imshow's `aspect` kwarg is not available in xarray") + + if subplot_kws is None: + subplot_kws = dict() + ax = get_axis(figsize, size, aspect, ax, **subplot_kws) - ax = get_axis(figsize, size, aspect, ax) primitive = plotfunc( xplt, yplt, @@ -746,7 +769,7 @@ def newplotfunc( elif cbar_ax is not None or cbar_kwargs: # inform the user about keywords which aren't used raise ValueError( - "cbar_ax and cbar_kwargs can't be used with " "add_colorbar=False." + "cbar_ax and cbar_kwargs can't be used with add_colorbar=False." ) # origin kwarg overrides yincrease @@ -852,7 +875,7 @@ def imshow(x, y, z, ax, **kwargs): if x.ndim != 1 or y.ndim != 1: raise ValueError( - "imshow requires 1D coordinates, try using " "pcolormesh or contour(f)" + "imshow requires 1D coordinates, try using pcolormesh or contour(f)" ) # Centering the pixels- Assumes uniform spacing diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index c3512828888..16c67e154fc 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -268,7 +268,7 @@ def _determine_cmap_params( cmap = OPTIONS["cmap_sequential"] # Handle discrete levels - if levels is not None and norm is None: + if levels is not None: if is_scalar(levels): if user_minmax: levels = np.linspace(vmin, vmax, levels) @@ -291,6 +291,12 @@ def _determine_cmap_params( cmap, newnorm = _build_discrete_cmap(cmap, levels, extend, filled) norm = newnorm if norm is None else norm + # vmin & vmax needs to be None if norm is passed + # TODO: always return a norm with vmin and vmax + if norm is not None: + vmin = None + vmax = None + return dict( vmin=vmin, vmax=vmax, cmap=cmap, extend=extend, levels=levels, norm=norm ) @@ -360,7 +366,9 @@ def _infer_xy_labels(darray, x, y, imshow=False, rgb=None): darray must be a 2 dimensional data array, or 3d for imshow only. """ - assert x is None or x != y + if (x is not None) and (x == y): + raise ValueError("x and y cannot be equal.") + if imshow and darray.ndim == 3: return _infer_xy_labels_3d(darray, x, y, rgb) @@ -369,27 +377,53 @@ def _infer_xy_labels(darray, x, y, imshow=False, rgb=None): raise ValueError("DataArray must be 2d") y, x = darray.dims elif x is None: - if y not in darray.dims and y not in darray.coords: - raise ValueError("y must be a dimension name if x is not supplied") + _assert_valid_xy(darray, y, "y") x = darray.dims[0] if y == darray.dims[1] else darray.dims[1] elif y is None: - if x not in darray.dims and x not in darray.coords: - raise ValueError("x must be a dimension name if y is not supplied") + _assert_valid_xy(darray, x, "x") y = darray.dims[0] if x == darray.dims[1] else darray.dims[1] - elif any(k not in darray.coords and k not in darray.dims for k in (x, y)): - raise ValueError("x and y must be coordinate variables") + else: + _assert_valid_xy(darray, x, "x") + _assert_valid_xy(darray, y, "y") + + if ( + all(k in darray._level_coords for k in (x, y)) + and darray._level_coords[x] == darray._level_coords[y] + ): + raise ValueError("x and y cannot be levels of the same MultiIndex") + return x, y -def get_axis(figsize, size, aspect, ax): - import matplotlib as mpl - import matplotlib.pyplot as plt +def _assert_valid_xy(darray, xy, name): + """ + make sure x and y passed to plotting functions are valid + """ + + # MultiIndex cannot be plotted; no point in allowing them here + multiindex = {darray._level_coords[lc] for lc in darray._level_coords} + + valid_xy = ( + set(darray.dims) | set(darray.coords) | set(darray._level_coords) + ) - multiindex + + if xy not in valid_xy: + valid_xy_str = "', '".join(sorted(valid_xy)) + raise ValueError(f"{name} must be one of None, '{valid_xy_str}'") + + +def get_axis(figsize=None, size=None, aspect=None, ax=None, **kwargs): + try: + import matplotlib as mpl + import matplotlib.pyplot as plt + except ImportError: + raise ImportError("matplotlib is required for plot.utils.get_axis") if figsize is not None: if ax is not None: - raise ValueError("cannot provide both `figsize` and " "`ax` arguments") + raise ValueError("cannot provide both `figsize` and `ax` arguments") if size is not None: - raise ValueError("cannot provide both `figsize` and " "`size` arguments") + raise ValueError("cannot provide both `figsize` and `size` arguments") _, ax = plt.subplots(figsize=figsize) elif size is not None: if ax is not None: @@ -402,15 +436,18 @@ def get_axis(figsize, size, aspect, ax): elif aspect is not None: raise ValueError("cannot provide `aspect` argument without `size`") + if kwargs and ax is not None: + raise ValueError("cannot use subplot_kws with existing ax") + if ax is None: - ax = plt.gca() + ax = plt.gca(**kwargs) return ax def label_from_attrs(da, extra=""): - """ Makes informative labels if variable metadata (attrs) follows - CF conventions. """ + """Makes informative labels if variable metadata (attrs) follows + CF conventions.""" if da.attrs.get("long_name"): name = da.attrs["long_name"] @@ -466,26 +503,32 @@ def _interval_to_double_bound_points(xarray, yarray): return xarray, yarray -def _resolve_intervals_1dplot(xval, yval, xlabel, ylabel, kwargs): +def _resolve_intervals_1dplot(xval, yval, kwargs): """ Helper function to replace the values of x and/or y coordinate arrays containing pd.Interval with their mid-points or - for step plots - double points which double the length. """ + x_suffix = "" + y_suffix = "" # Is it a step plot? (see matplotlib.Axes.step) if kwargs.get("drawstyle", "").startswith("steps-"): + remove_drawstyle = False # Convert intervals to double points if _valid_other_type(np.array([xval, yval]), [pd.Interval]): raise TypeError("Can't step plot intervals against intervals.") if _valid_other_type(xval, [pd.Interval]): xval, yval = _interval_to_double_bound_points(xval, yval) + remove_drawstyle = True if _valid_other_type(yval, [pd.Interval]): yval, xval = _interval_to_double_bound_points(yval, xval) + remove_drawstyle = True # Remove steps-* to be sure that matplotlib is not confused - del kwargs["drawstyle"] + if remove_drawstyle: + del kwargs["drawstyle"] # Is it another kind of plot? else: @@ -493,13 +536,13 @@ def _resolve_intervals_1dplot(xval, yval, xlabel, ylabel, kwargs): # Convert intervals to mid points and adjust labels if _valid_other_type(xval, [pd.Interval]): xval = _interval_to_mid_points(xval) - xlabel += "_center" + x_suffix = "_center" if _valid_other_type(yval, [pd.Interval]): yval = _interval_to_mid_points(yval) - ylabel += "_center" + y_suffix = "_center" # return converted arguments - return xval, yval, xlabel, ylabel, kwargs + return xval, yval, x_suffix, y_suffix, kwargs def _resolve_intervals_2dplot(val, func_name): @@ -589,6 +632,10 @@ def _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params): else: cbar_kwargs.setdefault("cax", cbar_ax) + # dont pass extend as kwarg if it is in the mappable + if hasattr(primitive, "extend"): + cbar_kwargs.pop("extend") + fig = ax.get_figure() cbar = fig.colorbar(primitive, **cbar_kwargs) diff --git a/xarray/static/css/style.css b/xarray/static/css/style.css index 7e382de3b5b..373624b8a9d 100644 --- a/xarray/static/css/style.css +++ b/xarray/static/css/style.css @@ -13,11 +13,29 @@ --xr-background-color-row-odd: var(--jp-layout-color2, #eeeeee); } +html[theme=dark], +body.vscode-dark { + --xr-font-color0: rgba(255, 255, 255, 1); + --xr-font-color2: rgba(255, 255, 255, 0.54); + --xr-font-color3: rgba(255, 255, 255, 0.38); + --xr-border-color: #1F1F1F; + --xr-disabled-color: #515151; + --xr-background-color: #111111; + --xr-background-color-row-even: #111111; + --xr-background-color-row-odd: #313131; +} + .xr-wrap { + display: block; min-width: 300px; max-width: 700px; } +.xr-text-repr-fallback { + /* fallback to plain text repr when CSS is not injected (untrusted notebook) */ + display: none; +} + .xr-header { padding-top: 6px; padding-bottom: 6px; @@ -280,7 +298,8 @@ dl.xr-attrs { grid-template-columns: 125px auto; } -.xr-attrs dt, dd { +.xr-attrs dt, +.xr-attrs dd { padding: 0; margin: 0; float: left; diff --git a/xarray/static/html/icons-svg-inline.html b/xarray/static/html/icons-svg-inline.html index c44f89c4304..b0e837a26cd 100644 --- a/xarray/static/html/icons-svg-inline.html +++ b/xarray/static/html/icons-svg-inline.html @@ -1,13 +1,11 @@ -Show/Hide data repr -Show/Hide attributes diff --git a/xarray/testing.py b/xarray/testing.py index ac189f7e023..ca72a4bee8e 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -1,10 +1,11 @@ """Testing functions exposed to the user API""" +import functools from typing import Hashable, Set, Union import numpy as np import pandas as pd -from xarray.core import duck_array_ops, formatting +from xarray.core import duck_array_ops, formatting, utils from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset from xarray.core.indexes import default_indexes @@ -13,6 +14,8 @@ __all__ = ( "assert_allclose", "assert_chunks_equal", + "assert_duckarray_equal", + "assert_duckarray_allclose", "assert_equal", "assert_identical", ) @@ -123,31 +126,91 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True): """ __tracebackhide__ = True assert type(a) == type(b) - kwargs = dict(rtol=rtol, atol=atol, decode_bytes=decode_bytes) + + equiv = functools.partial( + _data_allclose_or_equiv, rtol=rtol, atol=atol, decode_bytes=decode_bytes + ) + equiv.__name__ = "allclose" + + def compat_variable(a, b): + a = getattr(a, "variable", a) + b = getattr(b, "variable", b) + + return a.dims == b.dims and (a._data is b._data or equiv(a.data, b.data)) + if isinstance(a, Variable): - assert a.dims == b.dims - allclose = _data_allclose_or_equiv(a.values, b.values, **kwargs) - assert allclose, f"{a.values}\n{b.values}" + allclose = compat_variable(a, b) + assert allclose, formatting.diff_array_repr(a, b, compat=equiv) elif isinstance(a, DataArray): - assert_allclose(a.variable, b.variable, **kwargs) - assert set(a.coords) == set(b.coords) - for v in a.coords.variables: - # can't recurse with this function as coord is sometimes a - # DataArray, so call into _data_allclose_or_equiv directly - allclose = _data_allclose_or_equiv( - a.coords[v].values, b.coords[v].values, **kwargs - ) - assert allclose, "{}\n{}".format(a.coords[v].values, b.coords[v].values) + allclose = utils.dict_equiv( + a.coords, b.coords, compat=compat_variable + ) and compat_variable(a.variable, b.variable) + assert allclose, formatting.diff_array_repr(a, b, compat=equiv) elif isinstance(a, Dataset): - assert set(a.data_vars) == set(b.data_vars) - assert set(a.coords) == set(b.coords) - for k in list(a.variables) + list(a.coords): - assert_allclose(a[k], b[k], **kwargs) - + allclose = a._coord_names == b._coord_names and utils.dict_equiv( + a.variables, b.variables, compat=compat_variable + ) + assert allclose, formatting.diff_dataset_repr(a, b, compat=equiv) else: raise TypeError("{} not supported by assertion comparison".format(type(a))) +def _format_message(x, y, err_msg, verbose): + diff = x - y + abs_diff = max(abs(diff)) + rel_diff = "not implemented" + + n_diff = int(np.count_nonzero(diff)) + n_total = diff.size + + fraction = f"{n_diff} / {n_total}" + percentage = float(n_diff / n_total * 100) + + parts = [ + "Arrays are not equal", + err_msg, + f"Mismatched elements: {fraction} ({percentage:.0f}%)", + f"Max absolute difference: {abs_diff}", + f"Max relative difference: {rel_diff}", + ] + if verbose: + parts += [ + f" x: {x!r}", + f" y: {y!r}", + ] + + return "\n".join(parts) + + +def assert_duckarray_allclose( + actual, desired, rtol=1e-07, atol=0, err_msg="", verbose=True +): + """ Like `np.testing.assert_allclose`, but for duckarrays. """ + __tracebackhide__ = True + + allclose = duck_array_ops.allclose_or_equiv(actual, desired, rtol=rtol, atol=atol) + assert allclose, _format_message(actual, desired, err_msg=err_msg, verbose=verbose) + + +def assert_duckarray_equal(x, y, err_msg="", verbose=True): + """ Like `np.testing.assert_array_equal`, but for duckarrays """ + __tracebackhide__ = True + + if not utils.is_duck_array(x) and not utils.is_scalar(x): + x = np.asarray(x) + + if not utils.is_duck_array(y) and not utils.is_scalar(y): + y = np.asarray(y) + + if (utils.is_duck_array(x) and utils.is_scalar(y)) or ( + utils.is_scalar(x) and utils.is_duck_array(y) + ): + equiv = (x == y).all() + else: + equiv = duck_array_ops.array_equiv(x, y) + assert equiv, _format_message(x, y, err_msg=err_msg, verbose=verbose) + + def assert_chunks_equal(a, b): """ Assert that chunksizes along chunked dimensions are equal. @@ -259,7 +322,7 @@ def _assert_dataset_invariants(ds: Dataset): assert isinstance(ds._attrs, (type(None), dict)) -def _assert_internal_invariants(xarray_obj: Union[DataArray, Dataset, Variable],): +def _assert_internal_invariants(xarray_obj: Union[DataArray, Dataset, Variable]): """Validate that an xarray object satisfies its own internal invariants. This exists for the benefit of xarray's own test suite, but may be useful diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 40c5cfa267c..7c18f1a8c8a 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -16,6 +16,11 @@ from xarray.core.duck_array_ops import allclose_or_equiv # noqa: F401 from xarray.core.indexing import ExplicitlyIndexed from xarray.core.options import set_options +from xarray.testing import ( # noqa: F401 + assert_chunks_equal, + assert_duckarray_allclose, + assert_duckarray_equal, +) # import mpl and change the backend before other mpl imports try: @@ -73,6 +78,9 @@ def LooseVersion(vstring): has_numbagg, requires_numbagg = _importorskip("numbagg") has_seaborn, requires_seaborn = _importorskip("seaborn") has_sparse, requires_sparse = _importorskip("sparse") +has_cartopy, requires_cartopy = _importorskip("cartopy") +# Need Pint 0.15 for __dask_tokenize__ tests for Quantity wrapped Dask Arrays +has_pint_0_15, requires_pint_0_15 = _importorskip("pint", minversion="0.15") # some special cases has_scipy_or_netCDF4 = has_scipy or has_netCDF4 @@ -88,6 +96,39 @@ def LooseVersion(vstring): dask.config.set(scheduler="single-threaded") + +class CountingScheduler: + """Simple dask scheduler counting the number of computes. + + Reference: https://stackoverflow.com/questions/53289286/""" + + def __init__(self, max_computes=0): + self.total_computes = 0 + self.max_computes = max_computes + + def __call__(self, dsk, keys, **kwargs): + self.total_computes += 1 + if self.total_computes > self.max_computes: + raise RuntimeError( + "Too many computes. Total: %d > max: %d." + % (self.total_computes, self.max_computes) + ) + return dask.get(dsk, keys, **kwargs) + + +@contextmanager +def dummy_context(): + yield None + + +def raise_if_dask_computes(max_computes=0): + # return a dummy context manager so that this can be used for non-dask objects + if not has_dask: + return dummy_context() + scheduler = CountingScheduler(max_computes) + return dask.config.set(scheduler=scheduler) + + flaky = pytest.mark.flaky network = pytest.mark.network diff --git a/xarray/tests/data/example.ict b/xarray/tests/data/example.ict index bc04888fb80..41bbfeb996c 100644 --- a/xarray/tests/data/example.ict +++ b/xarray/tests/data/example.ict @@ -28,4 +28,4 @@ Start_UTC, lat, lon, elev, TEST_ppbv, TESTM_ppbv 43200, 41.00000, -71.00000, 5, 1.2345, 2.220 46800, 42.00000, -72.00000, 15, 2.3456, -9999 50400, 42.00000, -73.00000, 20, 3.4567, -7777 -50400, 42.00000, -74.00000, 25, 4.5678, -8888 \ No newline at end of file +50400, 42.00000, -74.00000, 25, 4.5678, -8888 diff --git a/xarray/tests/test_accessor_dt.py b/xarray/tests/test_accessor_dt.py index b3640722106..984bfc763bc 100644 --- a/xarray/tests/test_accessor_dt.py +++ b/xarray/tests/test_accessor_dt.py @@ -1,3 +1,5 @@ +from distutils.version import LooseVersion + import numpy as np import pandas as pd import pytest @@ -6,13 +8,14 @@ from . import ( assert_array_equal, + assert_chunks_equal, assert_equal, assert_identical, + raise_if_dask_computes, raises_regex, requires_cftime, requires_dask, ) -from .test_dask import assert_chunks_equal, raise_if_dask_computes class TestDatetimeAccessor: @@ -66,10 +69,48 @@ def setup(self): ], ) def test_field_access(self, field): + + if LooseVersion(pd.__version__) >= "1.1.0" and field in ["week", "weekofyear"]: + data = self.times.isocalendar()["week"] + else: + data = getattr(self.times, field) + + expected = xr.DataArray(data, name=field, coords=[self.times], dims=["time"]) + + if field in ["week", "weekofyear"]: + with pytest.warns( + FutureWarning, match="dt.weekofyear and dt.week have been deprecated" + ): + actual = getattr(self.data.time.dt, field) + else: + actual = getattr(self.data.time.dt, field) + + assert_equal(expected, actual) + + @pytest.mark.parametrize( + "field, pandas_field", + [ + ("year", "year"), + ("week", "week"), + ("weekday", "day"), + ], + ) + def test_isocalendar(self, field, pandas_field): + + if LooseVersion(pd.__version__) < "1.1.0": + with raises_regex( + AttributeError, "'isocalendar' not available in pandas < 1.1.0" + ): + self.data.time.dt.isocalendar()[field] + return + + # pandas isocalendar has dtypy UInt32Dtype, convert to Int64 + expected = pd.Int64Index(getattr(self.times.isocalendar(), pandas_field)) expected = xr.DataArray( - getattr(self.times, field), name=field, coords=[self.times], dims=["time"] + expected, name=field, coords=[self.times], dims=["time"] ) - actual = getattr(self.data.time.dt, field) + + actual = self.data.time.dt.isocalendar()[field] assert_equal(expected, actual) def test_strftime(self): @@ -84,6 +125,7 @@ def test_not_datetime_type(self): with raises_regex(TypeError, "dt"): nontime_data.time.dt + @pytest.mark.filterwarnings("ignore:dt.weekofyear and dt.week have been deprecated") @requires_dask @pytest.mark.parametrize( "field", @@ -128,6 +170,39 @@ def test_dask_field_access(self, field): assert_chunks_equal(actual, dask_times_2d) assert_equal(actual.compute(), expected.compute()) + @requires_dask + @pytest.mark.parametrize( + "field", + [ + "year", + "week", + "weekday", + ], + ) + def test_isocalendar_dask(self, field): + import dask.array as da + + if LooseVersion(pd.__version__) < "1.1.0": + with raises_regex( + AttributeError, "'isocalendar' not available in pandas < 1.1.0" + ): + self.data.time.dt.isocalendar()[field] + return + + expected = getattr(self.times_data.dt.isocalendar(), field) + + dask_times_arr = da.from_array(self.times_arr, chunks=(5, 5, 50)) + dask_times_2d = xr.DataArray( + dask_times_arr, coords=self.data.coords, dims=self.data.dims, name="data" + ) + + with raise_if_dask_computes(): + actual = dask_times_2d.dt.isocalendar()[field] + + assert isinstance(actual.data, da.Array) + assert_chunks_equal(actual, dask_times_2d) + assert_equal(actual.compute(), expected.compute()) + @requires_dask @pytest.mark.parametrize( "method, parameters", @@ -346,6 +421,15 @@ def test_field_access(data, field): assert_equal(result, expected) +@requires_cftime +def test_isocalendar_cftime(data): + + with raises_regex( + AttributeError, "'CFTimeIndex' object has no attribute 'isocalendar'" + ): + data.time.dt.isocalendar() + + @requires_cftime @pytest.mark.filterwarnings("ignore::RuntimeWarning") def test_cftime_strftime_access(data): diff --git a/xarray/tests/test_accessor_str.py b/xarray/tests/test_accessor_str.py index a987d302202..e0cbdb7377a 100644 --- a/xarray/tests/test_accessor_str.py +++ b/xarray/tests/test_accessor_str.py @@ -596,7 +596,7 @@ def test_wrap(): ) # expected values - xp = xr.DataArray( + expected = xr.DataArray( [ "hello world", "hello world!", @@ -610,15 +610,29 @@ def test_wrap(): ] ) - rs = values.str.wrap(12, break_long_words=True) - assert_equal(rs, xp) + result = values.str.wrap(12, break_long_words=True) + assert_equal(result, expected) # test with pre and post whitespace (non-unicode), NaN, and non-ascii # Unicode values = xr.DataArray([" pre ", "\xac\u20ac\U00008000 abadcafe"]) - xp = xr.DataArray([" pre", "\xac\u20ac\U00008000 ab\nadcafe"]) - rs = values.str.wrap(6) - assert_equal(rs, xp) + expected = xr.DataArray([" pre", "\xac\u20ac\U00008000 ab\nadcafe"]) + result = values.str.wrap(6) + assert_equal(result, expected) + + +def test_wrap_kwargs_passed(): + # GH4334 + + values = xr.DataArray(" hello world ") + + result = values.str.wrap(7) + expected = xr.DataArray(" hello\nworld") + assert_equal(result, expected) + + result = values.str.wrap(7, drop_whitespace=False) + expected = xr.DataArray(" hello\n world\n ") + assert_equal(result, expected) def test_get(dtype): @@ -642,6 +656,15 @@ def test_get(dtype): assert_equal(result, expected) +def test_get_default(dtype): + # GH4334 + values = xr.DataArray(["a_b", "c", ""]).astype(dtype) + + result = values.str.get(2, "default") + expected = xr.DataArray(["b", "default", "default"]).astype(dtype) + assert_equal(result, expected) + + def test_encode_decode(): data = xr.DataArray(["a", "b", "a\xe4"]) encoded = data.str.encode("utf-8") diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 3fde292c04f..3750c0715ae 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -30,6 +30,7 @@ save_mfdataset, ) from xarray.backends.common import robust_getitem +from xarray.backends.netcdf3 import _nc3_dtype_coercions from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding from xarray.backends.pydap_ import PydapDataStore from xarray.coding.variables import SerializationWarning @@ -86,6 +87,7 @@ dask_version = "10.0" ON_WINDOWS = sys.platform == "win32" +default_value = object() def open_example_dataset(name, *args, **kwargs): @@ -227,7 +229,27 @@ def __getitem__(self, key): class NetCDF3Only: - pass + netcdf3_formats = ("NETCDF3_CLASSIC", "NETCDF3_64BIT") + + @requires_scipy + def test_dtype_coercion_error(self): + """Failing dtype coercion should lead to an error""" + for dtype, format in itertools.product( + _nc3_dtype_coercions, self.netcdf3_formats + ): + if dtype == "bool": + # coerced upcast (bool to int8) ==> can never fail + continue + + # Using the largest representable value, create some data that will + # no longer compare equal after the coerced downcast + maxval = np.iinfo(dtype).max + x = np.array([0, 1, 2, maxval], dtype=dtype) + ds = Dataset({"x": ("t", x, {})}) + + with create_tmp_file(allow_cleanup_failure=False) as path: + with pytest.raises(ValueError, match="could not safely cast"): + ds.to_netcdf(path, format=format) class DatasetIOBase: @@ -296,9 +318,14 @@ def test_write_store(self): def check_dtypes_roundtripped(self, expected, actual): for k in expected.variables: expected_dtype = expected.variables[k].dtype - if isinstance(self, NetCDF3Only) and expected_dtype == "int64": - # downcast - expected_dtype = np.dtype("int32") + + # For NetCDF3, the backend should perform dtype coercion + if ( + isinstance(self, NetCDF3Only) + and str(expected_dtype) in _nc3_dtype_coercions + ): + expected_dtype = np.dtype(_nc3_dtype_coercions[str(expected_dtype)]) + actual_dtype = actual.variables[k].dtype # TODO: check expected behavior for string dtypes more carefully string_kinds = {"O", "S", "U"} @@ -581,10 +608,6 @@ def test_orthogonal_indexing(self): actual = on_disk.isel(**indexers) assert_identical(expected, actual) - @pytest.mark.xfail( - not has_dask, - reason="the code for indexing without dask handles negative steps in slices incorrectly", - ) def test_vectorized_indexing(self): in_memory = create_test_data() with self.roundtrip(in_memory) as on_disk: @@ -649,6 +672,29 @@ def multiple_indexing(indexers): ] multiple_indexing(indexers) + @pytest.mark.xfail( + reason="zarr without dask handles negative steps in slices incorrectly", + ) + def test_vectorized_indexing_negative_step(self): + # use dask explicitly when present + if has_dask: + open_kwargs = {"chunks": {}} + else: + open_kwargs = None + in_memory = create_test_data() + + def multiple_indexing(indexers): + # make sure a sequence of lazy indexings certainly works. + with self.roundtrip(in_memory, open_kwargs=open_kwargs) as on_disk: + actual = on_disk["var3"] + expected = in_memory["var3"] + for ind in indexers: + actual = actual.isel(**ind) + expected = expected.isel(**ind) + # make sure the array is not yet loaded into memory + assert not actual.variable._in_memory + assert_identical(expected, actual.load()) + # with negative step slice. indexers = [ { @@ -858,7 +904,7 @@ def test_roundtrip_endian(self): "x": np.arange(3, 10, dtype=">i2"), "y": np.arange(3, 20, dtype=" 0: + # check for initial spaces + assert s[:len_intro_str] == " " * len_intro_str + + +@requires_cftime_1_1_0 +@pytest.mark.parametrize("periods", [22, 50, 100]) +def test_cftimeindex_repr_101_shorter(periods): + index_101 = xr.cftime_range(start="2000", periods=101) + index_periods = xr.cftime_range(start="2000", periods=periods) + index_101_repr_str = index_101.__repr__() + index_periods_repr_str = index_periods.__repr__() + assert len(index_101_repr_str) < len(index_periods_repr_str) + + @requires_cftime def test_parse_array_of_cftime_strings(): from cftime import DatetimeNoLeap @@ -1046,3 +1185,73 @@ def test_asi8_distant_date(): result = index.asi8 expected = np.array([1000000 * 86400 * 400 * 8000 + 12345 * 1000000 + 123456]) np.testing.assert_array_equal(result, expected) + + +@requires_cftime_1_1_0 +def test_infer_freq_valid_types(): + cf_indx = xr.cftime_range("2000-01-01", periods=3, freq="D") + assert xr.infer_freq(cf_indx) == "D" + assert xr.infer_freq(xr.DataArray(cf_indx)) == "D" + + pd_indx = pd.date_range("2000-01-01", periods=3, freq="D") + assert xr.infer_freq(pd_indx) == "D" + assert xr.infer_freq(xr.DataArray(pd_indx)) == "D" + + pd_td_indx = pd.timedelta_range(start="1D", periods=3, freq="D") + assert xr.infer_freq(pd_td_indx) == "D" + assert xr.infer_freq(xr.DataArray(pd_td_indx)) == "D" + + +@requires_cftime_1_1_0 +def test_infer_freq_invalid_inputs(): + # Non-datetime DataArray + with pytest.raises(ValueError, match="must contain datetime-like objects"): + xr.infer_freq(xr.DataArray([0, 1, 2])) + + indx = xr.cftime_range("1990-02-03", periods=4, freq="MS") + # 2D DataArray + with pytest.raises(ValueError, match="must be 1D"): + xr.infer_freq(xr.DataArray([indx, indx])) + + # CFTimeIndex too short + with pytest.raises(ValueError, match="Need at least 3 dates to infer frequency"): + xr.infer_freq(indx[:2]) + + # Non-monotonic input + assert xr.infer_freq(indx[np.array([0, 2, 1, 3])]) is None + + # Non-unique input + assert xr.infer_freq(indx[np.array([0, 1, 1, 2])]) is None + + # No unique frequency (here 1st step is MS, second is 2MS) + assert xr.infer_freq(indx[np.array([0, 1, 3])]) is None + + # Same, but for QS + indx = xr.cftime_range("1990-02-03", periods=4, freq="QS") + assert xr.infer_freq(indx[np.array([0, 1, 3])]) is None + + +@requires_cftime_1_1_0 +@pytest.mark.parametrize( + "freq", + [ + "300AS-JAN", + "A-DEC", + "AS-JUL", + "2AS-FEB", + "Q-NOV", + "3QS-DEC", + "MS", + "4M", + "7D", + "D", + "30H", + "5T", + "40S", + ], +) +@pytest.mark.parametrize("calendar", _CFTIME_CALENDARS) +def test_infer_freq(freq, calendar): + indx = xr.cftime_range("2000-01-01", periods=3, freq=freq, calendar=calendar) + out = xr.infer_freq(indx) + assert out == freq diff --git a/xarray/tests/test_coding.py b/xarray/tests/test_coding.py index 0f191049284..e0df7782aa7 100644 --- a/xarray/tests/test_coding.py +++ b/xarray/tests/test_coding.py @@ -8,7 +8,7 @@ from xarray.coding import variables from xarray.conventions import decode_cf_variable, encode_cf_variable -from . import assert_equal, assert_identical, requires_dask +from . import assert_allclose, assert_equal, assert_identical, requires_dask with suppress(ImportError): import dask.array as da @@ -105,3 +105,15 @@ def test_scaling_converts_to_float32(dtype): roundtripped = coder.decode(encoded) assert_identical(original, roundtripped) assert roundtripped.dtype == np.float32 + + +@pytest.mark.parametrize("scale_factor", (10, [10])) +@pytest.mark.parametrize("add_offset", (0.1, [0.1])) +def test_scaling_offset_as_list(scale_factor, add_offset): + # test for #4631 + encoding = dict(scale_factor=scale_factor, add_offset=add_offset) + original = xr.Variable(("x",), np.arange(10.0), encoding=encoding) + coder = variables.CFScaleOffsetCoder() + encoded = coder.encode(original) + roundtripped = coder.decode(encoded) + assert_allclose(original, roundtripped) diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index 00c34940ce4..dfd558f737e 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -6,7 +6,7 @@ import pytest from pandas.errors import OutOfBoundsDatetime -from xarray import DataArray, Dataset, Variable, coding, decode_cf +from xarray import DataArray, Dataset, Variable, coding, conventions, decode_cf from xarray.coding.times import ( cftime_to_nptime, decode_cf_datetime, @@ -54,6 +54,7 @@ ([[0]], "days since 1000-01-01"), (np.arange(2), "days since 1000-01-01"), (np.arange(0, 100000, 20000), "days since 1900-01-01"), + (np.arange(0, 100000, 20000), "days since 1-01-01"), (17093352.0, "hours since 1-1-1 00:00:0.0"), ([0.5, 1.5], "hours since 1900-01-01T00:00:00"), (0, "milliseconds since 2000-01-01T00:00:00"), @@ -85,6 +86,7 @@ def _all_cftime_date_types(): @requires_cftime +@pytest.mark.filterwarnings("ignore:Ambiguous reference date string") @pytest.mark.parametrize(["num_dates", "units", "calendar"], _CF_DATETIME_TESTS) def test_cf_datetime(num_dates, units, calendar): import cftime @@ -109,20 +111,16 @@ def test_cf_datetime(num_dates, units, calendar): # https://github.com/Unidata/netcdf4-python/issues/355 assert (abs_diff <= np.timedelta64(1, "s")).all() encoded, _, _ = coding.times.encode_cf_datetime(actual, units, calendar) - if "1-1-1" not in units: - # pandas parses this date very strangely, so the original - # units/encoding cannot be preserved in this case: - # (Pdb) pd.to_datetime('1-1-1 00:00:0.0') - # Timestamp('2001-01-01 00:00:00') + + assert_array_equal(num_dates, np.around(encoded, 1)) + if hasattr(num_dates, "ndim") and num_dates.ndim == 1 and "1000" not in units: + # verify that wrapping with a pandas.Index works + # note that it *does not* currently work to put + # non-datetime64 compatible dates into a pandas.Index + encoded, _, _ = coding.times.encode_cf_datetime( + pd.Index(actual), units, calendar + ) assert_array_equal(num_dates, np.around(encoded, 1)) - if hasattr(num_dates, "ndim") and num_dates.ndim == 1 and "1000" not in units: - # verify that wrapping with a pandas.Index works - # note that it *does not* currently work to even put - # non-datetime64 compatible dates into a pandas.Index - encoded, _, _ = coding.times.encode_cf_datetime( - pd.Index(actual), units, calendar - ) - assert_array_equal(num_dates, np.around(encoded, 1)) @requires_cftime @@ -222,9 +220,10 @@ def test_decode_non_standard_calendar_inside_timestamp_range(calendar): @requires_cftime @pytest.mark.parametrize("calendar", _ALL_CALENDARS) def test_decode_dates_outside_timestamp_range(calendar): - import cftime from datetime import datetime + import cftime + units = "days since 0001-01-01" times = [datetime(1, 4, 1, h) for h in range(1, 5)] time = cftime.date2num(times, units, calendar=calendar) @@ -358,9 +357,10 @@ def test_decode_nonstandard_calendar_multidim_time_inside_timestamp_range(calend @requires_cftime @pytest.mark.parametrize("calendar", _ALL_CALENDARS) def test_decode_multidim_time_outside_timestamp_range(calendar): - import cftime from datetime import datetime + import cftime + units = "days since 0001-01-01" times1 = [datetime(1, 4, day) for day in range(1, 6)] times2 = [datetime(1, 5, day) for day in range(1, 6)] @@ -389,15 +389,15 @@ def test_decode_multidim_time_outside_timestamp_range(calendar): @requires_cftime -@pytest.mark.parametrize("calendar", ["360_day", "all_leap", "366_day"]) -def test_decode_non_standard_calendar_single_element(calendar): +@pytest.mark.parametrize( + ("calendar", "num_time"), + [("360_day", 720058.0), ("all_leap", 732059.0), ("366_day", 732059.0)], +) +def test_decode_non_standard_calendar_single_element(calendar, num_time): import cftime units = "days since 0001-01-01" - dt = cftime.datetime(2001, 2, 29) - - num_time = cftime.date2num(dt, units, calendar) actual = coding.times.decode_cf_datetime(num_time, units, calendar=calendar) expected = np.asarray( @@ -432,6 +432,18 @@ def test_decode_360_day_calendar(): assert_array_equal(actual, expected) +@requires_cftime +def test_decode_abbreviation(): + """Test making sure we properly fall back to cftime on abbreviated units.""" + import cftime + + val = np.array([1586628000000.0]) + units = "msecs since 1970-01-01T00:00:00Z" + actual = coding.times.decode_cf_datetime(val, units) + expected = coding.times.cftime_to_nptime(cftime.num2date(val, units)) + assert_array_equal(actual, expected) + + @arm_xfail @requires_cftime @pytest.mark.parametrize( @@ -467,27 +479,36 @@ def test_decoded_cf_datetime_array_2d(): assert_array_equal(np.asarray(result), expected) +FREQUENCIES_TO_ENCODING_UNITS = { + "N": "nanoseconds", + "U": "microseconds", + "L": "milliseconds", + "S": "seconds", + "T": "minutes", + "H": "hours", + "D": "days", +} + + +@pytest.mark.parametrize(("freq", "units"), FREQUENCIES_TO_ENCODING_UNITS.items()) +def test_infer_datetime_units(freq, units): + dates = pd.date_range("2000", periods=2, freq=freq) + expected = f"{units} since 2000-01-01 00:00:00" + assert expected == coding.times.infer_datetime_units(dates) + + @pytest.mark.parametrize( ["dates", "expected"], [ - (pd.date_range("1900-01-01", periods=5), "days since 1900-01-01 00:00:00"), - ( - pd.date_range("1900-01-01 12:00:00", freq="H", periods=2), - "hours since 1900-01-01 12:00:00", - ), ( pd.to_datetime(["1900-01-01", "1900-01-02", "NaT"]), "days since 1900-01-01 00:00:00", ), - ( - pd.to_datetime(["1900-01-01", "1900-01-02T00:00:00.005"]), - "seconds since 1900-01-01 00:00:00", - ), (pd.to_datetime(["NaT", "1900-01-01"]), "days since 1900-01-01 00:00:00"), (pd.to_datetime(["NaT"]), "days since 1970-01-01 00:00:00"), ], ) -def test_infer_datetime_units(dates, expected): +def test_infer_datetime_units_with_NaT(dates, expected): assert expected == coding.times.infer_datetime_units(dates) @@ -523,6 +544,7 @@ def test_infer_cftime_datetime_units(calendar, date_args, expected): ("1h", "hours", np.int64(1)), ("1ms", "milliseconds", np.int64(1)), ("1us", "microseconds", np.int64(1)), + ("1ns", "nanoseconds", np.int64(1)), (["NaT", "0s", "1s"], None, [np.nan, 0, 1]), (["30m", "60m"], "hours", [0.5, 1.0]), ("NaT", "days", np.nan), @@ -914,3 +936,62 @@ def test_use_cftime_false_non_standard_calendar(calendar, units_year): units = f"days since {units_year}-01-01" with pytest.raises(OutOfBoundsDatetime): decode_cf_datetime(numerical_dates, units, calendar, use_cftime=False) + + +@requires_cftime +@pytest.mark.parametrize("calendar", _ALL_CALENDARS) +def test_decode_ambiguous_time_warns(calendar): + # GH 4422, 4506 + from cftime import num2date + + # we don't decode non-standard calendards with + # pandas so expect no warning will be emitted + is_standard_calendar = calendar in coding.times._STANDARD_CALENDARS + + dates = [1, 2, 3] + units = "days since 1-1-1" + expected = num2date(dates, units, calendar=calendar, only_use_cftime_datetimes=True) + + exp_warn_type = SerializationWarning if is_standard_calendar else None + + with pytest.warns(exp_warn_type) as record: + result = decode_cf_datetime(dates, units, calendar=calendar) + + if is_standard_calendar: + relevant_warnings = [ + r + for r in record.list + if str(r.message).startswith("Ambiguous reference date string: 1-1-1") + ] + assert len(relevant_warnings) == 1 + else: + assert not record + + np.testing.assert_array_equal(result, expected) + + +@pytest.mark.parametrize("encoding_units", FREQUENCIES_TO_ENCODING_UNITS.values()) +@pytest.mark.parametrize("freq", FREQUENCIES_TO_ENCODING_UNITS.keys()) +def test_encode_cf_datetime_defaults_to_correct_dtype(encoding_units, freq): + times = pd.date_range("2000", periods=3, freq=freq) + units = f"{encoding_units} since 2000-01-01" + encoded, _, _ = coding.times.encode_cf_datetime(times, units) + + numpy_timeunit = coding.times._netcdf_to_numpy_timeunit(encoding_units) + encoding_units_as_timedelta = np.timedelta64(1, numpy_timeunit) + if pd.to_timedelta(1, freq) >= encoding_units_as_timedelta: + assert encoded.dtype == np.int64 + else: + assert encoded.dtype == np.float64 + + +@pytest.mark.parametrize("freq", FREQUENCIES_TO_ENCODING_UNITS.keys()) +def test_encode_decode_roundtrip(freq): + # See GH 4045. Prior to GH 4684 this test would fail for frequencies of + # "S", "L", "U", and "N". + initial_time = pd.date_range("1678-01-01", periods=1) + times = initial_time.append(pd.date_range("1968", periods=2, freq=freq)) + variable = Variable(["time"], times) + encoded = conventions.encode_cf_variable(variable) + decoded = conventions.decode_cf_variable("time", encoded) + assert_equal(variable, decoded) diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index c3f981f10d1..109b78f05a9 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -4,14 +4,7 @@ import numpy as np import pytest -from xarray import ( - DataArray, - Dataset, - auto_combine, - combine_by_coords, - combine_nested, - concat, -) +from xarray import DataArray, Dataset, combine_by_coords, combine_nested, concat from xarray.core import dtypes from xarray.core.combine import ( _check_shape_tile_ids, @@ -176,7 +169,7 @@ def test_coord_not_monotonic(self): ds1 = Dataset({"x": [3, 2]}) with raises_regex( ValueError, - "Coordinate variable x is neither " "monotonically increasing nor", + "Coordinate variable x is neither monotonically increasing nor", ): _infer_concat_order_from_coords([ds1, ds0]) @@ -489,7 +482,7 @@ def test_concat_one_dim_merge_another(self): expected = data[["var1", "var2"]] actual = combine_nested(objs, concat_dim=[None, "dim2"]) - assert expected.identical(actual) + assert_identical(expected, actual) def test_auto_combine_2d(self): ds = create_test_data @@ -563,11 +556,11 @@ def test_invalid_hypercube_input(self): ds = create_test_data datasets = [[ds(0), ds(1), ds(2)], [ds(3), ds(4)]] - with raises_regex(ValueError, "sub-lists do not have " "consistent lengths"): + with raises_regex(ValueError, "sub-lists do not have consistent lengths"): combine_nested(datasets, concat_dim=["dim1", "dim2"]) datasets = [[ds(0), ds(1)], [[ds(3), ds(4)]]] - with raises_regex(ValueError, "sub-lists do not have " "consistent depths"): + with raises_regex(ValueError, "sub-lists do not have consistent depths"): combine_nested(datasets, concat_dim=["dim1", "dim2"]) datasets = [[ds(0), ds(1)], [ds(3), ds(4)]] @@ -608,18 +601,26 @@ def test_combine_concat_over_redundant_nesting(self): expected = Dataset({"x": [0]}) assert_identical(expected, actual) - @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {"a": 2, "b": 1}]) def test_combine_nested_fill_value(self, fill_value): datasets = [ - Dataset({"a": ("x", [2, 3]), "x": [1, 2]}), - Dataset({"a": ("x", [1, 2]), "x": [0, 1]}), + Dataset({"a": ("x", [2, 3]), "b": ("x", [-2, 1]), "x": [1, 2]}), + Dataset({"a": ("x", [1, 2]), "b": ("x", [3, -1]), "x": [0, 1]}), ] if fill_value == dtypes.NA: # if we supply the default, we expect the missing value for a # float array - fill_value = np.nan + fill_value_a = fill_value_b = np.nan + elif isinstance(fill_value, dict): + fill_value_a = fill_value["a"] + fill_value_b = fill_value["b"] + else: + fill_value_a = fill_value_b = fill_value expected = Dataset( - {"a": (("t", "x"), [[fill_value, 2, 3], [1, 2, fill_value]])}, + { + "a": (("t", "x"), [[fill_value_a, 2, 3], [1, 2, fill_value_a]]), + "b": (("t", "x"), [[fill_value_b, -2, 1], [3, -1, fill_value_b]]), + }, {"x": [0, 1, 2]}, ) actual = combine_nested(datasets, concat_dim="t", fill_value=fill_value) @@ -797,7 +798,7 @@ def test_check_for_impossible_ordering(self): ds0 = Dataset({"x": [0, 1, 5]}) ds1 = Dataset({"x": [2, 3]}) with raises_regex( - ValueError, "does not have monotonic global indexes" " along dimension x" + ValueError, "does not have monotonic global indexes along dimension x" ): combine_by_coords([ds1, ds0]) @@ -818,173 +819,6 @@ def test_combine_by_coords_incomplete_hypercube(self): combine_by_coords([x1, x2, x3], fill_value=None) -@pytest.mark.filterwarnings( - "ignore:In xarray version 0.15 `auto_combine` " "will be deprecated" -) -@pytest.mark.filterwarnings("ignore:Also `open_mfdataset` will no longer") -@pytest.mark.filterwarnings("ignore:The datasets supplied") -class TestAutoCombineOldAPI: - """ - Set of tests which check that old 1-dimensional auto_combine behaviour is - still satisfied. #2616 - """ - - def test_auto_combine(self): - objs = [Dataset({"x": [0]}), Dataset({"x": [1]})] - actual = auto_combine(objs) - expected = Dataset({"x": [0, 1]}) - assert_identical(expected, actual) - - actual = auto_combine([actual]) - assert_identical(expected, actual) - - objs = [Dataset({"x": [0, 1]}), Dataset({"x": [2]})] - actual = auto_combine(objs) - expected = Dataset({"x": [0, 1, 2]}) - assert_identical(expected, actual) - - # ensure auto_combine handles non-sorted variables - objs = [ - Dataset({"x": ("a", [0]), "y": ("a", [0])}), - Dataset({"y": ("a", [1]), "x": ("a", [1])}), - ] - actual = auto_combine(objs) - expected = Dataset({"x": ("a", [0, 1]), "y": ("a", [0, 1])}) - assert_identical(expected, actual) - - objs = [Dataset({"x": [0], "y": [0]}), Dataset({"y": [1], "x": [1]})] - with raises_regex(ValueError, "too many .* dimensions"): - auto_combine(objs) - - objs = [Dataset({"x": 0}), Dataset({"x": 1})] - with raises_regex(ValueError, "cannot infer dimension"): - auto_combine(objs) - - objs = [Dataset({"x": [0], "y": [0]}), Dataset({"x": [0]})] - with raises_regex(ValueError, "'y' is not present in all datasets"): - auto_combine(objs) - - def test_auto_combine_previously_failed(self): - # In the above scenario, one file is missing, containing the data for - # one year's data for one variable. - datasets = [ - Dataset({"a": ("x", [0]), "x": [0]}), - Dataset({"b": ("x", [0]), "x": [0]}), - Dataset({"a": ("x", [1]), "x": [1]}), - ] - expected = Dataset({"a": ("x", [0, 1]), "b": ("x", [0, np.nan])}, {"x": [0, 1]}) - actual = auto_combine(datasets) - assert_identical(expected, actual) - - # Your data includes "time" and "station" dimensions, and each year's - # data has a different set of stations. - datasets = [ - Dataset({"a": ("x", [2, 3]), "x": [1, 2]}), - Dataset({"a": ("x", [1, 2]), "x": [0, 1]}), - ] - expected = Dataset( - {"a": (("t", "x"), [[np.nan, 2, 3], [1, 2, np.nan]])}, {"x": [0, 1, 2]} - ) - actual = auto_combine(datasets, concat_dim="t") - assert_identical(expected, actual) - - def test_auto_combine_with_new_variables(self): - datasets = [Dataset({"x": 0}, {"y": 0}), Dataset({"x": 1}, {"y": 1, "z": 1})] - actual = auto_combine(datasets, "y") - expected = Dataset({"x": ("y", [0, 1])}, {"y": [0, 1], "z": 1}) - assert_identical(expected, actual) - - def test_auto_combine_no_concat(self): - objs = [Dataset({"x": 0}), Dataset({"y": 1})] - actual = auto_combine(objs) - expected = Dataset({"x": 0, "y": 1}) - assert_identical(expected, actual) - - objs = [Dataset({"x": 0, "y": 1}), Dataset({"y": np.nan, "z": 2})] - actual = auto_combine(objs) - expected = Dataset({"x": 0, "y": 1, "z": 2}) - assert_identical(expected, actual) - - data = Dataset({"x": 0}) - actual = auto_combine([data, data, data], concat_dim=None) - assert_identical(data, actual) - - # Single object, with a concat_dim explicitly provided - # Test the issue reported in GH #1988 - objs = [Dataset({"x": 0, "y": 1})] - dim = DataArray([100], name="baz", dims="baz") - actual = auto_combine(objs, concat_dim=dim) - expected = Dataset({"x": ("baz", [0]), "y": ("baz", [1])}, {"baz": [100]}) - assert_identical(expected, actual) - - # Just making sure that auto_combine is doing what is - # expected for non-scalar values, too. - objs = [Dataset({"x": ("z", [0, 1]), "y": ("z", [1, 2])})] - dim = DataArray([100], name="baz", dims="baz") - actual = auto_combine(objs, concat_dim=dim) - expected = Dataset( - {"x": (("baz", "z"), [[0, 1]]), "y": (("baz", "z"), [[1, 2]])}, - {"baz": [100]}, - ) - assert_identical(expected, actual) - - def test_auto_combine_order_by_appearance_not_coords(self): - objs = [ - Dataset({"foo": ("x", [0])}, coords={"x": ("x", [1])}), - Dataset({"foo": ("x", [1])}, coords={"x": ("x", [0])}), - ] - actual = auto_combine(objs) - expected = Dataset({"foo": ("x", [0, 1])}, coords={"x": ("x", [1, 0])}) - assert_identical(expected, actual) - - @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) - def test_auto_combine_fill_value(self, fill_value): - datasets = [ - Dataset({"a": ("x", [2, 3]), "x": [1, 2]}), - Dataset({"a": ("x", [1, 2]), "x": [0, 1]}), - ] - if fill_value == dtypes.NA: - # if we supply the default, we expect the missing value for a - # float array - fill_value = np.nan - expected = Dataset( - {"a": (("t", "x"), [[fill_value, 2, 3], [1, 2, fill_value]])}, - {"x": [0, 1, 2]}, - ) - actual = auto_combine(datasets, concat_dim="t", fill_value=fill_value) - assert_identical(expected, actual) - - -class TestAutoCombineDeprecation: - """ - Set of tests to check that FutureWarnings are correctly raised until the - deprecation cycle is complete. #2616 - """ - - def test_auto_combine_with_concat_dim(self): - objs = [Dataset({"x": [0]}), Dataset({"x": [1]})] - with pytest.warns(FutureWarning, match="`concat_dim`"): - auto_combine(objs, concat_dim="x") - - def test_auto_combine_with_merge_and_concat(self): - objs = [Dataset({"x": [0]}), Dataset({"x": [1]}), Dataset({"z": ((), 99)})] - with pytest.warns(FutureWarning, match="require both concatenation"): - auto_combine(objs) - - def test_auto_combine_with_coords(self): - objs = [ - Dataset({"foo": ("x", [0])}, coords={"x": ("x", [0])}), - Dataset({"foo": ("x", [1])}, coords={"x": ("x", [1])}), - ] - with pytest.warns(FutureWarning, match="supplied have global"): - auto_combine(objs) - - def test_auto_combine_without_coords(self): - objs = [Dataset({"foo": ("x", [0])}), Dataset({"foo": ("x", [1])})] - with pytest.warns(FutureWarning, match="supplied do not have global"): - auto_combine(objs) - - @requires_cftime def test_combine_by_coords_distant_cftime_dates(): # Regression test for https://github.com/pydata/xarray/issues/3535 @@ -1005,3 +839,20 @@ def test_combine_by_coords_distant_cftime_dates(): [0, 1, 2], dims=["time"], coords=[expected_time], name="a" ).to_dataset() assert_identical(result, expected) + + +@requires_cftime +def test_combine_by_coords_raises_for_differing_calendars(): + # previously failed with uninformative StopIteration instead of TypeError + # https://github.com/pydata/xarray/issues/4495 + + import cftime + + time_1 = [cftime.DatetimeGregorian(2000, 1, 1)] + time_2 = [cftime.DatetimeProlepticGregorian(2001, 1, 1)] + + da_1 = DataArray([0], dims=["time"], coords=[time_1], name="a").to_dataset() + da_2 = DataArray([1], dims=["time"], coords=[time_2], name="a").to_dataset() + + with raises_regex(TypeError, r"cannot compare .* \(different calendars\)"): + combine_by_coords([da_1, da_2]) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 4eed464d2dc..4890536a5d7 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1,13 +1,15 @@ import functools import operator import pickle +from distutils.version import LooseVersion import numpy as np import pandas as pd import pytest -from numpy.testing import assert_array_equal +from numpy.testing import assert_allclose, assert_array_equal import xarray as xr +from xarray.core.alignment import broadcast from xarray.core.computation import ( _UFuncSignature, apply_ufunc, @@ -22,11 +24,15 @@ from . import has_dask, raises_regex, requires_dask +dask = pytest.importorskip("dask") + def assert_identical(a, b): + """ A version of this function which accepts numpy arrays """ + from xarray.testing import assert_identical as assert_identical_ + if hasattr(a, "identical"): - msg = f"not identical:\n{a!r}\n{b!r}" - assert a.identical(b), msg + assert_identical_(a, b) else: assert_array_equal(a, b) @@ -41,6 +47,9 @@ def test_signature_properties(): assert sig.num_outputs == 1 assert str(sig) == "(x),(x,y)->(z)" assert sig.to_gufunc_string() == "(dim0),(dim0,dim1)->(dim2)" + assert ( + sig.to_gufunc_string(exclude_dims=set("x")) == "(dim0_0),(dim0_1,dim1)->(dim2)" + ) # dimension names matter assert _UFuncSignature([["x"]]) != _UFuncSignature([["y"]]) @@ -244,6 +253,21 @@ def func(x): assert_identical(out1, dataset) +@requires_dask +def test_apply_dask_parallelized_two_outputs(): + data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) + + def twice(obj): + def func(x): + return (x, x) + + return apply_ufunc(func, obj, output_core_dims=[[], []], dask="parallelized") + + out0, out1 = twice(data_array.chunk({"x": 1})) + assert_identical(data_array, out0) + assert_identical(data_array, out1) + + def test_apply_input_core_dimension(): def first_element(obj, dim): def func(x): @@ -451,10 +475,13 @@ def test_unified_dim_sizes(): "x": 1, "y": 2, } - assert unified_dim_sizes( - [xr.Variable(("x", "z"), [[1]]), xr.Variable(("y", "z"), [[1, 2], [3, 4]])], - exclude_dims={"z"}, - ) == {"x": 1, "y": 2} + assert ( + unified_dim_sizes( + [xr.Variable(("x", "z"), [[1]]), xr.Variable(("y", "z"), [[1, 2], [3, 4]])], + exclude_dims={"z"}, + ) + == {"x": 1, "y": 2} + ) # duplicate dimensions with pytest.raises(ValueError): @@ -680,29 +707,11 @@ def test_apply_dask_parallelized_errors(): array = da.ones((2, 2), chunks=(1, 1)) data_array = xr.DataArray(array, dims=("x", "y")) - with pytest.raises(NotImplementedError): - apply_ufunc( - identity, data_array, output_core_dims=[["z"], ["z"]], dask="parallelized" - ) - with raises_regex(ValueError, "dtypes"): - apply_ufunc(identity, data_array, dask="parallelized") - with raises_regex(TypeError, "list"): - apply_ufunc(identity, data_array, dask="parallelized", output_dtypes=float) - with raises_regex(ValueError, "must have the same length"): - apply_ufunc( - identity, data_array, dask="parallelized", output_dtypes=[float, float] - ) - with raises_regex(ValueError, "output_sizes"): - apply_ufunc( - identity, - data_array, - output_core_dims=[["z"]], - output_dtypes=[float], - dask="parallelized", - ) + # from apply_array_ufunc with raises_regex(ValueError, "at least one input is an xarray object"): apply_ufunc(identity, array, dask="parallelized") + # formerly from _apply_blockwise, now from apply_variable_ufunc with raises_regex(ValueError, "consists of multiple chunks"): apply_ufunc( identity, @@ -777,7 +786,7 @@ def func(x): output_core_dims=[["sign"]], dask="parallelized", output_dtypes=[obj.dtype], - output_sizes={"sign": 2}, + dask_gufunc_kwargs=dict(output_sizes={"sign": 2}), ) expected = stack_negative(data_array.compute()) @@ -789,6 +798,32 @@ def func(x): assert_identical(expected, actual) +@requires_dask +def test_apply_dask_new_output_sizes(): + ds = xr.Dataset({"foo": (["lon", "lat"], np.arange(10 * 10).reshape((10, 10)))}) + ds["bar"] = ds["foo"] + newdims = {"lon_new": 3, "lat_new": 6} + + def extract(obj): + def func(da): + return da[1:4, 1:7] + + return apply_ufunc( + func, + obj, + dask="parallelized", + input_core_dims=[["lon", "lat"]], + output_core_dims=[["lon_new", "lat_new"]], + dask_gufunc_kwargs=dict(output_sizes=newdims), + ) + + expected = extract(ds) + + actual = extract(ds.chunk()) + assert actual.dims == {"lon_new": 3, "lat_new": 6} + assert_identical(expected.chunk(), actual) + + def pandas_median(x): return pd.Series(x).median() @@ -804,6 +839,7 @@ def test_vectorize(): @requires_dask def test_vectorize_dask(): + # run vectorization in dask.array.gufunc by using `dask='parallelized'` data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) expected = xr.DataArray([1, 2], dims=["x"]) actual = apply_ufunc( @@ -817,9 +853,299 @@ def test_vectorize_dask(): assert_identical(expected, actual) +@requires_dask +def test_vectorize_dask_dtype(): + # ensure output_dtypes is preserved with vectorize=True + # GH4015 + + # integer + data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) + expected = xr.DataArray([1, 2], dims=["x"]) + actual = apply_ufunc( + pandas_median, + data_array.chunk({"x": 1}), + input_core_dims=[["y"]], + vectorize=True, + dask="parallelized", + output_dtypes=[int], + ) + assert_identical(expected, actual) + assert expected.dtype == actual.dtype + + # complex + data_array = xr.DataArray([[0 + 0j, 1 + 2j, 2 + 1j]], dims=("x", "y")) + expected = data_array.copy() + actual = apply_ufunc( + identity, + data_array.chunk({"x": 1}), + vectorize=True, + dask="parallelized", + output_dtypes=[complex], + ) + assert_identical(expected, actual) + assert expected.dtype == actual.dtype + + +@requires_dask +@pytest.mark.parametrize( + "data_array", + [ + xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")), + xr.DataArray([[0 + 0j, 1 + 2j, 2 + 1j]], dims=("x", "y")), + ], +) +def test_vectorize_dask_dtype_without_output_dtypes(data_array): + # ensure output_dtypes is preserved with vectorize=True + # GH4015 + + expected = data_array.copy() + actual = apply_ufunc( + identity, + data_array.chunk({"x": 1}), + vectorize=True, + dask="parallelized", + ) + + assert_identical(expected, actual) + assert expected.dtype == actual.dtype + + +@pytest.mark.xfail(LooseVersion(dask.__version__) < "2.3", reason="dask GH5274") +@requires_dask +def test_vectorize_dask_dtype_meta(): + # meta dtype takes precedence + data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) + expected = xr.DataArray([1, 2], dims=["x"]) + + actual = apply_ufunc( + pandas_median, + data_array.chunk({"x": 1}), + input_core_dims=[["y"]], + vectorize=True, + dask="parallelized", + output_dtypes=[int], + dask_gufunc_kwargs=dict(meta=np.ndarray((0, 0), dtype=float)), + ) + + assert_identical(expected, actual) + assert float == actual.dtype + + +def pandas_median_add(x, y): + # function which can consume input of unequal length + return pd.Series(x).median() + pd.Series(y).median() + + +def test_vectorize_exclude_dims(): + # GH 3890 + data_array_a = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) + data_array_b = xr.DataArray([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]], dims=("x", "y")) + + expected = xr.DataArray([3, 5], dims=["x"]) + actual = apply_ufunc( + pandas_median_add, + data_array_a, + data_array_b, + input_core_dims=[["y"], ["y"]], + vectorize=True, + exclude_dims=set("y"), + ) + assert_identical(expected, actual) + + +@requires_dask +def test_vectorize_exclude_dims_dask(): + # GH 3890 + data_array_a = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) + data_array_b = xr.DataArray([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]], dims=("x", "y")) + + expected = xr.DataArray([3, 5], dims=["x"]) + actual = apply_ufunc( + pandas_median_add, + data_array_a.chunk({"x": 1}), + data_array_b.chunk({"x": 1}), + input_core_dims=[["y"], ["y"]], + exclude_dims=set("y"), + vectorize=True, + dask="parallelized", + output_dtypes=[float], + ) + assert_identical(expected, actual) + + +def test_corr_only_dataarray(): + with pytest.raises(TypeError, match="Only xr.DataArray is supported"): + xr.corr(xr.Dataset(), xr.Dataset()) + + +def arrays_w_tuples(): + da = xr.DataArray( + np.random.random((3, 21, 4)), + coords={"time": pd.date_range("2000-01-01", freq="1D", periods=21)}, + dims=("a", "time", "x"), + ) + + arrays = [ + da.isel(time=range(0, 18)), + da.isel(time=range(2, 20)).rolling(time=3, center=True).mean(), + xr.DataArray([[1, 2], [1, np.nan]], dims=["x", "time"]), + xr.DataArray([[1, 2], [np.nan, np.nan]], dims=["x", "time"]), + ] + + array_tuples = [ + (arrays[0], arrays[0]), + (arrays[0], arrays[1]), + (arrays[1], arrays[1]), + (arrays[2], arrays[2]), + (arrays[2], arrays[3]), + (arrays[3], arrays[3]), + ] + + return arrays, array_tuples + + +@pytest.mark.parametrize("ddof", [0, 1]) +@pytest.mark.parametrize( + "da_a, da_b", + [arrays_w_tuples()[1][0], arrays_w_tuples()[1][1], arrays_w_tuples()[1][2]], +) +@pytest.mark.parametrize("dim", [None, "time"]) +def test_cov(da_a, da_b, dim, ddof): + if dim is not None: + + def np_cov_ind(ts1, ts2, a, x): + # Ensure the ts are aligned and missing values ignored + ts1, ts2 = broadcast(ts1, ts2) + valid_values = ts1.notnull() & ts2.notnull() + + # While dropping isn't ideal here, numpy will return nan + # if any segment contains a NaN. + ts1 = ts1.where(valid_values) + ts2 = ts2.where(valid_values) + + return np.ma.cov( + np.ma.masked_invalid(ts1.sel(a=a, x=x).data.flatten()), + np.ma.masked_invalid(ts2.sel(a=a, x=x).data.flatten()), + ddof=ddof, + )[0, 1] + + expected = np.zeros((3, 4)) + for a in [0, 1, 2]: + for x in [0, 1, 2, 3]: + expected[a, x] = np_cov_ind(da_a, da_b, a=a, x=x) + actual = xr.cov(da_a, da_b, dim=dim, ddof=ddof) + assert_allclose(actual, expected) + + else: + + def np_cov(ts1, ts2): + # Ensure the ts are aligned and missing values ignored + ts1, ts2 = broadcast(ts1, ts2) + valid_values = ts1.notnull() & ts2.notnull() + + ts1 = ts1.where(valid_values) + ts2 = ts2.where(valid_values) + + return np.ma.cov( + np.ma.masked_invalid(ts1.data.flatten()), + np.ma.masked_invalid(ts2.data.flatten()), + ddof=ddof, + )[0, 1] + + expected = np_cov(da_a, da_b) + actual = xr.cov(da_a, da_b, dim=dim, ddof=ddof) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize( + "da_a, da_b", + [arrays_w_tuples()[1][0], arrays_w_tuples()[1][1], arrays_w_tuples()[1][2]], +) +@pytest.mark.parametrize("dim", [None, "time"]) +def test_corr(da_a, da_b, dim): + if dim is not None: + + def np_corr_ind(ts1, ts2, a, x): + # Ensure the ts are aligned and missing values ignored + ts1, ts2 = broadcast(ts1, ts2) + valid_values = ts1.notnull() & ts2.notnull() + + ts1 = ts1.where(valid_values) + ts2 = ts2.where(valid_values) + + return np.ma.corrcoef( + np.ma.masked_invalid(ts1.sel(a=a, x=x).data.flatten()), + np.ma.masked_invalid(ts2.sel(a=a, x=x).data.flatten()), + )[0, 1] + + expected = np.zeros((3, 4)) + for a in [0, 1, 2]: + for x in [0, 1, 2, 3]: + expected[a, x] = np_corr_ind(da_a, da_b, a=a, x=x) + actual = xr.corr(da_a, da_b, dim) + assert_allclose(actual, expected) + + else: + + def np_corr(ts1, ts2): + # Ensure the ts are aligned and missing values ignored + ts1, ts2 = broadcast(ts1, ts2) + valid_values = ts1.notnull() & ts2.notnull() + + ts1 = ts1.where(valid_values) + ts2 = ts2.where(valid_values) + + return np.ma.corrcoef( + np.ma.masked_invalid(ts1.data.flatten()), + np.ma.masked_invalid(ts2.data.flatten()), + )[0, 1] + + expected = np_corr(da_a, da_b) + actual = xr.corr(da_a, da_b, dim) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize( + "da_a, da_b", + arrays_w_tuples()[1], +) +@pytest.mark.parametrize("dim", [None, "time", "x"]) +def test_covcorr_consistency(da_a, da_b, dim): + # Testing that xr.corr and xr.cov are consistent with each other + # 1. Broadcast the two arrays + da_a, da_b = broadcast(da_a, da_b) + # 2. Ignore the nans + valid_values = da_a.notnull() & da_b.notnull() + da_a = da_a.where(valid_values) + da_b = da_b.where(valid_values) + + expected = xr.cov(da_a, da_b, dim=dim, ddof=0) / ( + da_a.std(dim=dim) * da_b.std(dim=dim) + ) + actual = xr.corr(da_a, da_b, dim=dim) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize( + "da_a", + arrays_w_tuples()[0], +) +@pytest.mark.parametrize("dim", [None, "time", "x", ["time", "x"]]) +def test_autocov(da_a, dim): + # Testing that the autocovariance*(N-1) is ~=~ to the variance matrix + # 1. Ignore the nans + valid_values = da_a.notnull() + # Because we're using ddof=1, this requires > 1 value in each sample + da_a = da_a.where(valid_values.sum(dim=dim) > 1) + expected = ((da_a - da_a.mean(dim=dim)) ** 2).sum(dim=dim, skipna=True, min_count=1) + actual = xr.cov(da_a, da_a, dim=dim) * (valid_values.sum(dim) - 1) + assert_allclose(actual, expected) + + @requires_dask def test_vectorize_dask_new_output_dims(): # regression test for GH3574 + # run vectorization in dask.array.gufunc by using `dask='parallelized'` data_array = xr.DataArray([[0, 1, 2], [1, 2, 3]], dims=("x", "y")) func = lambda x: x[np.newaxis, ...] expected = data_array.expand_dims("z") @@ -830,10 +1156,33 @@ def test_vectorize_dask_new_output_dims(): vectorize=True, dask="parallelized", output_dtypes=[float], - output_sizes={"z": 1}, + dask_gufunc_kwargs=dict(output_sizes={"z": 1}), ).transpose(*expected.dims) assert_identical(expected, actual) + with raises_regex(ValueError, "dimension 'z1' in 'output_sizes' must correspond"): + apply_ufunc( + func, + data_array.chunk({"x": 1}), + output_core_dims=[["z"]], + vectorize=True, + dask="parallelized", + output_dtypes=[float], + dask_gufunc_kwargs=dict(output_sizes={"z1": 1}), + ) + + with raises_regex( + ValueError, "dimension 'z' in 'output_core_dims' needs corresponding" + ): + apply_ufunc( + func, + data_array.chunk({"x": 1}), + output_core_dims=[["z"]], + vectorize=True, + dask="parallelized", + output_dtypes=[float], + ) + def test_output_wrong_number(): variable = xr.Variable("x", np.arange(10)) @@ -946,7 +1295,6 @@ def test_dot(use_dask): da_a = da_a.chunk({"a": 3}) da_b = da_b.chunk({"a": 3}) da_c = da_c.chunk({"c": 3}) - actual = xr.dot(da_a, da_b, dims=["a", "b"]) assert actual.dims == ("c",) assert (actual.data == np.einsum("ij,ijk->k", a, b)).all() @@ -960,7 +1308,7 @@ def test_dot(use_dask): # for only a single array is passed without dims argument, just return # as is actual = xr.dot(da_a) - assert da_a.identical(actual) + assert_identical(da_a, actual) # test for variable actual = xr.dot(da_a.variable, da_b.variable) diff --git a/xarray/tests/test_concat.py b/xarray/tests/test_concat.py index e5038dd4af2..7416cab13ed 100644 --- a/xarray/tests/test_concat.py +++ b/xarray/tests/test_concat.py @@ -250,7 +250,9 @@ def test_concat_join_kwarg(self): assert_equal(actual, expected[join]) # regression test for #3681 - actual = concat([ds1.drop("x"), ds2.drop("x")], join="override", dim="y") + actual = concat( + [ds1.drop_vars("x"), ds2.drop_vars("x")], join="override", dim="y" + ) expected = Dataset( {"a": (("x", "y"), np.array([0, 0], ndmin=2))}, coords={"y": [0, 0.0001]} ) @@ -349,23 +351,55 @@ def test_concat_multiindex(self): assert expected.equals(actual) assert isinstance(actual.x.to_index(), pd.MultiIndex) - @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {"a": 2, "b": 1}]) def test_concat_fill_value(self, fill_value): datasets = [ - Dataset({"a": ("x", [2, 3]), "x": [1, 2]}), - Dataset({"a": ("x", [1, 2]), "x": [0, 1]}), + Dataset({"a": ("x", [2, 3]), "b": ("x", [-2, 1]), "x": [1, 2]}), + Dataset({"a": ("x", [1, 2]), "b": ("x", [3, -1]), "x": [0, 1]}), ] if fill_value == dtypes.NA: # if we supply the default, we expect the missing value for a # float array - fill_value = np.nan + fill_value_a = fill_value_b = np.nan + elif isinstance(fill_value, dict): + fill_value_a = fill_value["a"] + fill_value_b = fill_value["b"] + else: + fill_value_a = fill_value_b = fill_value expected = Dataset( - {"a": (("t", "x"), [[fill_value, 2, 3], [1, 2, fill_value]])}, + { + "a": (("t", "x"), [[fill_value_a, 2, 3], [1, 2, fill_value_a]]), + "b": (("t", "x"), [[fill_value_b, -2, 1], [3, -1, fill_value_b]]), + }, {"x": [0, 1, 2]}, ) actual = concat(datasets, dim="t", fill_value=fill_value) assert_identical(actual, expected) + @pytest.mark.parametrize("dtype", [str, bytes]) + @pytest.mark.parametrize("dim", ["x1", "x2"]) + def test_concat_str_dtype(self, dtype, dim): + + data = np.arange(4).reshape([2, 2]) + + da1 = Dataset( + { + "data": (["x1", "x2"], data), + "x1": [0, 1], + "x2": np.array(["a", "b"], dtype=dtype), + } + ) + da2 = Dataset( + { + "data": (["x1", "x2"], data), + "x1": np.array([1, 2]), + "x2": np.array(["c", "d"], dtype=dtype), + } + ) + actual = concat([da1, da2], dim=dim) + + assert np.issubdtype(actual.x2.dtype, dtype) + class TestConcatDataArray: def test_concat(self): @@ -515,6 +549,26 @@ def test_concat_combine_attrs_kwarg(self): actual = concat([da1, da2], dim="x", combine_attrs=combine_attrs) assert_identical(actual, expected[combine_attrs]) + @pytest.mark.parametrize("dtype", [str, bytes]) + @pytest.mark.parametrize("dim", ["x1", "x2"]) + def test_concat_str_dtype(self, dtype, dim): + + data = np.arange(4).reshape([2, 2]) + + da1 = DataArray( + data=data, + dims=["x1", "x2"], + coords={"x1": [0, 1], "x2": np.array(["a", "b"], dtype=dtype)}, + ) + da2 = DataArray( + data=data, + dims=["x1", "x2"], + coords={"x1": np.array([1, 2]), "x2": np.array(["c", "d"], dtype=dtype)}, + ) + actual = concat([da1, da2], dim=dim) + + assert np.issubdtype(actual.x2.dtype, dtype) + @pytest.mark.parametrize("attr1", ({"a": {"meta": [10, 20, 30]}}, {"a": [1, 2, 3]}, {})) @pytest.mark.parametrize("attr2", ({"a": [1, 2, 3]}, {})) @@ -548,3 +602,36 @@ def test_concat_merge_single_non_dim_coord(): for coords in ["different", "all"]: with raises_regex(ValueError, "'y' not present in all datasets"): concat([da1, da2, da3], dim="x") + + +def test_concat_preserve_coordinate_order(): + x = np.arange(0, 5) + y = np.arange(0, 10) + time = np.arange(0, 4) + data = np.zeros((4, 10, 5), dtype=bool) + + ds1 = Dataset( + {"data": (["time", "y", "x"], data[0:2])}, + coords={"time": time[0:2], "y": y, "x": x}, + ) + ds2 = Dataset( + {"data": (["time", "y", "x"], data[2:4])}, + coords={"time": time[2:4], "y": y, "x": x}, + ) + + expected = Dataset( + {"data": (["time", "y", "x"], data)}, + coords={"time": time, "y": y, "x": x}, + ) + + actual = concat([ds1, ds2], dim="time") + + # check dimension order + for act, exp in zip(actual.dims, expected.dims): + assert act == exp + assert actual.dims[act] == expected.dims[exp] + + # check coordinate order + for act, exp in zip(actual.coords, expected.coords): + assert act == exp + assert_identical(actual.coords[act], expected.coords[exp]) diff --git a/xarray/tests/test_conventions.py b/xarray/tests/test_conventions.py index acb2400ea04..9abaa978651 100644 --- a/xarray/tests/test_conventions.py +++ b/xarray/tests/test_conventions.py @@ -32,10 +32,8 @@ class TestBoolTypeArray: def test_booltype_array(self): x = np.array([1, 0, 1, 1, 0], dtype="i1") bx = conventions.BoolTypeArray(x) - assert bx.dtype == np.bool - assert_array_equal( - bx, np.array([True, False, True, True, False], dtype=np.bool) - ) + assert bx.dtype == bool + assert_array_equal(bx, np.array([True, False, True, True, False], dtype=bool)) class TestNativeEndiannessArray: @@ -235,6 +233,7 @@ def test_decode_cf_with_drop_variables(self): assert_identical(expected, actual) assert_identical(expected, actual2) + @pytest.mark.filterwarnings("ignore:Ambiguous reference date string") def test_invalid_time_units_raises_eagerly(self): ds = Dataset({"time": ("time", [0, 1], {"units": "foobar since 123"})}) with raises_regex(ValueError, "unable to decode time"): @@ -311,6 +310,41 @@ def test_decode_dask_times(self): conventions.decode_cf(original).chunk(), ) + def test_decode_cf_time_kwargs(self): + ds = Dataset.from_dict( + { + "coords": { + "timedelta": { + "data": np.array([1, 2, 3], dtype="int64"), + "dims": "timedelta", + "attrs": {"units": "days"}, + }, + "time": { + "data": np.array([1, 2, 3], dtype="int64"), + "dims": "time", + "attrs": {"units": "days since 2000-01-01"}, + }, + }, + "dims": {"time": 3, "timedelta": 3}, + "data_vars": { + "a": {"dims": ("time", "timedelta"), "data": np.ones((3, 3))}, + }, + } + ) + + dsc = conventions.decode_cf(ds) + assert dsc.timedelta.dtype == np.dtype("m8[ns]") + assert dsc.time.dtype == np.dtype("M8[ns]") + dsc = conventions.decode_cf(ds, decode_times=False) + assert dsc.timedelta.dtype == np.dtype("int64") + assert dsc.time.dtype == np.dtype("int64") + dsc = conventions.decode_cf(ds, decode_times=True, decode_timedelta=False) + assert dsc.timedelta.dtype == np.dtype("int64") + assert dsc.time.dtype == np.dtype("M8[ns]") + dsc = conventions.decode_cf(ds, decode_times=False, decode_timedelta=True) + assert dsc.timedelta.dtype == np.dtype("m8[ns]") + assert dsc.time.dtype == np.dtype("int64") + class CFEncodedInMemoryStore(WritableCFDataStore, InMemoryDataStore): def encode_variable(self, var): @@ -328,13 +362,17 @@ def create_store(self): @contextlib.contextmanager def roundtrip( - self, data, save_kwargs={}, open_kwargs={}, allow_cleanup_failure=False + self, data, save_kwargs=None, open_kwargs=None, allow_cleanup_failure=False ): + if save_kwargs is None: + save_kwargs = {} + if open_kwargs is None: + open_kwargs = {} store = CFEncodedInMemoryStore() data.dump_to_store(store, **save_kwargs) yield open_dataset(store, **open_kwargs) - @pytest.mark.skip("cannot roundtrip coordinates yet for " "CFEncodedInMemoryStore") + @pytest.mark.skip("cannot roundtrip coordinates yet for CFEncodedInMemoryStore") def test_roundtrip_coordinates(self): pass diff --git a/xarray/tests/test_cupy.py b/xarray/tests/test_cupy.py new file mode 100644 index 00000000000..0276b8ebc08 --- /dev/null +++ b/xarray/tests/test_cupy.py @@ -0,0 +1,60 @@ +import numpy as np +import pandas as pd +import pytest + +import xarray as xr + +cp = pytest.importorskip("cupy") + + +@pytest.fixture +def toy_weather_data(): + """Construct the example DataSet from the Toy weather data example. + + http://xarray.pydata.org/en/stable/examples/weather-data.html + + Here we construct the DataSet exactly as shown in the example and then + convert the numpy arrays to cupy. + + """ + np.random.seed(123) + times = pd.date_range("2000-01-01", "2001-12-31", name="time") + annual_cycle = np.sin(2 * np.pi * (times.dayofyear.values / 365.25 - 0.28)) + + base = 10 + 15 * annual_cycle.reshape(-1, 1) + tmin_values = base + 3 * np.random.randn(annual_cycle.size, 3) + tmax_values = base + 10 + 3 * np.random.randn(annual_cycle.size, 3) + + ds = xr.Dataset( + { + "tmin": (("time", "location"), tmin_values), + "tmax": (("time", "location"), tmax_values), + }, + {"time": times, "location": ["IA", "IN", "IL"]}, + ) + + ds.tmax.data = cp.asarray(ds.tmax.data) + ds.tmin.data = cp.asarray(ds.tmin.data) + + return ds + + +def test_cupy_import(): + """Check the import worked.""" + assert cp + + +def test_check_data_stays_on_gpu(toy_weather_data): + """Perform some operations and check the data stays on the GPU.""" + freeze = (toy_weather_data["tmin"] <= 0).groupby("time.month").mean("time") + assert isinstance(freeze.data, cp.core.core.ndarray) + + +def test_where(): + from xarray.core.duck_array_ops import where + + data = cp.zeros(10) + + output = where(data < 1, 1, data).all() + assert output + assert isinstance(output, cp.ndarray) diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 538dbbfb58b..19a61c60577 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -23,7 +23,9 @@ assert_equal, assert_frame_equal, assert_identical, + raise_if_dask_computes, raises_regex, + requires_pint_0_15, requires_scipy_or_netCDF4, ) from .test_backends import create_tmp_file @@ -35,30 +37,6 @@ ON_WINDOWS = sys.platform == "win32" -class CountingScheduler: - """ Simple dask scheduler counting the number of computes. - - Reference: https://stackoverflow.com/questions/53289286/ """ - - def __init__(self, max_computes=0): - self.total_computes = 0 - self.max_computes = max_computes - - def __call__(self, dsk, keys, **kwargs): - self.total_computes += 1 - if self.total_computes > self.max_computes: - raise RuntimeError( - "Too many computes. Total: %d > max: %d." - % (self.total_computes, self.max_computes) - ) - return dask.get(dsk, keys, **kwargs) - - -def raise_if_dask_computes(max_computes=0): - scheduler = CountingScheduler(max_computes) - return dask.config.set(scheduler=scheduler) - - def test_raise_if_dask_computes(): data = da.from_array(np.random.RandomState(0).randn(4, 6), chunks=(2, 2)) with raises_regex(RuntimeError, "Too many computes"): @@ -292,6 +270,22 @@ def test_persist(self): self.assertLazyAndAllClose(u + 1, v) self.assertLazyAndAllClose(u + 1, v2) + @requires_pint_0_15(reason="Need __dask_tokenize__") + def test_tokenize_duck_dask_array(self): + import pint + + unit_registry = pint.UnitRegistry() + + q = unit_registry.Quantity(self.data, "meter") + variable = xr.Variable(("x", "y"), q) + + token = dask.base.tokenize(variable) + post_op = variable + 5 * unit_registry.meter + + assert dask.base.tokenize(variable) != dask.base.tokenize(post_op) + # Immutability check + assert dask.base.tokenize(variable) == token + class TestDataArrayAndDataset(DaskTestCase): def assertLazyAndIdentical(self, expected, actual): @@ -715,15 +709,35 @@ def test_from_dask_variable(self): a = DataArray(self.lazy_array.variable, coords={"x": range(4)}, name="foo") self.assertLazyAndIdentical(self.lazy_array, a) + @requires_pint_0_15(reason="Need __dask_tokenize__") + def test_tokenize_duck_dask_array(self): + import pint + + unit_registry = pint.UnitRegistry() + + q = unit_registry.Quantity(self.data, unit_registry.meter) + data_array = xr.DataArray( + data=q, coords={"x": range(4)}, dims=("x", "y"), name="foo" + ) + + token = dask.base.tokenize(data_array) + post_op = data_array + 5 * unit_registry.meter + + assert dask.base.tokenize(data_array) != dask.base.tokenize(post_op) + # Immutability check + assert dask.base.tokenize(data_array) == token + class TestToDaskDataFrame: def test_to_dask_dataframe(self): # Test conversion of Datasets to dask DataFrames - x = da.from_array(np.random.randn(10), chunks=4) + x = np.random.randn(10) y = np.arange(10, dtype="uint8") t = list("abcdefghij") - ds = Dataset({"a": ("t", x), "b": ("t", y), "t": ("t", t)}) + ds = Dataset( + {"a": ("t", da.from_array(x, chunks=4)), "b": ("t", y), "t": ("t", t)} + ) expected_pd = pd.DataFrame({"a": x, "b": y}, index=pd.Index(t, name="t")) @@ -746,8 +760,8 @@ def test_to_dask_dataframe(self): def test_to_dask_dataframe_2D(self): # Test if 2-D dataset is supplied - w = da.from_array(np.random.randn(2, 3), chunks=(1, 2)) - ds = Dataset({"w": (("x", "y"), w)}) + w = np.random.randn(2, 3) + ds = Dataset({"w": (("x", "y"), da.from_array(w, chunks=(1, 2)))}) ds["x"] = ("x", np.array([0, 1], np.int64)) ds["y"] = ("y", list("abc")) @@ -779,10 +793,15 @@ def test_to_dask_dataframe_2D_set_index(self): def test_to_dask_dataframe_coordinates(self): # Test if coordinate is also a dask array - x = da.from_array(np.random.randn(10), chunks=4) - t = da.from_array(np.arange(10) * 2, chunks=4) + x = np.random.randn(10) + t = np.arange(10) * 2 - ds = Dataset({"a": ("t", x), "t": ("t", t)}) + ds = Dataset( + { + "a": ("t", da.from_array(x, chunks=4)), + "t": ("t", da.from_array(t, chunks=4)), + } + ) expected_pd = pd.DataFrame({"a": x}, index=pd.Index(t, name="t")) expected = dd.from_pandas(expected_pd, chunksize=4) @@ -972,6 +991,7 @@ def make_da(): coords={"x": np.arange(10), "y": np.arange(100, 120)}, name="a", ).chunk({"x": 4, "y": 5}) + da.x.attrs["long_name"] = "x" da.attrs["test"] = "test" da.coords["c2"] = 0.5 da.coords["ndcoord"] = da.x * 2 @@ -995,6 +1015,9 @@ def make_ds(): map_ds.attrs["test"] = "test" map_ds.coords["xx"] = map_ds["a"] * map_ds.y + map_ds.x.attrs["long_name"] = "x" + map_ds.y.attrs["long_name"] = "y" + return map_ds @@ -1035,11 +1058,19 @@ def test_unify_chunks_shallow_copy(obj, transform): assert_identical(obj, unified) and obj is not obj.unify_chunks() +@pytest.mark.parametrize("obj", [make_da()]) +def test_auto_chunk_da(obj): + actual = obj.chunk("auto").data + expected = obj.data.rechunk("auto") + np.testing.assert_array_equal(actual, expected) + assert actual.chunks == expected.chunks + + def test_map_blocks_error(map_da, map_ds): def bad_func(darray): return (darray * darray.x + 5 * darray.y)[:1, :1] - with raises_regex(ValueError, "Length of the.* has changed."): + with raises_regex(ValueError, "Received dimension 'x' of length 1"): xr.map_blocks(bad_func, map_da).compute() def returns_numpy(darray): @@ -1066,9 +1097,6 @@ def really_bad_func(darray): with raises_regex(ValueError, "inconsistent chunks"): xr.map_blocks(bad_func, ds_copy) - with raises_regex(TypeError, "Cannot pass dask collections"): - xr.map_blocks(bad_func, map_da, args=[map_da.chunk()]) - with raises_regex(TypeError, "Cannot pass dask collections"): xr.map_blocks(bad_func, map_da, kwargs=dict(a=map_da.chunk())) @@ -1095,6 +1123,58 @@ def test_map_blocks_convert_args_to_list(obj): assert_identical(actual, expected) +def test_map_blocks_dask_args(): + da1 = xr.DataArray( + np.ones((10, 20)), + dims=["x", "y"], + coords={"x": np.arange(10), "y": np.arange(20)}, + ).chunk({"x": 5, "y": 4}) + + # check that block shapes are the same + def sumda(da1, da2): + assert da1.shape == da2.shape + return da1 + da2 + + da2 = da1 + 1 + with raise_if_dask_computes(): + mapped = xr.map_blocks(sumda, da1, args=[da2]) + xr.testing.assert_equal(da1 + da2, mapped) + + # one dimension in common + da2 = (da1 + 1).isel(x=1, drop=True) + with raise_if_dask_computes(): + mapped = xr.map_blocks(operator.add, da1, args=[da2]) + xr.testing.assert_equal(da1 + da2, mapped) + + # test that everything works when dimension names are different + da2 = (da1 + 1).isel(x=1, drop=True).rename({"y": "k"}) + with raise_if_dask_computes(): + mapped = xr.map_blocks(operator.add, da1, args=[da2]) + xr.testing.assert_equal(da1 + da2, mapped) + + with raises_regex(ValueError, "Chunk sizes along dimension 'x'"): + xr.map_blocks(operator.add, da1, args=[da1.chunk({"x": 1})]) + + with raises_regex(ValueError, "indexes along dimension 'x' are not equal"): + xr.map_blocks(operator.add, da1, args=[da1.reindex(x=np.arange(20))]) + + # reduction + da1 = da1.chunk({"x": -1}) + da2 = da1 + 1 + with raise_if_dask_computes(): + mapped = xr.map_blocks(lambda a, b: (a + b).sum("x"), da1, args=[da2]) + xr.testing.assert_equal((da1 + da2).sum("x"), mapped) + + # reduction with template + da1 = da1.chunk({"x": -1}) + da2 = da1 + 1 + with raise_if_dask_computes(): + mapped = xr.map_blocks( + lambda a, b: (a + b).sum("x"), da1, args=[da2], template=da1.sum("x") + ) + xr.testing.assert_equal((da1 + da2).sum("x"), mapped) + + @pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks_add_attrs(obj): def add_attrs(obj): @@ -1109,6 +1189,11 @@ def add_attrs(obj): assert_identical(actual, expected) + # when template is specified, attrs are copied from template, not set by function + with raise_if_dask_computes(): + actual = xr.map_blocks(add_attrs, obj, template=obj) + assert_identical(actual, obj) + def test_map_blocks_change_name(map_da): def change_name(obj): @@ -1150,7 +1235,7 @@ def test_map_blocks_to_array(map_ds): lambda x: x.expand_dims(k=3), lambda x: x.assign_coords(new_coord=("y", x.y * 2)), lambda x: x.astype(np.int32), - # TODO: [lambda x: x.isel(x=1).drop_vars("x"), map_da], + lambda x: x.x, ], ) def test_map_blocks_da_transformations(func, map_da): @@ -1170,7 +1255,7 @@ def test_map_blocks_da_transformations(func, map_da): lambda x: x.expand_dims(k=[1, 2, 3]), lambda x: x.expand_dims(k=3), lambda x: x.rename({"a": "new1", "b": "new2"}), - # TODO: [lambda x: x.isel(x=1)], + lambda x: x.x, ], ) def test_map_blocks_ds_transformations(func, map_ds): @@ -1180,6 +1265,64 @@ def test_map_blocks_ds_transformations(func, map_ds): assert_identical(actual, func(map_ds)) +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +def test_map_blocks_da_ds_with_template(obj): + func = lambda x: x.isel(x=[1]) + template = obj.isel(x=[1, 5, 9]) + with raise_if_dask_computes(): + actual = xr.map_blocks(func, obj, template=template) + assert_identical(actual, template) + + with raise_if_dask_computes(): + actual = obj.map_blocks(func, template=template) + assert_identical(actual, template) + + +def test_map_blocks_template_convert_object(): + da = make_da() + func = lambda x: x.to_dataset().isel(x=[1]) + template = da.to_dataset().isel(x=[1, 5, 9]) + with raise_if_dask_computes(): + actual = xr.map_blocks(func, da, template=template) + assert_identical(actual, template) + + ds = da.to_dataset() + func = lambda x: x.to_array().isel(x=[1]) + template = ds.to_array().isel(x=[1, 5, 9]) + with raise_if_dask_computes(): + actual = xr.map_blocks(func, ds, template=template) + assert_identical(actual, template) + + +@pytest.mark.parametrize("obj", [make_da(), make_ds()]) +def test_map_blocks_errors_bad_template(obj): + with raises_regex(ValueError, "unexpected coordinate variables"): + xr.map_blocks(lambda x: x.assign_coords(a=10), obj, template=obj).compute() + with raises_regex(ValueError, "does not contain coordinate variables"): + xr.map_blocks(lambda x: x.drop_vars("cxy"), obj, template=obj).compute() + with raises_regex(ValueError, "Dimensions {'x'} missing"): + xr.map_blocks(lambda x: x.isel(x=1), obj, template=obj).compute() + with raises_regex(ValueError, "Received dimension 'x' of length 1"): + xr.map_blocks(lambda x: x.isel(x=[1]), obj, template=obj).compute() + with raises_regex(TypeError, "must be a DataArray"): + xr.map_blocks(lambda x: x.isel(x=[1]), obj, template=(obj,)).compute() + with raises_regex(ValueError, "map_blocks requires that one block"): + xr.map_blocks( + lambda x: x.isel(x=[1]).assign_coords(x=10), obj, template=obj.isel(x=[1]) + ).compute() + with raises_regex(ValueError, "Expected index 'x' to be"): + xr.map_blocks( + lambda a: a.isel(x=[1]).assign_coords(x=[120]), # assign bad index values + obj, + template=obj.isel(x=[1, 5, 9]), + ).compute() + + +def test_map_blocks_errors_bad_template_2(map_ds): + with raises_regex(ValueError, "unexpected data variables {'xyz'}"): + xr.map_blocks(lambda x: x.assign(xyz=1), map_ds, template=map_ds).compute() + + @pytest.mark.parametrize("obj", [make_da(), make_ds()]) def test_map_blocks_object_method(obj): def func(obj): @@ -1448,3 +1591,11 @@ def test_more_transforms_pass_lazy_array_equiv(map_da, map_ds): assert_equal(map_da._from_temp_dataset(map_da._to_temp_dataset()), map_da) assert_equal(map_da.astype(map_da.dtype), map_da) assert_equal(map_da.transpose("y", "x", transpose_coords=False).cxy, map_da.cxy) + + +def test_optimize(): + # https://github.com/pydata/xarray/issues/3698 + a = dask.array.ones((10, 4), chunks=(5, 2)) + arr = xr.DataArray(a).chunk(5) + (arr2,) = dask.optimize(arr) + arr2.compute() diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index c3e5aafabfe..3ead427e22e 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -9,7 +9,15 @@ import pytest import xarray as xr -from xarray import DataArray, Dataset, IndexVariable, Variable, align, broadcast +from xarray import ( + DataArray, + Dataset, + IndexVariable, + Variable, + align, + broadcast, + set_options, +) from xarray.coding.times import CFDatetimeCoder from xarray.convert import from_cdms2 from xarray.core import dtypes @@ -24,6 +32,7 @@ assert_equal, assert_identical, has_dask, + raise_if_dask_computes, raises_regex, requires_bottleneck, requires_dask, @@ -34,6 +43,11 @@ source_ndarray, ) +pytestmark = [ + pytest.mark.filterwarnings("error:Mean of empty slice"), + pytest.mark.filterwarnings("error:All-NaN (slice|axis) encountered"), +] + class TestDataArray: @pytest.fixture(autouse=True) @@ -783,13 +797,13 @@ def test_isel(self): assert_identical(self.dv[:3, :5], self.dv.isel(x=slice(3), y=slice(5))) with raises_regex( ValueError, - r"dimensions {'not_a_dim'} do not exist. Expected " + r"Dimensions {'not_a_dim'} do not exist. Expected " r"one or more of \('x', 'y'\)", ): self.dv.isel(not_a_dim=0) with pytest.warns( UserWarning, - match=r"dimensions {'not_a_dim'} do not exist. " + match=r"Dimensions {'not_a_dim'} do not exist. " r"Expected one or more of \('x', 'y'\)", ): self.dv.isel(not_a_dim=0, missing_dims="warn") @@ -936,7 +950,7 @@ def test_sel_invalid_slice(self): with raises_regex(ValueError, "cannot use non-scalar arrays"): array.sel(x=slice(array.x)) - def test_sel_dataarray_datetime(self): + def test_sel_dataarray_datetime_slice(self): # regression test for GH1240 times = pd.date_range("2000-01-01", freq="D", periods=365) array = DataArray(np.arange(365), [("time", times)]) @@ -1076,6 +1090,12 @@ def test_loc(self): assert_identical(da[:3, :4], da.loc[["a", "b", "c"], np.arange(4)]) assert_identical(da[:, :4], da.loc[:, self.ds["y"] < 4]) + def test_loc_datetime64_value(self): + # regression test for https://github.com/pydata/xarray/issues/4283 + t = np.array(["2017-09-05T12", "2017-09-05T15"], dtype="datetime64[ns]") + array = DataArray(np.ones(t.shape), dims=("time",), coords=(t,)) + assert_identical(array.loc[{"time": t[0]}], array[0]) + def test_loc_assign(self): self.ds["x"] = ("x", np.array(list("abcdefghij"))) da = self.ds["foo"] @@ -1150,6 +1170,16 @@ def test_loc_single_boolean(self): assert data.loc[True] == 0 assert data.loc[False] == 1 + def test_loc_dim_name_collision_with_sel_params(self): + da = xr.DataArray( + [[0, 0], [1, 1]], + dims=["dim1", "method"], + coords={"dim1": ["x", "y"], "method": ["a", "b"]}, + ) + np.testing.assert_array_equal( + da.loc[dict(dim1=["x", "y"], method=["a"])], [[0], [1]] + ) + def test_selection_multiindex(self): mindex = pd.MultiIndex.from_product( [["a", "b"], [1, 2], [-1, -2]], names=("one", "two", "three") @@ -1385,8 +1415,6 @@ def test_reset_coords(self): ) assert_identical(actual, expected) - with pytest.raises(TypeError): - data = data.reset_coords(inplace=True) with raises_regex(ValueError, "cannot be found"): data.reset_coords("foo", drop=True) with raises_regex(ValueError, "cannot be found"): @@ -1466,8 +1494,8 @@ def test_broadcast_like(self): new1 = arr1.broadcast_like(arr2) new2 = arr2.broadcast_like(arr1) - assert orig1.identical(new1) - assert orig2.identical(new2) + assert_identical(orig1, new1) + assert_identical(orig2, new2) orig3 = DataArray(np.random.randn(5), [("x", range(5))]) orig4 = DataArray(np.random.randn(6), [("y", range(6))]) @@ -1501,7 +1529,7 @@ def test_reindex_regressions(self): da.reindex(time=time2) # regression test for #736, reindex can not change complex nums dtype - x = np.array([1, 2, 3], dtype=np.complex) + x = np.array([1, 2, 3], dtype=complex) x = DataArray(x, coords=[[0.1, 0.2, 0.3]]) y = DataArray([2, 5, 6, 7, 8], coords=[[-1.1, 0.21, 0.31, 0.41, 0.51]]) re_dtype = x.reindex_like(y, method="pad").dtype @@ -1519,18 +1547,40 @@ def test_reindex_method(self): expected = DataArray([10, 20, np.nan], coords=[("y", y)]) assert_identical(expected, actual) - @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {None: 2, "u": 1}]) def test_reindex_fill_value(self, fill_value): - x = DataArray([10, 20], dims="y", coords={"y": [0, 1]}) + x = DataArray([10, 20], dims="y", coords={"y": [0, 1], "u": ("y", [1, 2])}) y = [0, 1, 2] if fill_value == dtypes.NA: # if we supply the default, we expect the missing value for a # float array - fill_value = np.nan + fill_value_var = fill_value_u = np.nan + elif isinstance(fill_value, dict): + fill_value_var = fill_value[None] + fill_value_u = fill_value["u"] + else: + fill_value_var = fill_value_u = fill_value actual = x.reindex(y=y, fill_value=fill_value) - expected = DataArray([10, 20, fill_value], coords=[("y", y)]) + expected = DataArray( + [10, 20, fill_value_var], + dims="y", + coords={"y": y, "u": ("y", [1, 2, fill_value_u])}, + ) assert_identical(expected, actual) + @pytest.mark.parametrize("dtype", [str, bytes]) + def test_reindex_str_dtype(self, dtype): + + data = DataArray( + [1, 2], dims="x", coords={"x": np.array(["a", "b"], dtype=dtype)} + ) + + actual = data.reindex(x=data.x) + expected = data + + assert_identical(expected, actual) + assert actual.dtype == expected.dtype + def test_rename(self): renamed = self.dv.rename("bar") assert_identical(renamed.to_dataset(), self.ds.rename({"foo": "bar"})) @@ -1828,6 +1878,13 @@ def test_reset_index(self): expected = DataArray([1, 2], coords={"x_": ("x", ["a", "b"])}, dims="x") assert_identical(array.reset_index("x"), expected) + def test_reset_index_keep_attrs(self): + coord_1 = DataArray([1, 2], dims=["coord_1"], attrs={"attrs": True}) + da = DataArray([1, 0], [coord_1]) + expected = DataArray([1, 0], {"coord_1_": coord_1}, dims=["coord_1"]) + obj = da.reset_index("coord_1") + assert_identical(expected, obj) + def test_reorder_levels(self): midx = self.mindex.reorder_levels(["level_2", "level_1"]) expected = DataArray(self.mda.values, coords={"x": midx}, dims="x") @@ -1835,10 +1892,6 @@ def test_reorder_levels(self): obj = self.mda.reorder_levels(x=["level_2", "level_1"]) assert_identical(obj, expected) - with pytest.raises(TypeError): - array = self.mda.copy() - array.reorder_levels(x=["level_2", "level_1"], inplace=True) - array = DataArray([1, 2], dims="x") with pytest.raises(KeyError): array.reorder_levels(x=["level_1", "level_2"]) @@ -1865,6 +1918,39 @@ def test_array_interface(self): bar = Variable(["x", "y"], np.zeros((10, 20))) assert_equal(self.dv, np.maximum(self.dv, bar)) + def test_astype_attrs(self): + for v in [self.va.copy(), self.mda.copy(), self.ds.copy()]: + v.attrs["foo"] = "bar" + assert v.attrs == v.astype(float).attrs + assert not v.astype(float, keep_attrs=False).attrs + + def test_astype_dtype(self): + original = DataArray([-1, 1, 2, 3, 1000]) + converted = original.astype(float) + assert_array_equal(original, converted) + assert np.issubdtype(original.dtype, np.integer) + assert np.issubdtype(converted.dtype, np.floating) + + def test_astype_order(self): + original = DataArray([[1, 2], [3, 4]]) + converted = original.astype("d", order="F") + assert_equal(original, converted) + assert original.values.flags["C_CONTIGUOUS"] + assert converted.values.flags["F_CONTIGUOUS"] + + def test_astype_subok(self): + class NdArraySubclass(np.ndarray): + pass + + original = DataArray(NdArraySubclass(np.arange(3))) + converted_not_subok = original.astype("d", subok=False) + converted_subok = original.astype("d", subok=True) + if not isinstance(original.data, NdArraySubclass): + pytest.xfail("DataArray cannot be backed yet by a subclasses of np.ndarray") + assert isinstance(converted_not_subok.data, np.ndarray) + assert not isinstance(converted_not_subok.data, NdArraySubclass) + assert isinstance(converted_subok.data, NdArraySubclass) + def test_is_null(self): x = np.random.RandomState(42).randn(5, 6) x[x < 0] = np.nan @@ -1921,9 +2007,9 @@ def test_inplace_math_basics(self): def test_inplace_math_automatic_alignment(self): a = DataArray(range(5), [("x", range(5))]) b = DataArray(range(1, 6), [("x", range(1, 6))]) - with pytest.raises(xr.MergeError): + with pytest.raises(xr.MergeError, match="Automatic alignment is not supported"): a += b - with pytest.raises(xr.MergeError): + with pytest.raises(xr.MergeError, match="Automatic alignment is not supported"): b += a def test_math_name(self): @@ -2158,11 +2244,20 @@ def test_transpose(self): actual = da.transpose("z", ..., "x", transpose_coords=True) assert_equal(expected, actual) + # same as previous but with a missing dimension + actual = da.transpose( + "z", "y", "x", "not_a_dim", transpose_coords=True, missing_dims="ignore" + ) + assert_equal(expected, actual) + with pytest.raises(ValueError): da.transpose("x", "y") - with pytest.warns(FutureWarning): - da.transpose() + with pytest.raises(ValueError): + da.transpose("not_a_dim", "z", "x", ...) + + with pytest.warns(UserWarning): + da.transpose("not_a_dim", "y", "x", ..., missing_dims="warn") def test_squeeze(self): assert_equal(self.dv.variable.squeeze(), self.dv.squeeze().variable) @@ -2452,6 +2547,21 @@ def test_assign_attrs(self): assert_identical(new_actual, expected) assert actual.attrs == {"a": 1, "b": 2} + @pytest.mark.parametrize( + "func", [lambda x: x.clip(0, 1), lambda x: np.float64(1.0) * x, np.abs, abs] + ) + def test_propagate_attrs(self, func): + da = DataArray(self.va) + + # test defaults + assert func(da).attrs == da.attrs + + with set_options(keep_attrs=False): + assert func(da).attrs == {} + + with set_options(keep_attrs=True): + assert func(da).attrs == da.attrs + def test_fillna(self): a = DataArray([np.nan, 1, np.nan, 3], coords={"x": range(4)}, dims="x") actual = a.fillna(-1) @@ -2753,9 +2863,6 @@ def test_groupby_restore_coord_dims(self): )["c"] assert result.dims == expected_dims - with pytest.warns(FutureWarning): - array.groupby("x").map(lambda x: x.squeeze()) - def test_groupby_first_and_last(self): array = DataArray([1, 2, 3, 4, 5], dims="x") by = DataArray(["a"] * 2 + ["b"] * 3, dims="x", name="ab") @@ -3144,7 +3251,8 @@ def test_upsample_interpolate_regression_1605(self): @requires_dask @requires_scipy - def test_upsample_interpolate_dask(self): + @pytest.mark.parametrize("chunked_time", [True, False]) + def test_upsample_interpolate_dask(self, chunked_time): from scipy.interpolate import interp1d xs = np.arange(6) @@ -3155,6 +3263,8 @@ def test_upsample_interpolate_dask(self): data = np.tile(z, (6, 3, 1)) array = DataArray(data, {"time": times, "x": xs, "y": ys}, ("x", "y", "time")) chunks = {"x": 2, "y": 1} + if chunked_time: + chunks["time"] = 3 expected_times = times.to_series().resample("1H").asfreq().index # Split the times into equal sub-intervals to simulate the 6 hour @@ -3182,13 +3292,6 @@ def test_upsample_interpolate_dask(self): # done here due to floating point arithmetic assert_allclose(expected, actual, rtol=1e-16) - # Check that an error is raised if an attempt is made to interpolate - # over a chunked dimension - with raises_regex( - NotImplementedError, "Chunking along the dimension to be interpolated" - ): - array.chunk({"time": 1}).resample(time="1H").interpolate("linear") - def test_align(self): array = DataArray( np.random.random((6, 8)), coords={"x": list("abcdef")}, dims=["x", "y"] @@ -3345,6 +3448,26 @@ def test_align_without_indexes_errors(self): DataArray([1, 2], coords=[("x", [0, 1])]), ) + def test_align_str_dtype(self): + + a = DataArray([0, 1], dims=["x"], coords={"x": ["a", "b"]}) + b = DataArray([1, 2], dims=["x"], coords={"x": ["b", "c"]}) + + expected_a = DataArray( + [0, 1, np.NaN], dims=["x"], coords={"x": ["a", "b", "c"]} + ) + expected_b = DataArray( + [np.NaN, 1, 2], dims=["x"], coords={"x": ["a", "b", "c"]} + ) + + actual_a, actual_b = xr.align(a, b, join="outer") + + assert_identical(expected_a, actual_a) + assert expected_a.x.dtype == actual_a.x.dtype + + assert_identical(expected_b, actual_b) + assert expected_b.x.dtype == actual_b.x.dtype + def test_broadcast_arrays(self): x = DataArray([1, 2], coords=[("a", [-1, -2])], name="x") y = DataArray([1, 2], coords=[("b", [3, 4])], name="y") @@ -3464,15 +3587,18 @@ def test_to_pandas(self): def test_to_dataframe(self): # regression test for #260 - arr = DataArray( - np.random.randn(3, 4), [("B", [1, 2, 3]), ("A", list("cdef"))], name="foo" - ) + arr_np = np.random.randn(3, 4) + + arr = DataArray(arr_np, [("B", [1, 2, 3]), ("A", list("cdef"))], name="foo") expected = arr.to_series() actual = arr.to_dataframe()["foo"] assert_array_equal(expected.values, actual.values) assert_array_equal(expected.name, actual.name) assert_array_equal(expected.index.values, actual.index.values) + actual = arr.to_dataframe(dim_order=["A", "B"])["foo"] + assert_array_equal(arr_np.transpose().reshape(-1), actual.values) + # regression test for coords with different dimensions arr.coords["C"] = ("B", [-1, -2, -3]) expected = arr.to_series().to_frame() @@ -3483,6 +3609,12 @@ def test_to_dataframe(self): assert_array_equal(expected.columns.values, actual.columns.values) assert_array_equal(expected.index.values, actual.index.values) + with pytest.raises(ValueError, match="does not match the set of dimensions"): + arr.to_dataframe(dim_order=["B", "A", "C"]) + + with pytest.raises(ValueError, match=r"cannot convert a scalar"): + arr.sel(A="c", B=2).to_dataframe() + arr.name = None # unnamed with raises_regex(ValueError, "unnamed"): arr.to_dataframe() @@ -3536,6 +3668,24 @@ def test_from_series_sparse(self): actual_sparse.data = actual_sparse.data.todense() assert_identical(actual_sparse, actual_dense) + @requires_sparse + def test_from_multiindex_series_sparse(self): + # regression test for GH4019 + import sparse + + idx = pd.MultiIndex.from_product([np.arange(3), np.arange(5)], names=["a", "b"]) + series = pd.Series(np.random.RandomState(0).random(len(idx)), index=idx).sample( + n=5, random_state=3 + ) + + dense = DataArray.from_series(series, sparse=False) + expected_coords = sparse.COO.from_numpy(dense.data, np.nan).coords + + actual_sparse = xr.DataArray.from_series(series, sparse=True) + actual_coords = actual_sparse.data.coords + + np.testing.assert_equal(actual_coords, expected_coords) + def test_to_and_from_empty_series(self): # GH697 expected = pd.Series([], dtype=np.float64) @@ -4265,8 +4415,14 @@ def test_polyfit(self, use_dask, use_datetime): ).T assert_allclose(out.polyfit_coefficients, expected, rtol=1e-3) + # Full output and deficient rank + with warnings.catch_warnings(): + warnings.simplefilter("ignore", np.RankWarning) + out = da.polyfit("x", 12, full=True) + assert out.polyfit_residuals.isnull().all() + # With NaN - da_raw[0, 1] = np.nan + da_raw[0, 1:3] = np.nan if use_dask: da = da_raw.chunk({"d": 1}) else: @@ -4281,6 +4437,11 @@ def test_polyfit(self, use_dask, use_datetime): assert out.x_matrix_rank == 3 np.testing.assert_almost_equal(out.polyfit_residuals, [0, 0]) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", np.RankWarning) + out = da.polyfit("x", 8, full=True) + np.testing.assert_array_equal(out.polyfit_residuals.isnull(), [True, False]) + def test_pad_constant(self): ar = DataArray(np.arange(3 * 4 * 5).reshape(3, 4, 5)) actual = ar.pad(dim_0=(1, 3)) @@ -4472,6 +4633,9 @@ def test_max(self, x, minindex, maxindex, nanindex): assert_identical(result2, expected2) + @pytest.mark.filterwarnings( + "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" + ) def test_argmin(self, x, minindex, maxindex, nanindex): ar = xr.DataArray( x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs @@ -4501,6 +4665,9 @@ def test_argmin(self, x, minindex, maxindex, nanindex): assert_identical(result2, expected2) + @pytest.mark.filterwarnings( + "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" + ) def test_argmax(self, x, minindex, maxindex, nanindex): ar = xr.DataArray( x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs @@ -4530,11 +4697,21 @@ def test_argmax(self, x, minindex, maxindex, nanindex): assert_identical(result2, expected2) - def test_idxmin(self, x, minindex, maxindex, nanindex): - ar0 = xr.DataArray( + @pytest.mark.parametrize("use_dask", [True, False]) + def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): + if use_dask and not has_dask: + pytest.skip("requires dask") + if use_dask and x.dtype.kind == "M": + pytest.xfail("dask operation 'argmin' breaks when dtype is datetime64 (M)") + ar0_raw = xr.DataArray( x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs ) + if use_dask: + ar0 = ar0_raw.chunk({}) + else: + ar0 = ar0_raw + # dim doesn't exist with pytest.raises(KeyError): ar0.idxmin(dim="spam") @@ -4626,11 +4803,21 @@ def test_idxmin(self, x, minindex, maxindex, nanindex): result7 = ar0.idxmin(fill_value=-1j) assert_identical(result7, expected7) - def test_idxmax(self, x, minindex, maxindex, nanindex): - ar0 = xr.DataArray( + @pytest.mark.parametrize("use_dask", [True, False]) + def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): + if use_dask and not has_dask: + pytest.skip("requires dask") + if use_dask and x.dtype.kind == "M": + pytest.xfail("dask operation 'argmax' breaks when dtype is datetime64 (M)") + ar0_raw = xr.DataArray( x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs ) + if use_dask: + ar0 = ar0_raw.chunk({}) + else: + ar0 = ar0_raw + # dim doesn't exist with pytest.raises(KeyError): ar0.idxmax(dim="spam") @@ -4722,6 +4909,78 @@ def test_idxmax(self, x, minindex, maxindex, nanindex): result7 = ar0.idxmax(fill_value=-1j) assert_identical(result7, expected7) + @pytest.mark.filterwarnings( + "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" + ) + def test_argmin_dim(self, x, minindex, maxindex, nanindex): + ar = xr.DataArray( + x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs + ) + indarr = xr.DataArray(np.arange(x.size, dtype=np.intp), dims=["x"]) + + if np.isnan(minindex): + with pytest.raises(ValueError): + ar.argmin() + return + + expected0 = {"x": indarr[minindex]} + result0 = ar.argmin(...) + for key in expected0: + assert_identical(result0[key], expected0[key]) + + result1 = ar.argmin(..., keep_attrs=True) + expected1 = deepcopy(expected0) + for da in expected1.values(): + da.attrs = self.attrs + for key in expected1: + assert_identical(result1[key], expected1[key]) + + result2 = ar.argmin(..., skipna=False) + if nanindex is not None and ar.dtype.kind != "O": + expected2 = {"x": indarr.isel(x=nanindex, drop=True)} + expected2["x"].attrs = {} + else: + expected2 = expected0 + + for key in expected2: + assert_identical(result2[key], expected2[key]) + + @pytest.mark.filterwarnings( + "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" + ) + def test_argmax_dim(self, x, minindex, maxindex, nanindex): + ar = xr.DataArray( + x, dims=["x"], coords={"x": np.arange(x.size) * 4}, attrs=self.attrs + ) + indarr = xr.DataArray(np.arange(x.size, dtype=np.intp), dims=["x"]) + + if np.isnan(maxindex): + with pytest.raises(ValueError): + ar.argmax() + return + + expected0 = {"x": indarr[maxindex]} + result0 = ar.argmax(...) + for key in expected0: + assert_identical(result0[key], expected0[key]) + + result1 = ar.argmax(..., keep_attrs=True) + expected1 = deepcopy(expected0) + for da in expected1.values(): + da.attrs = self.attrs + for key in expected1: + assert_identical(result1[key], expected1[key]) + + result2 = ar.argmax(..., skipna=False) + if nanindex is not None and ar.dtype.kind != "O": + expected2 = {"x": indarr.isel(x=nanindex, drop=True)} + expected2["x"].attrs = {} + else: + expected2 = expected0 + + for key in expected2: + assert_identical(result2[key], expected2[key]) + @pytest.mark.parametrize( "x, minindex, maxindex, nanindex", @@ -4950,14 +5209,31 @@ def test_argmax(self, x, minindex, maxindex, nanindex): assert_identical(result3, expected2) - def test_idxmin(self, x, minindex, maxindex, nanindex): - ar0 = xr.DataArray( + @pytest.mark.parametrize("use_dask", [True, False]) + def test_idxmin(self, x, minindex, maxindex, nanindex, use_dask): + if use_dask and not has_dask: + pytest.skip("requires dask") + if use_dask and x.dtype.kind == "M": + pytest.xfail("dask operation 'argmin' breaks when dtype is datetime64 (M)") + + if x.dtype.kind == "O": + # TODO: nanops._nan_argminmax_object computes once to check for all-NaN slices. + max_computes = 1 + else: + max_computes = 0 + + ar0_raw = xr.DataArray( x, dims=["y", "x"], coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, attrs=self.attrs, ) + if use_dask: + ar0 = ar0_raw.chunk({}) + else: + ar0 = ar0_raw + assert_identical(ar0, ar0) # No dimension specified @@ -4988,15 +5264,18 @@ def test_idxmin(self, x, minindex, maxindex, nanindex): expected0.name = "x" # Default fill value (NaN) - result0 = ar0.idxmin(dim="x") + with raise_if_dask_computes(max_computes=max_computes): + result0 = ar0.idxmin(dim="x") assert_identical(result0, expected0) # Manually specify NaN fill_value - result1 = ar0.idxmin(dim="x", fill_value=np.NaN) + with raise_if_dask_computes(max_computes=max_computes): + result1 = ar0.idxmin(dim="x", fill_value=np.NaN) assert_identical(result1, expected0) # keep_attrs - result2 = ar0.idxmin(dim="x", keep_attrs=True) + with raise_if_dask_computes(max_computes=max_computes): + result2 = ar0.idxmin(dim="x", keep_attrs=True) expected2 = expected0.copy() expected2.attrs = self.attrs assert_identical(result2, expected2) @@ -5014,11 +5293,13 @@ def test_idxmin(self, x, minindex, maxindex, nanindex): expected3.name = "x" expected3.attrs = {} - result3 = ar0.idxmin(dim="x", skipna=False) + with raise_if_dask_computes(max_computes=max_computes): + result3 = ar0.idxmin(dim="x", skipna=False) assert_identical(result3, expected3) # fill_value should be ignored with skipna=False - result4 = ar0.idxmin(dim="x", skipna=False, fill_value=-100j) + with raise_if_dask_computes(max_computes=max_computes): + result4 = ar0.idxmin(dim="x", skipna=False, fill_value=-100j) assert_identical(result4, expected3) # Float fill_value @@ -5030,7 +5311,8 @@ def test_idxmin(self, x, minindex, maxindex, nanindex): expected5 = xr.concat(expected5, dim="y") expected5.name = "x" - result5 = ar0.idxmin(dim="x", fill_value=-1.1) + with raise_if_dask_computes(max_computes=max_computes): + result5 = ar0.idxmin(dim="x", fill_value=-1.1) assert_identical(result5, expected5) # Integer fill_value @@ -5042,7 +5324,8 @@ def test_idxmin(self, x, minindex, maxindex, nanindex): expected6 = xr.concat(expected6, dim="y") expected6.name = "x" - result6 = ar0.idxmin(dim="x", fill_value=-1) + with raise_if_dask_computes(max_computes=max_computes): + result6 = ar0.idxmin(dim="x", fill_value=-1) assert_identical(result6, expected6) # Complex fill_value @@ -5054,17 +5337,35 @@ def test_idxmin(self, x, minindex, maxindex, nanindex): expected7 = xr.concat(expected7, dim="y") expected7.name = "x" - result7 = ar0.idxmin(dim="x", fill_value=-5j) + with raise_if_dask_computes(max_computes=max_computes): + result7 = ar0.idxmin(dim="x", fill_value=-5j) assert_identical(result7, expected7) - def test_idxmax(self, x, minindex, maxindex, nanindex): - ar0 = xr.DataArray( + @pytest.mark.parametrize("use_dask", [True, False]) + def test_idxmax(self, x, minindex, maxindex, nanindex, use_dask): + if use_dask and not has_dask: + pytest.skip("requires dask") + if use_dask and x.dtype.kind == "M": + pytest.xfail("dask operation 'argmax' breaks when dtype is datetime64 (M)") + + if x.dtype.kind == "O": + # TODO: nanops._nan_argminmax_object computes once to check for all-NaN slices. + max_computes = 1 + else: + max_computes = 0 + + ar0_raw = xr.DataArray( x, dims=["y", "x"], coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, attrs=self.attrs, ) + if use_dask: + ar0 = ar0_raw.chunk({}) + else: + ar0 = ar0_raw + # No dimension specified with pytest.raises(ValueError): ar0.idxmax() @@ -5096,15 +5397,18 @@ def test_idxmax(self, x, minindex, maxindex, nanindex): expected0.name = "x" # Default fill value (NaN) - result0 = ar0.idxmax(dim="x") + with raise_if_dask_computes(max_computes=max_computes): + result0 = ar0.idxmax(dim="x") assert_identical(result0, expected0) # Manually specify NaN fill_value - result1 = ar0.idxmax(dim="x", fill_value=np.NaN) + with raise_if_dask_computes(max_computes=max_computes): + result1 = ar0.idxmax(dim="x", fill_value=np.NaN) assert_identical(result1, expected0) # keep_attrs - result2 = ar0.idxmax(dim="x", keep_attrs=True) + with raise_if_dask_computes(max_computes=max_computes): + result2 = ar0.idxmax(dim="x", keep_attrs=True) expected2 = expected0.copy() expected2.attrs = self.attrs assert_identical(result2, expected2) @@ -5122,11 +5426,13 @@ def test_idxmax(self, x, minindex, maxindex, nanindex): expected3.name = "x" expected3.attrs = {} - result3 = ar0.idxmax(dim="x", skipna=False) + with raise_if_dask_computes(max_computes=max_computes): + result3 = ar0.idxmax(dim="x", skipna=False) assert_identical(result3, expected3) # fill_value should be ignored with skipna=False - result4 = ar0.idxmax(dim="x", skipna=False, fill_value=-100j) + with raise_if_dask_computes(max_computes=max_computes): + result4 = ar0.idxmax(dim="x", skipna=False, fill_value=-100j) assert_identical(result4, expected3) # Float fill_value @@ -5138,7 +5444,8 @@ def test_idxmax(self, x, minindex, maxindex, nanindex): expected5 = xr.concat(expected5, dim="y") expected5.name = "x" - result5 = ar0.idxmax(dim="x", fill_value=-1.1) + with raise_if_dask_computes(max_computes=max_computes): + result5 = ar0.idxmax(dim="x", fill_value=-1.1) assert_identical(result5, expected5) # Integer fill_value @@ -5150,7 +5457,8 @@ def test_idxmax(self, x, minindex, maxindex, nanindex): expected6 = xr.concat(expected6, dim="y") expected6.name = "x" - result6 = ar0.idxmax(dim="x", fill_value=-1) + with raise_if_dask_computes(max_computes=max_computes): + result6 = ar0.idxmax(dim="x", fill_value=-1) assert_identical(result6, expected6) # Complex fill_value @@ -5162,9 +5470,774 @@ def test_idxmax(self, x, minindex, maxindex, nanindex): expected7 = xr.concat(expected7, dim="y") expected7.name = "x" - result7 = ar0.idxmax(dim="x", fill_value=-5j) + with raise_if_dask_computes(max_computes=max_computes): + result7 = ar0.idxmax(dim="x", fill_value=-5j) assert_identical(result7, expected7) + @pytest.mark.filterwarnings( + "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" + ) + def test_argmin_dim(self, x, minindex, maxindex, nanindex): + ar = xr.DataArray( + x, + dims=["y", "x"], + coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, + attrs=self.attrs, + ) + indarr = np.tile(np.arange(x.shape[1], dtype=np.intp), [x.shape[0], 1]) + indarr = xr.DataArray(indarr, dims=ar.dims, coords=ar.coords) + + if np.isnan(minindex).any(): + with pytest.raises(ValueError): + ar.argmin(dim="x") + return + + expected0 = [ + indarr.isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(minindex) + ] + expected0 = {"x": xr.concat(expected0, dim="y")} + + result0 = ar.argmin(dim=["x"]) + for key in expected0: + assert_identical(result0[key], expected0[key]) + + result1 = ar.argmin(dim=["x"], keep_attrs=True) + expected1 = deepcopy(expected0) + expected1["x"].attrs = self.attrs + for key in expected1: + assert_identical(result1[key], expected1[key]) + + minindex = [ + x if y is None or ar.dtype.kind == "O" else y + for x, y in zip(minindex, nanindex) + ] + expected2 = [ + indarr.isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(minindex) + ] + expected2 = {"x": xr.concat(expected2, dim="y")} + expected2["x"].attrs = {} + + result2 = ar.argmin(dim=["x"], skipna=False) + + for key in expected2: + assert_identical(result2[key], expected2[key]) + + result3 = ar.argmin(...) + min_xind = ar.isel(expected0).argmin() + expected3 = { + "y": DataArray(min_xind), + "x": DataArray(minindex[min_xind.item()]), + } + + for key in expected3: + assert_identical(result3[key], expected3[key]) + + @pytest.mark.filterwarnings( + "ignore:Behaviour of argmin/argmax with neither dim nor :DeprecationWarning" + ) + def test_argmax_dim(self, x, minindex, maxindex, nanindex): + ar = xr.DataArray( + x, + dims=["y", "x"], + coords={"x": np.arange(x.shape[1]) * 4, "y": 1 - np.arange(x.shape[0])}, + attrs=self.attrs, + ) + indarr = np.tile(np.arange(x.shape[1], dtype=np.intp), [x.shape[0], 1]) + indarr = xr.DataArray(indarr, dims=ar.dims, coords=ar.coords) + + if np.isnan(maxindex).any(): + with pytest.raises(ValueError): + ar.argmax(dim="x") + return + + expected0 = [ + indarr.isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(maxindex) + ] + expected0 = {"x": xr.concat(expected0, dim="y")} + + result0 = ar.argmax(dim=["x"]) + for key in expected0: + assert_identical(result0[key], expected0[key]) + + result1 = ar.argmax(dim=["x"], keep_attrs=True) + expected1 = deepcopy(expected0) + expected1["x"].attrs = self.attrs + for key in expected1: + assert_identical(result1[key], expected1[key]) + + maxindex = [ + x if y is None or ar.dtype.kind == "O" else y + for x, y in zip(maxindex, nanindex) + ] + expected2 = [ + indarr.isel(y=yi).isel(x=indi, drop=True) + for yi, indi in enumerate(maxindex) + ] + expected2 = {"x": xr.concat(expected2, dim="y")} + expected2["x"].attrs = {} + + result2 = ar.argmax(dim=["x"], skipna=False) + + for key in expected2: + assert_identical(result2[key], expected2[key]) + + result3 = ar.argmax(...) + max_xind = ar.isel(expected0).argmax() + expected3 = { + "y": DataArray(max_xind), + "x": DataArray(maxindex[max_xind.item()]), + } + + for key in expected3: + assert_identical(result3[key], expected3[key]) + + +@pytest.mark.parametrize( + "x, minindices_x, minindices_y, minindices_z, minindices_xy, " + "minindices_xz, minindices_yz, minindices_xyz, maxindices_x, " + "maxindices_y, maxindices_z, maxindices_xy, maxindices_xz, maxindices_yz, " + "maxindices_xyz, nanindices_x, nanindices_y, nanindices_z, nanindices_xy, " + "nanindices_xz, nanindices_yz, nanindices_xyz", + [ + ( + np.array( + [ + [[0, 1, 2, 0], [-2, -4, 2, 0]], + [[1, 1, 1, 1], [1, 1, 1, 1]], + [[0, 0, -10, 5], [20, 0, 0, 0]], + ] + ), + {"x": np.array([[0, 2, 2, 0], [0, 0, 2, 0]])}, + {"y": np.array([[1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 1]])}, + {"z": np.array([[0, 1], [0, 0], [2, 1]])}, + {"x": np.array([0, 0, 2, 0]), "y": np.array([1, 1, 0, 0])}, + {"x": np.array([2, 0]), "z": np.array([2, 1])}, + {"y": np.array([1, 0, 0]), "z": np.array([1, 0, 2])}, + {"x": np.array(2), "y": np.array(0), "z": np.array(2)}, + {"x": np.array([[1, 0, 0, 2], [2, 1, 0, 1]])}, + {"y": np.array([[0, 0, 0, 0], [0, 0, 0, 0], [1, 0, 1, 0]])}, + {"z": np.array([[2, 2], [0, 0], [3, 0]])}, + {"x": np.array([2, 0, 0, 2]), "y": np.array([1, 0, 0, 0])}, + {"x": np.array([2, 2]), "z": np.array([3, 0])}, + {"y": np.array([0, 0, 1]), "z": np.array([2, 0, 0])}, + {"x": np.array(2), "y": np.array(1), "z": np.array(0)}, + {"x": np.array([[None, None, None, None], [None, None, None, None]])}, + { + "y": np.array( + [ + [None, None, None, None], + [None, None, None, None], + [None, None, None, None], + ] + ) + }, + {"z": np.array([[None, None], [None, None], [None, None]])}, + { + "x": np.array([None, None, None, None]), + "y": np.array([None, None, None, None]), + }, + {"x": np.array([None, None]), "z": np.array([None, None])}, + {"y": np.array([None, None, None]), "z": np.array([None, None, None])}, + {"x": np.array(None), "y": np.array(None), "z": np.array(None)}, + ), + ( + np.array( + [ + [[2.0, 1.0, 2.0, 0.0], [-2.0, -4.0, 2.0, 0.0]], + [[-4.0, np.NaN, 2.0, np.NaN], [-2.0, -4.0, 2.0, 0.0]], + [[np.NaN] * 4, [np.NaN] * 4], + ] + ), + {"x": np.array([[1, 0, 0, 0], [0, 0, 0, 0]])}, + { + "y": np.array( + [[1, 1, 0, 0], [0, 1, 0, 1], [np.NaN, np.NaN, np.NaN, np.NaN]] + ) + }, + {"z": np.array([[3, 1], [0, 1], [np.NaN, np.NaN]])}, + {"x": np.array([1, 0, 0, 0]), "y": np.array([0, 1, 0, 0])}, + {"x": np.array([1, 0]), "z": np.array([0, 1])}, + {"y": np.array([1, 0, np.NaN]), "z": np.array([1, 0, np.NaN])}, + {"x": np.array(0), "y": np.array(1), "z": np.array(1)}, + {"x": np.array([[0, 0, 0, 0], [0, 0, 0, 0]])}, + { + "y": np.array( + [[0, 0, 0, 0], [1, 1, 0, 1], [np.NaN, np.NaN, np.NaN, np.NaN]] + ) + }, + {"z": np.array([[0, 2], [2, 2], [np.NaN, np.NaN]])}, + {"x": np.array([0, 0, 0, 0]), "y": np.array([0, 0, 0, 0])}, + {"x": np.array([0, 0]), "z": np.array([2, 2])}, + {"y": np.array([0, 0, np.NaN]), "z": np.array([0, 2, np.NaN])}, + {"x": np.array(0), "y": np.array(0), "z": np.array(0)}, + {"x": np.array([[2, 1, 2, 1], [2, 2, 2, 2]])}, + { + "y": np.array( + [[None, None, None, None], [None, 0, None, 0], [0, 0, 0, 0]] + ) + }, + {"z": np.array([[None, None], [1, None], [0, 0]])}, + {"x": np.array([2, 1, 2, 1]), "y": np.array([0, 0, 0, 0])}, + {"x": np.array([1, 2]), "z": np.array([1, 0])}, + {"y": np.array([None, 0, 0]), "z": np.array([None, 1, 0])}, + {"x": np.array(1), "y": np.array(0), "z": np.array(1)}, + ), + ( + np.array( + [ + [[2.0, 1.0, 2.0, 0.0], [-2.0, -4.0, 2.0, 0.0]], + [[-4.0, np.NaN, 2.0, np.NaN], [-2.0, -4.0, 2.0, 0.0]], + [[np.NaN] * 4, [np.NaN] * 4], + ] + ).astype("object"), + {"x": np.array([[1, 0, 0, 0], [0, 0, 0, 0]])}, + { + "y": np.array( + [[1, 1, 0, 0], [0, 1, 0, 1], [np.NaN, np.NaN, np.NaN, np.NaN]] + ) + }, + {"z": np.array([[3, 1], [0, 1], [np.NaN, np.NaN]])}, + {"x": np.array([1, 0, 0, 0]), "y": np.array([0, 1, 0, 0])}, + {"x": np.array([1, 0]), "z": np.array([0, 1])}, + {"y": np.array([1, 0, np.NaN]), "z": np.array([1, 0, np.NaN])}, + {"x": np.array(0), "y": np.array(1), "z": np.array(1)}, + {"x": np.array([[0, 0, 0, 0], [0, 0, 0, 0]])}, + { + "y": np.array( + [[0, 0, 0, 0], [1, 1, 0, 1], [np.NaN, np.NaN, np.NaN, np.NaN]] + ) + }, + {"z": np.array([[0, 2], [2, 2], [np.NaN, np.NaN]])}, + {"x": np.array([0, 0, 0, 0]), "y": np.array([0, 0, 0, 0])}, + {"x": np.array([0, 0]), "z": np.array([2, 2])}, + {"y": np.array([0, 0, np.NaN]), "z": np.array([0, 2, np.NaN])}, + {"x": np.array(0), "y": np.array(0), "z": np.array(0)}, + {"x": np.array([[2, 1, 2, 1], [2, 2, 2, 2]])}, + { + "y": np.array( + [[None, None, None, None], [None, 0, None, 0], [0, 0, 0, 0]] + ) + }, + {"z": np.array([[None, None], [1, None], [0, 0]])}, + {"x": np.array([2, 1, 2, 1]), "y": np.array([0, 0, 0, 0])}, + {"x": np.array([1, 2]), "z": np.array([1, 0])}, + {"y": np.array([None, 0, 0]), "z": np.array([None, 1, 0])}, + {"x": np.array(1), "y": np.array(0), "z": np.array(1)}, + ), + ( + np.array( + [ + [["2015-12-31", "2020-01-02"], ["2020-01-01", "2016-01-01"]], + [["2020-01-02", "2020-01-02"], ["2020-01-02", "2020-01-02"]], + [["1900-01-01", "1-02-03"], ["1900-01-02", "1-02-03"]], + ], + dtype="datetime64[ns]", + ), + {"x": np.array([[2, 2], [2, 2]])}, + {"y": np.array([[0, 1], [0, 0], [0, 0]])}, + {"z": np.array([[0, 1], [0, 0], [1, 1]])}, + {"x": np.array([2, 2]), "y": np.array([0, 0])}, + {"x": np.array([2, 2]), "z": np.array([1, 1])}, + {"y": np.array([0, 0, 0]), "z": np.array([0, 0, 1])}, + {"x": np.array(2), "y": np.array(0), "z": np.array(1)}, + {"x": np.array([[1, 0], [1, 1]])}, + {"y": np.array([[1, 0], [0, 0], [1, 0]])}, + {"z": np.array([[1, 0], [0, 0], [0, 0]])}, + {"x": np.array([1, 0]), "y": np.array([0, 0])}, + {"x": np.array([0, 1]), "z": np.array([1, 0])}, + {"y": np.array([0, 0, 1]), "z": np.array([1, 0, 0])}, + {"x": np.array(0), "y": np.array(0), "z": np.array(1)}, + {"x": np.array([[None, None], [None, None]])}, + {"y": np.array([[None, None], [None, None], [None, None]])}, + {"z": np.array([[None, None], [None, None], [None, None]])}, + {"x": np.array([None, None]), "y": np.array([None, None])}, + {"x": np.array([None, None]), "z": np.array([None, None])}, + {"y": np.array([None, None, None]), "z": np.array([None, None, None])}, + {"x": np.array(None), "y": np.array(None), "z": np.array(None)}, + ), + ], +) +class TestReduce3D(TestReduce): + def test_argmin_dim( + self, + x, + minindices_x, + minindices_y, + minindices_z, + minindices_xy, + minindices_xz, + minindices_yz, + minindices_xyz, + maxindices_x, + maxindices_y, + maxindices_z, + maxindices_xy, + maxindices_xz, + maxindices_yz, + maxindices_xyz, + nanindices_x, + nanindices_y, + nanindices_z, + nanindices_xy, + nanindices_xz, + nanindices_yz, + nanindices_xyz, + ): + + ar = xr.DataArray( + x, + dims=["x", "y", "z"], + coords={ + "x": np.arange(x.shape[0]) * 4, + "y": 1 - np.arange(x.shape[1]), + "z": 2 + 3 * np.arange(x.shape[2]), + }, + attrs=self.attrs, + ) + xindarr = np.tile( + np.arange(x.shape[0], dtype=np.intp)[:, np.newaxis, np.newaxis], + [1, x.shape[1], x.shape[2]], + ) + xindarr = xr.DataArray(xindarr, dims=ar.dims, coords=ar.coords) + yindarr = np.tile( + np.arange(x.shape[1], dtype=np.intp)[np.newaxis, :, np.newaxis], + [x.shape[0], 1, x.shape[2]], + ) + yindarr = xr.DataArray(yindarr, dims=ar.dims, coords=ar.coords) + zindarr = np.tile( + np.arange(x.shape[2], dtype=np.intp)[np.newaxis, np.newaxis, :], + [x.shape[0], x.shape[1], 1], + ) + zindarr = xr.DataArray(zindarr, dims=ar.dims, coords=ar.coords) + + for inds in [ + minindices_x, + minindices_y, + minindices_z, + minindices_xy, + minindices_xz, + minindices_yz, + minindices_xyz, + ]: + if np.array([np.isnan(i) for i in inds.values()]).any(): + with pytest.raises(ValueError): + ar.argmin(dim=[d for d in inds]) + return + + result0 = ar.argmin(dim=["x"]) + expected0 = { + key: xr.DataArray(value, dims=("y", "z")) + for key, value in minindices_x.items() + } + for key in expected0: + assert_identical(result0[key].drop_vars(["y", "z"]), expected0[key]) + + result1 = ar.argmin(dim=["y"]) + expected1 = { + key: xr.DataArray(value, dims=("x", "z")) + for key, value in minindices_y.items() + } + for key in expected1: + assert_identical(result1[key].drop_vars(["x", "z"]), expected1[key]) + + result2 = ar.argmin(dim=["z"]) + expected2 = { + key: xr.DataArray(value, dims=("x", "y")) + for key, value in minindices_z.items() + } + for key in expected2: + assert_identical(result2[key].drop_vars(["x", "y"]), expected2[key]) + + result3 = ar.argmin(dim=("x", "y")) + expected3 = { + key: xr.DataArray(value, dims=("z")) for key, value in minindices_xy.items() + } + for key in expected3: + assert_identical(result3[key].drop_vars("z"), expected3[key]) + + result4 = ar.argmin(dim=("x", "z")) + expected4 = { + key: xr.DataArray(value, dims=("y")) for key, value in minindices_xz.items() + } + for key in expected4: + assert_identical(result4[key].drop_vars("y"), expected4[key]) + + result5 = ar.argmin(dim=("y", "z")) + expected5 = { + key: xr.DataArray(value, dims=("x")) for key, value in minindices_yz.items() + } + for key in expected5: + assert_identical(result5[key].drop_vars("x"), expected5[key]) + + result6 = ar.argmin(...) + expected6 = {key: xr.DataArray(value) for key, value in minindices_xyz.items()} + for key in expected6: + assert_identical(result6[key], expected6[key]) + + minindices_x = { + key: xr.where( + nanindices_x[key] == None, # noqa: E711 + minindices_x[key], + nanindices_x[key], + ) + for key in minindices_x + } + expected7 = { + key: xr.DataArray(value, dims=("y", "z")) + for key, value in minindices_x.items() + } + + result7 = ar.argmin(dim=["x"], skipna=False) + for key in expected7: + assert_identical(result7[key].drop_vars(["y", "z"]), expected7[key]) + + minindices_y = { + key: xr.where( + nanindices_y[key] == None, # noqa: E711 + minindices_y[key], + nanindices_y[key], + ) + for key in minindices_y + } + expected8 = { + key: xr.DataArray(value, dims=("x", "z")) + for key, value in minindices_y.items() + } + + result8 = ar.argmin(dim=["y"], skipna=False) + for key in expected8: + assert_identical(result8[key].drop_vars(["x", "z"]), expected8[key]) + + minindices_z = { + key: xr.where( + nanindices_z[key] == None, # noqa: E711 + minindices_z[key], + nanindices_z[key], + ) + for key in minindices_z + } + expected9 = { + key: xr.DataArray(value, dims=("x", "y")) + for key, value in minindices_z.items() + } + + result9 = ar.argmin(dim=["z"], skipna=False) + for key in expected9: + assert_identical(result9[key].drop_vars(["x", "y"]), expected9[key]) + + minindices_xy = { + key: xr.where( + nanindices_xy[key] == None, # noqa: E711 + minindices_xy[key], + nanindices_xy[key], + ) + for key in minindices_xy + } + expected10 = { + key: xr.DataArray(value, dims="z") for key, value in minindices_xy.items() + } + + result10 = ar.argmin(dim=("x", "y"), skipna=False) + for key in expected10: + assert_identical(result10[key].drop_vars("z"), expected10[key]) + + minindices_xz = { + key: xr.where( + nanindices_xz[key] == None, # noqa: E711 + minindices_xz[key], + nanindices_xz[key], + ) + for key in minindices_xz + } + expected11 = { + key: xr.DataArray(value, dims="y") for key, value in minindices_xz.items() + } + + result11 = ar.argmin(dim=("x", "z"), skipna=False) + for key in expected11: + assert_identical(result11[key].drop_vars("y"), expected11[key]) + + minindices_yz = { + key: xr.where( + nanindices_yz[key] == None, # noqa: E711 + minindices_yz[key], + nanindices_yz[key], + ) + for key in minindices_yz + } + expected12 = { + key: xr.DataArray(value, dims="x") for key, value in minindices_yz.items() + } + + result12 = ar.argmin(dim=("y", "z"), skipna=False) + for key in expected12: + assert_identical(result12[key].drop_vars("x"), expected12[key]) + + minindices_xyz = { + key: xr.where( + nanindices_xyz[key] == None, # noqa: E711 + minindices_xyz[key], + nanindices_xyz[key], + ) + for key in minindices_xyz + } + expected13 = {key: xr.DataArray(value) for key, value in minindices_xyz.items()} + + result13 = ar.argmin(..., skipna=False) + for key in expected13: + assert_identical(result13[key], expected13[key]) + + def test_argmax_dim( + self, + x, + minindices_x, + minindices_y, + minindices_z, + minindices_xy, + minindices_xz, + minindices_yz, + minindices_xyz, + maxindices_x, + maxindices_y, + maxindices_z, + maxindices_xy, + maxindices_xz, + maxindices_yz, + maxindices_xyz, + nanindices_x, + nanindices_y, + nanindices_z, + nanindices_xy, + nanindices_xz, + nanindices_yz, + nanindices_xyz, + ): + + ar = xr.DataArray( + x, + dims=["x", "y", "z"], + coords={ + "x": np.arange(x.shape[0]) * 4, + "y": 1 - np.arange(x.shape[1]), + "z": 2 + 3 * np.arange(x.shape[2]), + }, + attrs=self.attrs, + ) + xindarr = np.tile( + np.arange(x.shape[0], dtype=np.intp)[:, np.newaxis, np.newaxis], + [1, x.shape[1], x.shape[2]], + ) + xindarr = xr.DataArray(xindarr, dims=ar.dims, coords=ar.coords) + yindarr = np.tile( + np.arange(x.shape[1], dtype=np.intp)[np.newaxis, :, np.newaxis], + [x.shape[0], 1, x.shape[2]], + ) + yindarr = xr.DataArray(yindarr, dims=ar.dims, coords=ar.coords) + zindarr = np.tile( + np.arange(x.shape[2], dtype=np.intp)[np.newaxis, np.newaxis, :], + [x.shape[0], x.shape[1], 1], + ) + zindarr = xr.DataArray(zindarr, dims=ar.dims, coords=ar.coords) + + for inds in [ + maxindices_x, + maxindices_y, + maxindices_z, + maxindices_xy, + maxindices_xz, + maxindices_yz, + maxindices_xyz, + ]: + if np.array([np.isnan(i) for i in inds.values()]).any(): + with pytest.raises(ValueError): + ar.argmax(dim=[d for d in inds]) + return + + result0 = ar.argmax(dim=["x"]) + expected0 = { + key: xr.DataArray(value, dims=("y", "z")) + for key, value in maxindices_x.items() + } + for key in expected0: + assert_identical(result0[key].drop_vars(["y", "z"]), expected0[key]) + + result1 = ar.argmax(dim=["y"]) + expected1 = { + key: xr.DataArray(value, dims=("x", "z")) + for key, value in maxindices_y.items() + } + for key in expected1: + assert_identical(result1[key].drop_vars(["x", "z"]), expected1[key]) + + result2 = ar.argmax(dim=["z"]) + expected2 = { + key: xr.DataArray(value, dims=("x", "y")) + for key, value in maxindices_z.items() + } + for key in expected2: + assert_identical(result2[key].drop_vars(["x", "y"]), expected2[key]) + + result3 = ar.argmax(dim=("x", "y")) + expected3 = { + key: xr.DataArray(value, dims=("z")) for key, value in maxindices_xy.items() + } + for key in expected3: + assert_identical(result3[key].drop_vars("z"), expected3[key]) + + result4 = ar.argmax(dim=("x", "z")) + expected4 = { + key: xr.DataArray(value, dims=("y")) for key, value in maxindices_xz.items() + } + for key in expected4: + assert_identical(result4[key].drop_vars("y"), expected4[key]) + + result5 = ar.argmax(dim=("y", "z")) + expected5 = { + key: xr.DataArray(value, dims=("x")) for key, value in maxindices_yz.items() + } + for key in expected5: + assert_identical(result5[key].drop_vars("x"), expected5[key]) + + result6 = ar.argmax(...) + expected6 = {key: xr.DataArray(value) for key, value in maxindices_xyz.items()} + for key in expected6: + assert_identical(result6[key], expected6[key]) + + maxindices_x = { + key: xr.where( + nanindices_x[key] == None, # noqa: E711 + maxindices_x[key], + nanindices_x[key], + ) + for key in maxindices_x + } + expected7 = { + key: xr.DataArray(value, dims=("y", "z")) + for key, value in maxindices_x.items() + } + + result7 = ar.argmax(dim=["x"], skipna=False) + for key in expected7: + assert_identical(result7[key].drop_vars(["y", "z"]), expected7[key]) + + maxindices_y = { + key: xr.where( + nanindices_y[key] == None, # noqa: E711 + maxindices_y[key], + nanindices_y[key], + ) + for key in maxindices_y + } + expected8 = { + key: xr.DataArray(value, dims=("x", "z")) + for key, value in maxindices_y.items() + } + + result8 = ar.argmax(dim=["y"], skipna=False) + for key in expected8: + assert_identical(result8[key].drop_vars(["x", "z"]), expected8[key]) + + maxindices_z = { + key: xr.where( + nanindices_z[key] == None, # noqa: E711 + maxindices_z[key], + nanindices_z[key], + ) + for key in maxindices_z + } + expected9 = { + key: xr.DataArray(value, dims=("x", "y")) + for key, value in maxindices_z.items() + } + + result9 = ar.argmax(dim=["z"], skipna=False) + for key in expected9: + assert_identical(result9[key].drop_vars(["x", "y"]), expected9[key]) + + maxindices_xy = { + key: xr.where( + nanindices_xy[key] == None, # noqa: E711 + maxindices_xy[key], + nanindices_xy[key], + ) + for key in maxindices_xy + } + expected10 = { + key: xr.DataArray(value, dims="z") for key, value in maxindices_xy.items() + } + + result10 = ar.argmax(dim=("x", "y"), skipna=False) + for key in expected10: + assert_identical(result10[key].drop_vars("z"), expected10[key]) + + maxindices_xz = { + key: xr.where( + nanindices_xz[key] == None, # noqa: E711 + maxindices_xz[key], + nanindices_xz[key], + ) + for key in maxindices_xz + } + expected11 = { + key: xr.DataArray(value, dims="y") for key, value in maxindices_xz.items() + } + + result11 = ar.argmax(dim=("x", "z"), skipna=False) + for key in expected11: + assert_identical(result11[key].drop_vars("y"), expected11[key]) + + maxindices_yz = { + key: xr.where( + nanindices_yz[key] == None, # noqa: E711 + maxindices_yz[key], + nanindices_yz[key], + ) + for key in maxindices_yz + } + expected12 = { + key: xr.DataArray(value, dims="x") for key, value in maxindices_yz.items() + } + + result12 = ar.argmax(dim=("y", "z"), skipna=False) + for key in expected12: + assert_identical(result12[key].drop_vars("x"), expected12[key]) + + maxindices_xyz = { + key: xr.where( + nanindices_xyz[key] == None, # noqa: E711 + maxindices_xyz[key], + nanindices_xyz[key], + ) + for key in maxindices_xyz + } + expected13 = {key: xr.DataArray(value) for key, value in maxindices_xyz.items()} + + result13 = ar.argmax(..., skipna=False) + for key in expected13: + assert_identical(result13[key], expected13[key]) + + +class TestReduceND(TestReduce): + @pytest.mark.parametrize("op", ["idxmin", "idxmax"]) + @pytest.mark.parametrize("ndim", [3, 5]) + def test_idxminmax_dask(self, op, ndim): + if not has_dask: + pytest.skip("requires dask") + + ar0_raw = xr.DataArray( + np.random.random_sample(size=[10] * ndim), + dims=[i for i in "abcdefghij"[: ndim - 1]] + ["x"], + coords={"x": np.arange(10)}, + attrs=self.attrs, + ) + + ar0_dsk = ar0_raw.chunk({}) + # Assert idx is the same with dask and without + assert_equal(getattr(ar0_dsk, op)(dim="x"), getattr(ar0_raw, op)(dim="x")) + @pytest.fixture(params=[1]) def da(request): @@ -5199,7 +6272,6 @@ def da_dask(seed=123): @pytest.mark.parametrize("da", ("repeating_ints",), indirect=True) def test_isin(da): - expected = DataArray( np.asarray([[0, 0, 0], [1, 0, 0]]), dims=list("yx"), @@ -5218,13 +6290,39 @@ def test_isin(da): assert_equal(result, expected) +def test_coarsen_keep_attrs(): + _attrs = {"units": "test", "long_name": "testing"} + + da = xr.DataArray( + np.linspace(0, 364, num=364), + dims="time", + coords={"time": pd.date_range("15/12/1999", periods=364)}, + attrs=_attrs, + ) + + da2 = da.copy(deep=True) + + # Test dropped attrs + dat = da.coarsen(time=3, boundary="trim").mean() + assert dat.attrs == {} + + # Test kept attrs using dataset keyword + dat = da.coarsen(time=3, boundary="trim", keep_attrs=True).mean() + assert dat.attrs == _attrs + + # Test kept attrs using global option + with xr.set_options(keep_attrs=True): + dat = da.coarsen(time=3, boundary="trim").mean() + assert dat.attrs == _attrs + + # Test kept attrs in original object + xr.testing.assert_identical(da, da2) + + @pytest.mark.parametrize("da", (1, 2), indirect=True) def test_rolling_iter(da): - rolling_obj = da.rolling(time=7) - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "Mean of empty slice") - rolling_obj_mean = rolling_obj.mean() + rolling_obj_mean = rolling_obj.mean() assert len(rolling_obj.window_labels) == len(da["time"]) assert_identical(rolling_obj.window_labels, da["time"]) @@ -5232,10 +6330,8 @@ def test_rolling_iter(da): for i, (label, window_da) in enumerate(rolling_obj): assert label == da["time"].isel(time=i) - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "Mean of empty slice") - actual = rolling_obj_mean.isel(time=i) - expected = window_da.mean("time") + actual = rolling_obj_mean.isel(time=i) + expected = window_da.mean("time") # TODO add assert_allclose_with_nan, which compares nan position # as well as the closeness of the values. @@ -5247,6 +6343,16 @@ def test_rolling_iter(da): ) +@pytest.mark.parametrize("da", (1,), indirect=True) +def test_rolling_repr(da): + rolling_obj = da.rolling(time=7) + assert repr(rolling_obj) == "DataArrayRolling [time->7]" + rolling_obj = da.rolling(time=7, center=True) + assert repr(rolling_obj) == "DataArrayRolling [time->7(center)]" + rolling_obj = da.rolling(time=7, x=3, center=True) + assert repr(rolling_obj) == "DataArrayRolling [time->7(center),x->3(center)]" + + def test_rolling_doc(da): rolling_obj = da.rolling(time=7) @@ -5260,10 +6366,9 @@ def test_rolling_properties(da): assert rolling_obj.obj.get_axis_num("time") == 1 # catching invalid args - with pytest.raises(ValueError, match="exactly one dim/window should"): - da.rolling(time=7, x=2) with pytest.raises(ValueError, match="window must be > 0"): da.rolling(time=-2) + with pytest.raises(ValueError, match="min_periods must be greater than zero"): da.rolling(time=2, min_periods=0) @@ -5284,7 +6389,7 @@ def test_rolling_wrapped_bottleneck(da, name, center, min_periods): ) assert_array_equal(actual.values, expected) - with pytest.warns(DeprecationWarning, match="Reductions will be applied"): + with pytest.warns(DeprecationWarning, match="Reductions are applied"): getattr(rolling_obj, name)(dim="time") # Test center @@ -5303,7 +6408,7 @@ def test_rolling_wrapped_dask(da_dask, name, center, min_periods, window): rolling_obj = da_dask.rolling(time=window, min_periods=min_periods, center=center) actual = getattr(rolling_obj, name)().load() if name != "count": - with pytest.warns(DeprecationWarning, match="Reductions will be applied"): + with pytest.warns(DeprecationWarning, match="Reductions are applied"): getattr(rolling_obj, name)(dim="time") # numpy version rolling_obj = da_dask.load().rolling( @@ -5390,7 +6495,6 @@ def test_rolling_construct(center, window): @pytest.mark.parametrize("window", (1, 2, 3, 4)) @pytest.mark.parametrize("name", ("sum", "mean", "std", "max")) def test_rolling_reduce(da, center, min_periods, window, name): - if min_periods is not None and window < min_periods: min_periods = window @@ -5429,7 +6533,6 @@ def test_rolling_reduce_nonnumeric(center, min_periods, window, name): def test_rolling_count_correct(): - da = DataArray([0, np.nan, 1, 2, np.nan, 3, 4, 5, np.nan, 6, 7], dims="time") kwargs = [ @@ -5466,12 +6569,142 @@ def test_rolling_count_correct(): assert_equal(result, expected) +@pytest.mark.parametrize("da", (1,), indirect=True) +@pytest.mark.parametrize("center", (True, False)) +@pytest.mark.parametrize("min_periods", (None, 1)) +@pytest.mark.parametrize("name", ("sum", "mean", "max")) +def test_ndrolling_reduce(da, center, min_periods, name): + rolling_obj = da.rolling(time=3, x=2, center=center, min_periods=min_periods) + + actual = getattr(rolling_obj, name)() + expected = getattr( + getattr( + da.rolling(time=3, center=center, min_periods=min_periods), name + )().rolling(x=2, center=center, min_periods=min_periods), + name, + )() + + assert_allclose(actual, expected) + assert actual.dims == expected.dims + + +@pytest.mark.parametrize("center", (True, False, (True, False))) +@pytest.mark.parametrize("fill_value", (np.nan, 0.0)) +def test_ndrolling_construct(center, fill_value): + da = DataArray( + np.arange(5 * 6 * 7).reshape(5, 6, 7).astype(float), + dims=["x", "y", "z"], + coords={"x": ["a", "b", "c", "d", "e"], "y": np.arange(6)}, + ) + actual = da.rolling(x=3, z=2, center=center).construct( + x="x1", z="z1", fill_value=fill_value + ) + if not isinstance(center, tuple): + center = (center, center) + expected = ( + da.rolling(x=3, center=center[0]) + .construct(x="x1", fill_value=fill_value) + .rolling(z=2, center=center[1]) + .construct(z="z1", fill_value=fill_value) + ) + assert_allclose(actual, expected) + + +@pytest.mark.parametrize( + "funcname, argument", + [ + ("reduce", (np.mean,)), + ("mean", ()), + ("construct", ("window_dim",)), + ("count", ()), + ], +) +def test_rolling_keep_attrs(funcname, argument): + attrs_da = {"da_attr": "test"} + + data = np.linspace(10, 15, 100) + coords = np.linspace(1, 10, 100) + + da = DataArray( + data, dims=("coord"), coords={"coord": coords}, attrs=attrs_da, name="name" + ) + + # attrs are now kept per default + func = getattr(da.rolling(dim={"coord": 5}), funcname) + result = func(*argument) + assert result.attrs == attrs_da + assert result.name == "name" + + # discard attrs + func = getattr(da.rolling(dim={"coord": 5}), funcname) + result = func(*argument, keep_attrs=False) + assert result.attrs == {} + assert result.name == "name" + + # test discard attrs using global option + func = getattr(da.rolling(dim={"coord": 5}), funcname) + with set_options(keep_attrs=False): + result = func(*argument) + assert result.attrs == {} + assert result.name == "name" + + # keyword takes precedence over global option + func = getattr(da.rolling(dim={"coord": 5}), funcname) + with set_options(keep_attrs=False): + result = func(*argument, keep_attrs=True) + assert result.attrs == attrs_da + assert result.name == "name" + + func = getattr(da.rolling(dim={"coord": 5}), funcname) + with set_options(keep_attrs=True): + result = func(*argument, keep_attrs=False) + assert result.attrs == {} + assert result.name == "name" + + +def test_rolling_keep_attrs_deprecated(): + attrs_da = {"da_attr": "test"} + + data = np.linspace(10, 15, 100) + coords = np.linspace(1, 10, 100) + + da = DataArray( + data, + dims=("coord"), + coords={"coord": coords}, + attrs=attrs_da, + ) + + # deprecated option + with pytest.warns( + FutureWarning, match="Passing ``keep_attrs`` to ``rolling`` is deprecated" + ): + result = da.rolling(dim={"coord": 5}, keep_attrs=False).construct("window_dim") + + assert result.attrs == {} + + # the keep_attrs in the reduction function takes precedence + with pytest.warns( + FutureWarning, match="Passing ``keep_attrs`` to ``rolling`` is deprecated" + ): + result = da.rolling(dim={"coord": 5}, keep_attrs=True).construct( + "window_dim", keep_attrs=False + ) + + assert result.attrs == {} + + def test_raise_no_warning_for_nan_in_binary_ops(): with pytest.warns(None) as record: xr.DataArray([1, 2, np.NaN]) > 0 assert len(record) == 0 +@pytest.mark.filterwarnings("error") +def test_no_warning_for_all_nan(): + _ = xr.DataArray([np.NaN, np.NaN]).mean() + + def test_name_in_masking(): name = "RingoStarr" da = xr.DataArray(range(10), coords=[("x", range(10))], name=name) @@ -5484,8 +6717,8 @@ def test_name_in_masking(): class TestIrisConversion: @requires_iris def test_to_and_from_iris(self): - import iris import cf_units # iris requirement + import iris # to iris coord_dict = {} @@ -5555,9 +6788,9 @@ def test_to_and_from_iris(self): @requires_iris @requires_dask def test_to_and_from_iris_dask(self): + import cf_units # iris requirement import dask.array as da import iris - import cf_units # iris requirement coord_dict = {} coord_dict["distance"] = ("distance", [-2, 2], {"units": "meters"}) @@ -5690,8 +6923,8 @@ def test_da_name_from_cube(self, std_name, long_name, var_name, name, attrs): ], ) def test_da_coord_name_from_cube(self, std_name, long_name, var_name, name, attrs): - from iris.cube import Cube from iris.coords import DimCoord + from iris.cube import Cube latitude = DimCoord( [-90, 0, 90], standard_name=std_name, var_name=var_name, long_name=long_name @@ -5704,8 +6937,8 @@ def test_da_coord_name_from_cube(self, std_name, long_name, var_name, name, attr @requires_iris def test_prevent_duplicate_coord_names(self): - from iris.cube import Cube from iris.coords import DimCoord + from iris.cube import Cube # Iris enforces unique coordinate names. Because we use a different # name resolution order a valid iris Cube with coords that have the @@ -5726,8 +6959,8 @@ def test_prevent_duplicate_coord_names(self): [["IA", "IL", "IN"], [0, 2, 1]], # non-numeric values # non-monotonic values ) def test_fallback_to_iris_AuxCoord(self, coord_values): - from iris.cube import Cube from iris.coords import AuxCoord + from iris.cube import Cube data = [0, 0, 0] da = xr.DataArray(data, coords=[coord_values], dims=["space"]) @@ -5761,6 +6994,34 @@ def test_rolling_exp(da, dim, window_type, window): assert_allclose(expected.variable, result.variable) +@requires_numbagg +def test_rolling_exp_keep_attrs(da): + attrs = {"attrs": "da"} + da.attrs = attrs + + # attrs are kept per default + result = da.rolling_exp(time=10).mean() + assert result.attrs == attrs + + # discard attrs + result = da.rolling_exp(time=10).mean(keep_attrs=False) + assert result.attrs == {} + + # test discard attrs using global option + with set_options(keep_attrs=False): + result = da.rolling_exp(time=10).mean() + assert result.attrs == {} + + # keyword takes precedence over global option + with set_options(keep_attrs=False): + result = da.rolling_exp(time=10).mean(keep_attrs=True) + assert result.attrs == attrs + + with set_options(keep_attrs=True): + result = da.rolling_exp(time=10).mean(keep_attrs=False) + assert result.attrs == {} + + def test_no_dict(): d = DataArray() with pytest.raises(AttributeError): @@ -5816,3 +7077,9 @@ def test_delete_coords(): assert a1.dims == ("y", "x") assert set(a0.coords.keys()) == {"x", "y"} assert set(a1.coords.keys()) == {"x"} + + +def test_deepcopy_obj_array(): + x0 = DataArray(np.array([object()])) + x1 = deepcopy(x0) + assert x0.values[0] is not x1.values[0] diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index cc78feb40db..d385ec3a48b 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -32,7 +32,6 @@ from . import ( InaccessibleArray, - LooseVersion, UnexpectedDataAccess, assert_allclose, assert_array_equal, @@ -55,6 +54,11 @@ except ImportError: pass +pytestmark = [ + pytest.mark.filterwarnings("error:Mean of empty slice"), + pytest.mark.filterwarnings("error:All-NaN (slice|axis) encountered"), +] + def create_test_data(seed=None): rs = np.random.RandomState(seed) @@ -99,8 +103,8 @@ def create_append_test_data(seed=None): datetime_var_to_append = np.array( ["2019-01-04", "2019-01-05"], dtype="datetime64[s]" ) - bool_var = np.array([True, False, True], dtype=np.bool) - bool_var_to_append = np.array([False, True], dtype=np.bool) + bool_var = np.array([True, False, True], dtype=bool) + bool_var_to_append = np.array([False, True], dtype=bool) ds = xr.Dataset( data_vars={ @@ -496,16 +500,11 @@ def test_constructor_pandas_single(self): DataArray(np.random.rand(4, 3), dims=["a", "b"]), # df ] - if LooseVersion(pd.__version__) < "0.25.0": - das.append(DataArray(np.random.rand(4, 3, 2), dims=["a", "b", "c"])) - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", r"\W*Panel is deprecated") - for a in das: - pandas_obj = a.to_pandas() - ds_based_on_pandas = Dataset(pandas_obj) - for dim in ds_based_on_pandas.data_vars: - assert_array_equal(ds_based_on_pandas[dim], pandas_obj[dim]) + for a in das: + pandas_obj = a.to_pandas() + ds_based_on_pandas = Dataset(pandas_obj) + for dim in ds_based_on_pandas.data_vars: + assert_array_equal(ds_based_on_pandas[dim], pandas_obj[dim]) def test_constructor_compat(self): data = {"x": DataArray(0, coords={"y": 1}), "y": ("z", [1, 1, 1])} @@ -1025,14 +1024,14 @@ def test_isel(self): data.isel(not_a_dim=slice(0, 2)) with raises_regex( ValueError, - r"dimensions {'not_a_dim'} do not exist. Expected " + r"Dimensions {'not_a_dim'} do not exist. Expected " r"one or more of " r"[\w\W]*'time'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*", ): data.isel(not_a_dim=slice(0, 2)) with pytest.warns( UserWarning, - match=r"dimensions {'not_a_dim'} do not exist. " + match=r"Dimensions {'not_a_dim'} do not exist. " r"Expected one or more of " r"[\w\W]*'time'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*'dim\d'[\w\W]*", ): @@ -1419,7 +1418,7 @@ def test_sel_dataarray_mindex(self): with raises_regex( ValueError, - "Vectorized selection is " "not available along MultiIndex variable:" " x", + "Vectorized selection is not available along MultiIndex variable: x", ): mds.sel( x=xr.DataArray( @@ -1828,6 +1827,25 @@ def test_reindex(self): actual = data.reindex(dim2=data["dim2"][:5:-1]) assert_identical(actual, expected) + # multiple fill values + expected = data.reindex(dim2=[0.1, 2.1, 3.1, 4.1]).assign( + var1=lambda ds: ds.var1.copy(data=[[-10, -10, -10, -10]] * len(ds.dim1)), + var2=lambda ds: ds.var2.copy(data=[[-20, -20, -20, -20]] * len(ds.dim1)), + ) + actual = data.reindex( + dim2=[0.1, 2.1, 3.1, 4.1], fill_value={"var1": -10, "var2": -20} + ) + assert_identical(actual, expected) + # use the default value + expected = data.reindex(dim2=[0.1, 2.1, 3.1, 4.1]).assign( + var1=lambda ds: ds.var1.copy(data=[[-10, -10, -10, -10]] * len(ds.dim1)), + var2=lambda ds: ds.var2.copy( + data=[[np.nan, np.nan, np.nan, np.nan]] * len(ds.dim1) + ), + ) + actual = data.reindex(dim2=[0.1, 2.1, 3.1, 4.1], fill_value={"var1": -10}) + assert_identical(actual, expected) + # regression test for #279 expected = Dataset({"x": ("time", np.random.randn(5))}, {"time": range(5)}) time2 = DataArray(np.arange(5), dims="time2") @@ -1884,32 +1902,64 @@ def test_reindex_method(self): actual = ds.reindex_like(alt, method="pad") assert_identical(expected, actual) - @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {"x": 2, "z": 1}]) def test_reindex_fill_value(self, fill_value): - ds = Dataset({"x": ("y", [10, 20]), "y": [0, 1]}) + ds = Dataset({"x": ("y", [10, 20]), "z": ("y", [-20, -10]), "y": [0, 1]}) y = [0, 1, 2] actual = ds.reindex(y=y, fill_value=fill_value) if fill_value == dtypes.NA: # if we supply the default, we expect the missing value for a # float array - fill_value = np.nan - expected = Dataset({"x": ("y", [10, 20, fill_value]), "y": y}) + fill_value_x = fill_value_z = np.nan + elif isinstance(fill_value, dict): + fill_value_x = fill_value["x"] + fill_value_z = fill_value["z"] + else: + fill_value_x = fill_value_z = fill_value + expected = Dataset( + { + "x": ("y", [10, 20, fill_value_x]), + "z": ("y", [-20, -10, fill_value_z]), + "y": y, + } + ) assert_identical(expected, actual) - @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {"x": 2, "z": 1}]) def test_reindex_like_fill_value(self, fill_value): - ds = Dataset({"x": ("y", [10, 20]), "y": [0, 1]}) + ds = Dataset({"x": ("y", [10, 20]), "z": ("y", [-20, -10]), "y": [0, 1]}) y = [0, 1, 2] alt = Dataset({"y": y}) actual = ds.reindex_like(alt, fill_value=fill_value) if fill_value == dtypes.NA: # if we supply the default, we expect the missing value for a # float array - fill_value = np.nan - expected = Dataset({"x": ("y", [10, 20, fill_value]), "y": y}) + fill_value_x = fill_value_z = np.nan + elif isinstance(fill_value, dict): + fill_value_x = fill_value["x"] + fill_value_z = fill_value["z"] + else: + fill_value_x = fill_value_z = fill_value + expected = Dataset( + { + "x": ("y", [10, 20, fill_value_x]), + "z": ("y", [-20, -10, fill_value_z]), + "y": y, + } + ) assert_identical(expected, actual) - @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) + @pytest.mark.parametrize("dtype", [str, bytes]) + def test_reindex_str_dtype(self, dtype): + data = Dataset({"data": ("x", [1, 2]), "x": np.array(["a", "b"], dtype=dtype)}) + + actual = data.reindex(x=data.x) + expected = data + + assert_identical(expected, actual) + assert actual.x.dtype == expected.x.dtype + + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {"foo": 2, "bar": 1}]) def test_align_fill_value(self, fill_value): x = Dataset({"foo": DataArray([1, 2], dims=["x"], coords={"x": [1, 2]})}) y = Dataset({"bar": DataArray([1, 2], dims=["x"], coords={"x": [1, 3]})}) @@ -1917,13 +1967,26 @@ def test_align_fill_value(self, fill_value): if fill_value == dtypes.NA: # if we supply the default, we expect the missing value for a # float array - fill_value = np.nan + fill_value_foo = fill_value_bar = np.nan + elif isinstance(fill_value, dict): + fill_value_foo = fill_value["foo"] + fill_value_bar = fill_value["bar"] + else: + fill_value_foo = fill_value_bar = fill_value expected_x2 = Dataset( - {"foo": DataArray([1, 2, fill_value], dims=["x"], coords={"x": [1, 2, 3]})} + { + "foo": DataArray( + [1, 2, fill_value_foo], dims=["x"], coords={"x": [1, 2, 3]} + ) + } ) expected_y2 = Dataset( - {"bar": DataArray([1, fill_value, 2], dims=["x"], coords={"x": [1, 2, 3]})} + { + "bar": DataArray( + [1, fill_value_bar, 2], dims=["x"], coords={"x": [1, 2, 3]} + ) + } ) assert_identical(expected_x2, x2) assert_identical(expected_y2, y2) @@ -2074,12 +2137,29 @@ def test_align_indexes(self): def test_align_non_unique(self): x = Dataset({"foo": ("x", [3, 4, 5]), "x": [0, 0, 1]}) x1, x2 = align(x, x) - assert x1.identical(x) and x2.identical(x) + assert_identical(x1, x) + assert_identical(x2, x) y = Dataset({"bar": ("x", [6, 7]), "x": [0, 1]}) with raises_regex(ValueError, "cannot reindex or align"): align(x, y) + def test_align_str_dtype(self): + + a = Dataset({"foo": ("x", [0, 1]), "x": ["a", "b"]}) + b = Dataset({"foo": ("x", [1, 2]), "x": ["b", "c"]}) + + expected_a = Dataset({"foo": ("x", [0, 1, np.NaN]), "x": ["a", "b", "c"]}) + expected_b = Dataset({"foo": ("x", [np.NaN, 1, 2]), "x": ["a", "b", "c"]}) + + actual_a, actual_b = xr.align(a, b, join="outer") + + assert_identical(expected_a, actual_a) + assert expected_a.x.dtype == actual_a.x.dtype + + assert_identical(expected_b, actual_b) + assert expected_b.x.dtype == actual_b.x.dtype + def test_broadcast(self): ds = Dataset( {"foo": 0, "bar": ("x", [1]), "baz": ("y", [2, 3])}, {"c": ("x", [4])} @@ -2521,12 +2601,6 @@ def test_rename_same_name(self): renamed = data.rename(newnames) assert_identical(renamed, data) - def test_rename_inplace(self): - times = pd.date_range("2000-01-01", periods=3) - data = Dataset({"z": ("x", [2, 3, 4]), "t": ("t", times)}) - with pytest.raises(TypeError): - data.rename({"x": "y"}, inplace=True) - def test_rename_dims(self): original = Dataset({"x": ("x", [0, 1, 2]), "y": ("x", [10, 11, 12]), "z": 42}) expected = Dataset( @@ -2842,10 +2916,6 @@ def test_set_index(self): obj = ds.set_index(x=mindex.names) assert_identical(obj, expected) - with pytest.raises(TypeError): - ds.set_index(x=mindex.names, inplace=True) - assert_identical(ds, expected) - # ensure set_index with no existing index and a single data var given # doesn't return multi-index ds = Dataset(data_vars={"x_var": ("x", [0, 1, 2])}) @@ -2867,8 +2937,12 @@ def test_reset_index(self): obj = ds.reset_index("x") assert_identical(obj, expected) - with pytest.raises(TypeError): - ds.reset_index("x", inplace=True) + def test_reset_index_keep_attrs(self): + coord_1 = DataArray([1, 2], dims=["coord_1"], attrs={"attrs": True}) + ds = Dataset({}, {"coord_1": coord_1}) + expected = Dataset({}, {"coord_1_": coord_1}) + obj = ds.reset_index("coord_1") + assert_identical(expected, obj) def test_reorder_levels(self): ds = create_test_multiindex() @@ -2879,9 +2953,6 @@ def test_reorder_levels(self): reindexed = ds.reorder_levels(x=["level_2", "level_1"]) assert_identical(reindexed, expected) - with pytest.raises(TypeError): - ds.reorder_levels(x=["level_2", "level_1"], inplace=True) - ds = Dataset({}, coords={"x": [1, 2]}) with raises_regex(ValueError, "has no MultiIndex"): ds.reorder_levels(x=["level_1", "level_2"]) @@ -2935,20 +3006,24 @@ def test_unstack_errors(self): def test_unstack_fill_value(self): ds = xr.Dataset( - {"var": (("x",), np.arange(6))}, + {"var": (("x",), np.arange(6)), "other_var": (("x",), np.arange(3, 9))}, coords={"x": [0, 1, 2] * 2, "y": (("x",), ["a"] * 3 + ["b"] * 3)}, ) # make ds incomplete ds = ds.isel(x=[0, 2, 3, 4]).set_index(index=["x", "y"]) # test fill_value actual = ds.unstack("index", fill_value=-1) - expected = ds.unstack("index").fillna(-1).astype(np.int) - assert actual["var"].dtype == np.int + expected = ds.unstack("index").fillna(-1).astype(int) + assert actual["var"].dtype == int assert_equal(actual, expected) actual = ds["var"].unstack("index", fill_value=-1) - expected = ds["var"].unstack("index").fillna(-1).astype(np.int) - assert actual.equals(expected) + expected = ds["var"].unstack("index").fillna(-1).astype(int) + assert_equal(actual, expected) + + actual = ds.unstack("index", fill_value={"var": -1, "other_var": 1}) + expected = ds.unstack("index").fillna({"var": -1, "other_var": 1}).astype(int) + assert_equal(actual, expected) @requires_sparse def test_unstack_sparse(self): @@ -3030,6 +3105,14 @@ def test_to_stacked_array_dtype_dims(self): assert y.dims == ("x", "features") def test_to_stacked_array_to_unstacked_dataset(self): + + # single dimension: regression test for GH4049 + arr = xr.DataArray(np.arange(3), coords=[("x", [0, 1, 2])]) + data = xr.Dataset({"a": arr, "b": arr}) + stacked = data.to_stacked_array("y", sample_dims=["x"]) + unstacked = stacked.to_unstacked_dataset("y") + assert_identical(unstacked, data) + # make a two dimensional dataset a, b = create_test_stacked_array() D = xr.Dataset({"a": a, "b": b}) @@ -3067,9 +3150,6 @@ def test_update(self): assert actual_result is actual assert_identical(expected, actual) - with pytest.raises(TypeError): - actual = data.update(data, inplace=False) - other = Dataset(attrs={"new": "attr"}) actual = data.copy() actual.update(other) @@ -3366,6 +3446,14 @@ def test_setitem_align_new_indexes(self): ) assert_identical(ds, expected) + @pytest.mark.parametrize("dtype", [str, bytes]) + def test_setitem_str_dtype(self, dtype): + + ds = xr.Dataset(coords={"x": np.array(["x", "y"], dtype=dtype)}) + ds["foo"] = xr.DataArray(np.array([0, 0]), dims=["x"]) + + assert np.issubdtype(ds.x.dtype, dtype) + def test_assign(self): ds = Dataset() actual = ds.assign(x=[0, 1, 2], y=2) @@ -3930,6 +4018,33 @@ def test_to_and_from_dataframe(self): # check roundtrip assert_identical(ds.assign_coords(x=[0, 1]), Dataset.from_dataframe(actual)) + # Check multiindex reordering + new_order = ["x", "y"] + actual = ds.to_dataframe(dim_order=new_order) + assert expected.equals(actual) + + new_order = ["y", "x"] + exp_index = pd.MultiIndex.from_arrays( + [["a", "a", "b", "b", "c", "c"], [0, 1, 0, 1, 0, 1]], names=["y", "x"] + ) + expected = pd.DataFrame( + w.transpose().reshape(-1), columns=["w"], index=exp_index + ) + actual = ds.to_dataframe(dim_order=new_order) + assert expected.equals(actual) + + invalid_order = ["x"] + with pytest.raises( + ValueError, match="does not match the set of dimensions of this" + ): + ds.to_dataframe(dim_order=invalid_order) + + invalid_order = ["x", "z"] + with pytest.raises( + ValueError, match="does not match the set of dimensions of this" + ): + ds.to_dataframe(dim_order=invalid_order) + # check pathological cases df = pd.DataFrame([1]) actual = Dataset.from_dataframe(df) @@ -4012,6 +4127,49 @@ def test_to_and_from_empty_dataframe(self): assert len(actual) == 0 assert expected.equals(actual) + def test_from_dataframe_multiindex(self): + index = pd.MultiIndex.from_product([["a", "b"], [1, 2, 3]], names=["x", "y"]) + df = pd.DataFrame({"z": np.arange(6)}, index=index) + + expected = Dataset( + {"z": (("x", "y"), [[0, 1, 2], [3, 4, 5]])}, + coords={"x": ["a", "b"], "y": [1, 2, 3]}, + ) + actual = Dataset.from_dataframe(df) + assert_identical(actual, expected) + + df2 = df.iloc[[3, 2, 1, 0, 4, 5], :] + actual = Dataset.from_dataframe(df2) + assert_identical(actual, expected) + + df3 = df.iloc[:4, :] + expected3 = Dataset( + {"z": (("x", "y"), [[0, 1, 2], [3, np.nan, np.nan]])}, + coords={"x": ["a", "b"], "y": [1, 2, 3]}, + ) + actual = Dataset.from_dataframe(df3) + assert_identical(actual, expected3) + + df_nonunique = df.iloc[[0, 0], :] + with raises_regex(ValueError, "non-unique MultiIndex"): + Dataset.from_dataframe(df_nonunique) + + def test_from_dataframe_unsorted_levels(self): + # regression test for GH-4186 + index = pd.MultiIndex( + levels=[["b", "a"], ["foo"]], codes=[[0, 1], [0, 0]], names=["lev1", "lev2"] + ) + df = pd.DataFrame({"c1": [0, 2], "c2": [1, 3]}, index=index) + expected = Dataset( + { + "c1": (("lev1", "lev2"), [[0], [2]]), + "c2": (("lev1", "lev2"), [[1], [3]]), + }, + coords={"lev1": ["b", "a"], "lev2": ["foo"]}, + ) + actual = Dataset.from_dataframe(df) + assert_identical(actual, expected) + def test_from_dataframe_non_unique_columns(self): # regression test for GH449 df = pd.DataFrame(np.zeros((2, 2))) @@ -4113,7 +4271,7 @@ def test_to_and_from_dict(self): "t": {"data": t, "dims": "t"}, "b": {"dims": "t", "data": y}, } - with raises_regex(ValueError, "cannot convert dict " "without the key 'dims'"): + with raises_regex(ValueError, "cannot convert dict without the key 'dims'"): Dataset.from_dict(d) def test_to_and_from_dict_with_time_dim(self): @@ -4336,6 +4494,28 @@ def test_fillna(self): assert actual.a.name == "a" assert actual.a.attrs == ds.a.attrs + @pytest.mark.parametrize( + "func", [lambda x: x.clip(0, 1), lambda x: np.float64(1.0) * x, np.abs, abs] + ) + def test_propagate_attrs(self, func): + + da = DataArray(range(5), name="a", attrs={"attr": "da"}) + ds = Dataset({"a": da}, attrs={"attr": "ds"}) + + # test defaults + assert func(ds).attrs == ds.attrs + with set_options(keep_attrs=False): + assert func(ds).attrs != ds.attrs + assert func(ds).a.attrs != ds.a.attrs + + with set_options(keep_attrs=False): + assert func(ds).attrs != ds.attrs + assert func(ds).a.attrs != ds.a.attrs + + with set_options(keep_attrs=True): + assert func(ds).attrs == ds.attrs + assert func(ds).a.attrs == ds.a.attrs + def test_where(self): ds = Dataset({"a": ("x", range(5))}) expected = Dataset({"a": ("x", [np.nan, np.nan, 2, 3, 4])}) @@ -4596,6 +4776,9 @@ def test_reduce_non_numeric(self): assert_equal(data1.mean(), data2.mean()) assert_equal(data1.mean(dim="dim1"), data2.mean(dim="dim1")) + @pytest.mark.filterwarnings( + "ignore:Once the behaviour of DataArray:DeprecationWarning" + ) def test_reduce_strings(self): expected = Dataset({"x": "a"}) ds = Dataset({"x": ("y", ["a", "b"])}) @@ -4667,6 +4850,9 @@ def test_reduce_keep_attrs(self): for k, v in ds.data_vars.items(): assert v.attrs == data[k].attrs + @pytest.mark.filterwarnings( + "ignore:Once the behaviour of DataArray:DeprecationWarning" + ) def test_reduce_argmin(self): # regression test for #205 ds = Dataset({"a": ("x", [0, 1])}) @@ -4698,9 +4884,7 @@ def mean_only_one_axis(x, axis): actual = ds.reduce(mean_only_one_axis, "y") assert_identical(expected, actual) - with raises_regex( - TypeError, "missing 1 required positional argument: " "'axis'" - ): + with raises_regex(TypeError, "missing 1 required positional argument: 'axis'"): ds.reduce(mean_only_one_axis) with raises_regex(TypeError, "non-integer axis"): @@ -5117,7 +5301,7 @@ def test_dataset_diff_exception_label_str(self): with raises_regex(ValueError, "'label' argument has to"): ds.diff("dim2", label="raise_me") - @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0]) + @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {"foo": -10}]) def test_shift(self, fill_value): coords = {"bar": ("x", list("abc")), "x": [-4, 3, 2]} attrs = {"meta": "data"} @@ -5127,6 +5311,8 @@ def test_shift(self, fill_value): # if we supply the default, we expect the missing value for a # float array fill_value = np.nan + elif isinstance(fill_value, dict): + fill_value = fill_value.get("foo", np.nan) expected = Dataset({"foo": ("x", [fill_value, 1, 2])}, coords, attrs) assert_identical(expected, actual) @@ -5320,21 +5506,35 @@ def test_full_like(self): ) actual = full_like(ds, 2) - expect = ds.copy(deep=True) - expect["d1"].values = [2, 2, 2] - expect["d2"].values = [2.0, 2.0, 2.0] - assert expect["d1"].dtype == int - assert expect["d2"].dtype == float - assert_identical(expect, actual) + expected = ds.copy(deep=True) + expected["d1"].values = [2, 2, 2] + expected["d2"].values = [2.0, 2.0, 2.0] + assert expected["d1"].dtype == int + assert expected["d2"].dtype == float + assert_identical(expected, actual) # override dtype actual = full_like(ds, fill_value=True, dtype=bool) - expect = ds.copy(deep=True) - expect["d1"].values = [True, True, True] - expect["d2"].values = [True, True, True] - assert expect["d1"].dtype == bool - assert expect["d2"].dtype == bool - assert_identical(expect, actual) + expected = ds.copy(deep=True) + expected["d1"].values = [True, True, True] + expected["d2"].values = [True, True, True] + assert expected["d1"].dtype == bool + assert expected["d2"].dtype == bool + assert_identical(expected, actual) + + # with multiple fill values + actual = full_like(ds, {"d1": 1, "d2": 2.3}) + expected = ds.assign(d1=("x", [1, 1, 1]), d2=("y", [2.3, 2.3, 2.3])) + assert expected["d1"].dtype == int + assert expected["d2"].dtype == float + assert_identical(expected, actual) + + # override multiple dtypes + actual = full_like(ds, fill_value={"d1": 1, "d2": 2.3}, dtype={"d1": bool}) + expected = ds.assign(d1=("x", [True, True, True]), d2=("y", [2.3, 2.3, 2.3])) + assert expected["d1"].dtype == bool + assert expected["d2"].dtype == float + assert_identical(expected, actual) def test_combine_first(self): dsx0 = DataArray([0, 0], [("x", ["a", "b"])]).to_dataset(name="dsx0") @@ -5536,6 +5736,16 @@ def test_polyfit_output(self): out = ds.polyfit("time", 2) assert len(out.data_vars) == 0 + def test_polyfit_warnings(self): + ds = create_test_data(seed=1) + + with warnings.catch_warnings(record=True) as ws: + ds.var1.polyfit("dim2", 10, full=False) + assert len(ws) == 1 + assert ws[0].category == np.RankWarning + ds.var1.polyfit("dim2", 10, full=True) + assert len(ws) == 1 + def test_pad(self): ds = create_test_data(seed=1) padded = ds.pad(dim2=(1, 1), constant_values=42) @@ -5549,6 +5759,15 @@ def test_pad(self): np.testing.assert_equal(padded["var1"].isel(dim2=[0, -1]).data, 42) np.testing.assert_equal(padded["dim2"][[0, -1]].data, np.nan) + def test_astype_attrs(self): + data = create_test_data(seed=123) + data.attrs["foo"] = "bar" + + assert data.attrs == data.astype(float).attrs + assert data.var1.attrs == data.astype(float).var1.attrs + assert not data.astype(float, keep_attrs=False).attrs + assert not data.astype(float, keep_attrs=False).var1.attrs + # Py.test tests @@ -5778,6 +5997,8 @@ def test_coarsen_keep_attrs(): attrs=_attrs, ) + ds2 = ds.copy(deep=True) + # Test dropped attrs dat = ds.coarsen(coord=5).mean() assert dat.attrs == {} @@ -5791,40 +6012,123 @@ def test_coarsen_keep_attrs(): dat = ds.coarsen(coord=5).mean() assert dat.attrs == _attrs + # Test kept attrs in original object + xr.testing.assert_identical(ds, ds2) -def test_rolling_keep_attrs(): - _attrs = {"units": "test", "long_name": "testing"} - var1 = np.linspace(10, 15, 100) - var2 = np.linspace(5, 10, 100) +@pytest.mark.parametrize( + "funcname, argument", + [ + ("reduce", (np.mean,)), + ("mean", ()), + ("construct", ("window_dim",)), + ("count", ()), + ], +) +def test_rolling_keep_attrs(funcname, argument): + global_attrs = {"units": "test", "long_name": "testing"} + da_attrs = {"da_attr": "test"} + da_not_rolled_attrs = {"da_not_rolled_attr": "test"} + + data = np.linspace(10, 15, 100) coords = np.linspace(1, 10, 100) ds = Dataset( - data_vars={"var1": ("coord", var1), "var2": ("coord", var2)}, + data_vars={"da": ("coord", data), "da_not_rolled": ("no_coord", data)}, coords={"coord": coords}, - attrs=_attrs, + attrs=global_attrs, ) + ds.da.attrs = da_attrs + ds.da_not_rolled.attrs = da_not_rolled_attrs + + # attrs are now kept per default + func = getattr(ds.rolling(dim={"coord": 5}), funcname) + result = func(*argument) + assert result.attrs == global_attrs + assert result.da.attrs == da_attrs + assert result.da_not_rolled.attrs == da_not_rolled_attrs + assert result.da.name == "da" + assert result.da_not_rolled.name == "da_not_rolled" + + # discard attrs + func = getattr(ds.rolling(dim={"coord": 5}), funcname) + result = func(*argument, keep_attrs=False) + assert result.attrs == {} + assert result.da.attrs == {} + assert result.da_not_rolled.attrs == {} + assert result.da.name == "da" + assert result.da_not_rolled.name == "da_not_rolled" + + # test discard attrs using global option + func = getattr(ds.rolling(dim={"coord": 5}), funcname) + with set_options(keep_attrs=False): + result = func(*argument) + + assert result.attrs == {} + assert result.da.attrs == {} + assert result.da_not_rolled.attrs == {} + assert result.da.name == "da" + assert result.da_not_rolled.name == "da_not_rolled" + + # keyword takes precedence over global option + func = getattr(ds.rolling(dim={"coord": 5}), funcname) + with set_options(keep_attrs=False): + result = func(*argument, keep_attrs=True) + + assert result.attrs == global_attrs + assert result.da.attrs == da_attrs + assert result.da_not_rolled.attrs == da_not_rolled_attrs + assert result.da.name == "da" + assert result.da_not_rolled.name == "da_not_rolled" + + func = getattr(ds.rolling(dim={"coord": 5}), funcname) + with set_options(keep_attrs=True): + result = func(*argument, keep_attrs=False) - # Test dropped attrs - dat = ds.rolling(dim={"coord": 5}, min_periods=None, center=False).mean() - assert dat.attrs == {} + assert result.attrs == {} + assert result.da.attrs == {} + assert result.da_not_rolled.attrs == {} + assert result.da.name == "da" + assert result.da_not_rolled.name == "da_not_rolled" - # Test kept attrs using dataset keyword - dat = ds.rolling( - dim={"coord": 5}, min_periods=None, center=False, keep_attrs=True - ).mean() - assert dat.attrs == _attrs - # Test kept attrs using global option - with set_options(keep_attrs=True): - dat = ds.rolling(dim={"coord": 5}, min_periods=None, center=False).mean() - assert dat.attrs == _attrs +def test_rolling_keep_attrs_deprecated(): + global_attrs = {"units": "test", "long_name": "testing"} + attrs_da = {"da_attr": "test"} + + data = np.linspace(10, 15, 100) + coords = np.linspace(1, 10, 100) + + ds = Dataset( + data_vars={"da": ("coord", data)}, + coords={"coord": coords}, + attrs=global_attrs, + ) + ds.da.attrs = attrs_da + + # deprecated option + with pytest.warns( + FutureWarning, match="Passing ``keep_attrs`` to ``rolling`` is deprecated" + ): + result = ds.rolling(dim={"coord": 5}, keep_attrs=False).construct("window_dim") + + assert result.attrs == {} + assert result.da.attrs == {} + + # the keep_attrs in the reduction function takes precedence + with pytest.warns( + FutureWarning, match="Passing ``keep_attrs`` to ``rolling`` is deprecated" + ): + result = ds.rolling(dim={"coord": 5}, keep_attrs=True).construct( + "window_dim", keep_attrs=False + ) + + assert result.attrs == {} + assert result.da.attrs == {} def test_rolling_properties(ds): # catching invalid args - with pytest.raises(ValueError, match="exactly one dim/window should"): - ds.rolling(time=7, x=2) with pytest.raises(ValueError, match="window must be > 0"): ds.rolling(time=-2) with pytest.raises(ValueError, match="min_periods must be greater than zero"): @@ -5851,6 +6155,8 @@ def test_rolling_wrapped_bottleneck(ds, name, center, min_periods, key): expected = getattr(bn, func_name)( ds[key].values, window=7, axis=0, min_count=min_periods ) + else: + raise ValueError assert_array_equal(actual[key].values, expected) # Test center @@ -5866,6 +6172,43 @@ def test_rolling_exp(ds): assert isinstance(result, Dataset) +@requires_numbagg +def test_rolling_exp_keep_attrs(ds): + + attrs_global = {"attrs": "global"} + attrs_z1 = {"attr": "z1"} + + ds.attrs = attrs_global + ds.z1.attrs = attrs_z1 + + # attrs are kept per default + result = ds.rolling_exp(time=10).mean() + assert result.attrs == attrs_global + assert result.z1.attrs == attrs_z1 + + # discard attrs + result = ds.rolling_exp(time=10).mean(keep_attrs=False) + assert result.attrs == {} + assert result.z1.attrs == {} + + # test discard attrs using global option + with set_options(keep_attrs=False): + result = ds.rolling_exp(time=10).mean() + assert result.attrs == {} + assert result.z1.attrs == {} + + # keyword takes precedence over global option + with set_options(keep_attrs=False): + result = ds.rolling_exp(time=10).mean(keep_attrs=True) + assert result.attrs == attrs_global + assert result.z1.attrs == attrs_z1 + + with set_options(keep_attrs=True): + result = ds.rolling_exp(time=10).mean(keep_attrs=False) + assert result.attrs == {} + assert result.z1.attrs == {} + + @pytest.mark.parametrize("center", (True, False)) @pytest.mark.parametrize("min_periods", (None, 1, 2, 3)) @pytest.mark.parametrize("window", (1, 2, 3, 4)) @@ -5949,12 +6292,98 @@ def test_rolling_reduce(ds, center, min_periods, window, name): assert src_var.dims == actual[key].dims +@pytest.mark.parametrize("ds", (2,), indirect=True) +@pytest.mark.parametrize("center", (True, False)) +@pytest.mark.parametrize("min_periods", (None, 1)) +@pytest.mark.parametrize("name", ("sum", "max")) +@pytest.mark.parametrize("dask", (True, False)) +def test_ndrolling_reduce(ds, center, min_periods, name, dask): + if dask and has_dask: + ds = ds.chunk({"x": 4}) + + rolling_obj = ds.rolling(time=4, x=3, center=center, min_periods=min_periods) + + actual = getattr(rolling_obj, name)() + expected = getattr( + getattr( + ds.rolling(time=4, center=center, min_periods=min_periods), name + )().rolling(x=3, center=center, min_periods=min_periods), + name, + )() + assert_allclose(actual, expected) + assert actual.dims == expected.dims + + # Do it in the opposite order + expected = getattr( + getattr( + ds.rolling(x=3, center=center, min_periods=min_periods), name + )().rolling(time=4, center=center, min_periods=min_periods), + name, + )() + + assert_allclose(actual, expected) + assert actual.dims == expected.dims + + +@pytest.mark.parametrize("center", (True, False, (True, False))) +@pytest.mark.parametrize("fill_value", (np.nan, 0.0)) +@pytest.mark.parametrize("dask", (True, False)) +def test_ndrolling_construct(center, fill_value, dask): + da = DataArray( + np.arange(5 * 6 * 7).reshape(5, 6, 7).astype(float), + dims=["x", "y", "z"], + coords={"x": ["a", "b", "c", "d", "e"], "y": np.arange(6)}, + ) + ds = xr.Dataset({"da": da}) + if dask and has_dask: + ds = ds.chunk({"x": 4}) + + actual = ds.rolling(x=3, z=2, center=center).construct( + x="x1", z="z1", fill_value=fill_value + ) + if not isinstance(center, tuple): + center = (center, center) + expected = ( + ds.rolling(x=3, center=center[0]) + .construct(x="x1", fill_value=fill_value) + .rolling(z=2, center=center[1]) + .construct(z="z1", fill_value=fill_value) + ) + assert_allclose(actual, expected) + + def test_raise_no_warning_for_nan_in_binary_ops(): with pytest.warns(None) as record: Dataset(data_vars={"x": ("y", [1, 2, np.NaN])}) > 0 assert len(record) == 0 +@pytest.mark.filterwarnings("error") +@pytest.mark.parametrize("ds", (2,), indirect=True) +def test_raise_no_warning_assert_close(ds): + assert_allclose(ds, ds) + + +@pytest.mark.xfail(reason="See https://github.com/pydata/xarray/pull/4369 or docstring") +@pytest.mark.filterwarnings("error") +@pytest.mark.parametrize("ds", (2,), indirect=True) +@pytest.mark.parametrize("name", ("mean", "max")) +def test_raise_no_warning_dask_rolling_assert_close(ds, name): + """ + This is a puzzle — I can't easily find the source of the warning. It + requires `assert_allclose` to be run, for the `ds` param to be 2, and is + different for `mean` and `max`. `sum` raises no warning. + """ + + ds = ds.chunk({"x": 4}) + + rolling_obj = ds.rolling(time=4, x=3) + + actual = getattr(rolling_obj, name)() + expected = getattr(getattr(ds.rolling(time=4), name)().rolling(x=3), name)() + assert_allclose(actual, expected) + + @pytest.mark.parametrize("dask", [True, False]) @pytest.mark.parametrize("edge_order", [1, 2]) def test_differentiate(dask, edge_order): @@ -6221,3 +6650,9 @@ def test_weakref(): ds = Dataset() r = ref(ds) assert r() is ds + + +def test_deepcopy_obj_array(): + x0 = Dataset(dict(foo=DataArray(np.array([object()])))) + x1 = deepcopy(x0) + assert x0["foo"].values[0] is not x1["foo"].values[0] diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 8011171d223..7886e9fd0d4 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -135,8 +135,8 @@ def test_dask_distributed_read_netcdf_integration_test( def test_dask_distributed_zarr_integration_test(loop, consolidated, compute): if consolidated: pytest.importorskip("zarr", minversion="2.2.1.dev2") - write_kwargs = dict(consolidated=True) - read_kwargs = dict(consolidated=True) + write_kwargs = {"consolidated": True} + read_kwargs = {"backend_kwargs": {"consolidated": True}} else: write_kwargs = read_kwargs = {} chunks = {"dim1": 4, "dim2": 3, "dim3": 5} @@ -151,7 +151,9 @@ def test_dask_distributed_zarr_integration_test(loop, consolidated, compute): ) if not compute: maybe_futures.compute() - with xr.open_zarr(filename, **read_kwargs) as restored: + with xr.open_dataset( + filename, chunks="auto", engine="zarr", **read_kwargs + ) as restored: assert isinstance(restored.var1.data, da.Array) computed = restored.compute() assert_allclose(original, computed) diff --git a/xarray/tests/test_dtypes.py b/xarray/tests/test_dtypes.py index 1f3aee84979..5ad1a6355e6 100644 --- a/xarray/tests/test_dtypes.py +++ b/xarray/tests/test_dtypes.py @@ -7,8 +7,8 @@ @pytest.mark.parametrize( "args, expected", [ - ([np.bool], np.bool), - ([np.bool, np.string_], np.object_), + ([bool], bool), + ([bool, np.string_], np.object_), ([np.float32, np.float64], np.float64), ([np.float32, np.string_], np.object_), ([np.unicode_, np.int64], np.object_), diff --git a/xarray/tests/test_duck_array_ops.py b/xarray/tests/test_duck_array_ops.py index e61881cfce3..1342950f3e5 100644 --- a/xarray/tests/test_duck_array_ops.py +++ b/xarray/tests/test_duck_array_ops.py @@ -33,6 +33,7 @@ arm_xfail, assert_array_equal, has_dask, + has_scipy, raises_regex, requires_cftime, requires_dask, @@ -119,11 +120,9 @@ def test_concatenate_type_promotion(self): result = concatenate([[1], ["b"]]) assert_array_equal(result, np.array([1, "b"], dtype=object)) + @pytest.mark.filterwarnings("error") def test_all_nan_arrays(self): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", "All-NaN slice") - warnings.filterwarnings("ignore", "Mean of empty slice") - assert np.isnan(mean([np.nan, np.nan])) + assert np.isnan(mean([np.nan, np.nan])) def test_cumsum_1d(): @@ -256,7 +255,7 @@ def from_series_or_scalar(se): def series_reduce(da, func, dim, **kwargs): - """ convert DataArray to pd.Series, apply pd.func, then convert back to + """convert DataArray to pd.Series, apply pd.func, then convert back to a DataArray. Multiple dims cannot be specified.""" if dim is None or da.ndim == 1: se = da.to_series() @@ -332,6 +331,40 @@ def test_cftime_datetime_mean(): assert_equal(result, expected) +@requires_cftime +def test_cftime_datetime_mean_long_time_period(): + import cftime + + times = np.array( + [ + [ + cftime.DatetimeNoLeap(400, 12, 31, 0, 0, 0, 0), + cftime.DatetimeNoLeap(520, 12, 31, 0, 0, 0, 0), + ], + [ + cftime.DatetimeNoLeap(520, 12, 31, 0, 0, 0, 0), + cftime.DatetimeNoLeap(640, 12, 31, 0, 0, 0, 0), + ], + [ + cftime.DatetimeNoLeap(640, 12, 31, 0, 0, 0, 0), + cftime.DatetimeNoLeap(760, 12, 31, 0, 0, 0, 0), + ], + ] + ) + + da = DataArray(times, dims=["time", "d2"]) + result = da.mean("d2") + expected = DataArray( + [ + cftime.DatetimeNoLeap(460, 12, 31, 0, 0, 0, 0), + cftime.DatetimeNoLeap(580, 12, 31, 0, 0, 0, 0), + cftime.DatetimeNoLeap(700, 12, 31, 0, 0, 0, 0), + ], + dims=["time"], + ) + assert_equal(result, expected) + + @requires_cftime @requires_dask def test_cftime_datetime_mean_dask_error(): @@ -384,7 +417,7 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim): actual = getattr(da, func)(skipna=skipna, dim=aggdim) assert_dask_array(actual, dask) - assert np.allclose( + np.testing.assert_allclose( actual.values, np.array(expected), rtol=1.0e-4, equal_nan=True ) except (TypeError, AttributeError, ZeroDivisionError): @@ -449,9 +482,7 @@ def test_argmin_max(dim_num, dtype, contains_nan, dask, func, skipna, aggdim): if contains_nan: if not skipna: - pytest.skip( - "numpy's argmin (not nanargmin) does not handle " "object-dtype" - ) + pytest.skip("numpy's argmin (not nanargmin) does not handle object-dtype") if skipna and np.dtype(dtype).kind in "iufc": pytest.skip("numpy's nanargmin raises ValueError for all nan axis") da = construct_dataarray(dim_num, dtype, contains_nan=contains_nan, dask=dask) @@ -547,15 +578,35 @@ def test_dask_gradient(axis, edge_order): @pytest.mark.parametrize("dask", [False, True]) @pytest.mark.parametrize("func", ["sum", "prod"]) @pytest.mark.parametrize("aggdim", [None, "x"]) -def test_min_count(dim_num, dtype, dask, func, aggdim): +@pytest.mark.parametrize("contains_nan", [True, False]) +@pytest.mark.parametrize("skipna", [True, False, None]) +def test_min_count(dim_num, dtype, dask, func, aggdim, contains_nan, skipna): if dask and not has_dask: pytest.skip("requires dask") - da = construct_dataarray(dim_num, dtype, contains_nan=True, dask=dask) + da = construct_dataarray(dim_num, dtype, contains_nan=contains_nan, dask=dask) min_count = 3 - actual = getattr(da, func)(dim=aggdim, skipna=True, min_count=min_count) - expected = series_reduce(da, func, skipna=True, dim=aggdim, min_count=min_count) + actual = getattr(da, func)(dim=aggdim, skipna=skipna, min_count=min_count) + expected = series_reduce(da, func, skipna=skipna, dim=aggdim, min_count=min_count) + assert_allclose(actual, expected) + assert_dask_array(actual, dask) + + +@pytest.mark.parametrize("dtype", [float, int, np.float32, np.bool_]) +@pytest.mark.parametrize("dask", [False, True]) +@pytest.mark.parametrize("func", ["sum", "prod"]) +def test_min_count_nd(dtype, dask, func): + if dask and not has_dask: + pytest.skip("requires dask") + + min_count = 3 + dim_num = 3 + da = construct_dataarray(dim_num, dtype, contains_nan=True, dask=dask) + actual = getattr(da, func)(dim=["x", "y", "z"], skipna=True, min_count=min_count) + # Supplying all dims is equivalent to supplying `...` or `None` + expected = getattr(da, func)(dim=..., skipna=True, min_count=min_count) + assert_allclose(actual, expected) assert_dask_array(actual, dask) @@ -571,14 +622,15 @@ def test_min_count_dataset(func): @pytest.mark.parametrize("dtype", [float, int, np.float32, np.bool_]) @pytest.mark.parametrize("dask", [False, True]) +@pytest.mark.parametrize("skipna", [False, True]) @pytest.mark.parametrize("func", ["sum", "prod"]) -def test_multiple_dims(dtype, dask, func): +def test_multiple_dims(dtype, dask, skipna, func): if dask and not has_dask: pytest.skip("requires dask") da = construct_dataarray(3, dtype, contains_nan=True, dask=dask) - actual = getattr(da, func)(("x", "y")) - expected = getattr(getattr(da, func)("x"), func)("y") + actual = getattr(da, func)(("x", "y"), skipna=skipna) + expected = getattr(getattr(da, func)("x", skipna=skipna), func)("y", skipna=skipna) assert_allclose(actual, expected) @@ -602,7 +654,7 @@ def test_docs(): skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or skipna=True has not been implemented (object, datetime64 or timedelta64). - min_count : int, default None + min_count : int, default: None The required number of valid values to perform the operation. If fewer than min_count non-NA values are present the result will be NA. New in version 0.10.8: Added with the default being None. @@ -767,8 +819,8 @@ def test_timedelta_to_numeric(td): @pytest.mark.parametrize("use_dask", [True, False]) @pytest.mark.parametrize("skipna", [True, False]) def test_least_squares(use_dask, skipna): - if use_dask and not has_dask: - pytest.skip("requires dask") + if use_dask and (not has_dask or not has_scipy): + pytest.skip("requires dask and scipy") lhs = np.array([[1, 2], [1, 2], [3, 2]]) rhs = DataArray(np.array([3, 5, 7]), dims=("y",)) diff --git a/xarray/tests/test_extensions.py b/xarray/tests/test_extensions.py index 5af0f6d8a42..fa91e5c813d 100644 --- a/xarray/tests/test_extensions.py +++ b/xarray/tests/test_extensions.py @@ -4,7 +4,7 @@ import xarray as xr -from . import raises_regex +from . import assert_identical, raises_regex @xr.register_dataset_accessor("example_accessor") @@ -61,20 +61,20 @@ class Foo: def test_pickle_dataset(self): ds = xr.Dataset() ds_restored = pickle.loads(pickle.dumps(ds)) - assert ds.identical(ds_restored) + assert_identical(ds, ds_restored) # state save on the accessor is restored assert ds.example_accessor is ds.example_accessor ds.example_accessor.value = "foo" ds_restored = pickle.loads(pickle.dumps(ds)) - assert ds.identical(ds_restored) + assert_identical(ds, ds_restored) assert ds_restored.example_accessor.value == "foo" def test_pickle_dataarray(self): array = xr.Dataset() assert array.example_accessor is array.example_accessor array_restored = pickle.loads(pickle.dumps(array)) - assert array.identical(array_restored) + assert_identical(array, array_restored) def test_broken_accessor(self): # regression test for GH933 diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 6881c0bc0ff..f2facf5b481 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -7,6 +7,7 @@ import xarray as xr from xarray.core import formatting +from xarray.core.npcompat import IS_NEP18_ACTIVE from . import raises_regex @@ -86,6 +87,9 @@ def test_format_item(self): (b"foo", "b'foo'"), (1, "1"), (1.0, "1.0"), + (np.float16(1.1234), "1.123"), + (np.float32(1.0111111), "1.011"), + (np.float64(22.222222), "22.22"), ] for item, expected in cases: actual = formatting.format_item(item) @@ -391,6 +395,44 @@ def test_array_repr(self): assert actual == expected +@pytest.mark.skipif(not IS_NEP18_ACTIVE, reason="requires __array_function__") +def test_inline_variable_array_repr_custom_repr(): + class CustomArray: + def __init__(self, value, attr): + self.value = value + self.attr = attr + + def _repr_inline_(self, width): + formatted = f"({self.attr}) {self.value}" + if len(formatted) > width: + formatted = f"({self.attr}) ..." + + return formatted + + def __array_function__(self, *args, **kwargs): + return NotImplemented + + @property + def shape(self): + return self.value.shape + + @property + def dtype(self): + return self.value.dtype + + @property + def ndim(self): + return self.value.ndim + + value = CustomArray(np.array([20, 40]), "m") + variable = xr.Variable("x", value) + + max_width = 10 + actual = formatting.inline_variable_array_repr(variable, max_width=10) + + assert actual == value._repr_inline_(max_width) + + def test_set_numpy_options(): original_options = np.get_printoptions() with formatting.set_numpy_options(threshold=10): @@ -405,10 +447,52 @@ def test_short_numpy_repr(): np.random.randn(20, 20), np.random.randn(5, 10, 15), np.random.randn(5, 10, 15, 3), + np.random.randn(100, 5, 1), ] # number of lines: - # for default numpy repr: 167, 140, 254, 248 - # for short_numpy_repr: 1, 7, 24, 19 + # for default numpy repr: 167, 140, 254, 248, 599 + # for short_numpy_repr: 1, 7, 24, 19, 25 for array in cases: num_lines = formatting.short_numpy_repr(array).count("\n") + 1 assert num_lines < 30 + + +def test_large_array_repr_length(): + + da = xr.DataArray(np.random.randn(100, 5, 1)) + + result = repr(da).splitlines() + assert len(result) < 50 + + +@pytest.mark.parametrize( + "display_max_rows, n_vars, n_attr", + [(50, 40, 30), (35, 40, 30), (11, 40, 30), (1, 40, 30)], +) +def test__mapping_repr(display_max_rows, n_vars, n_attr): + long_name = "long_name" + a = np.core.defchararray.add(long_name, np.arange(0, n_vars).astype(str)) + b = np.core.defchararray.add("attr_", np.arange(0, n_attr).astype(str)) + attrs = {k: 2 for k in b} + coords = dict(time=np.array([0, 1])) + data_vars = dict() + for v in a: + data_vars[v] = xr.DataArray( + name=v, + data=np.array([3, 4]), + dims=["time"], + coords=coords, + ) + ds = xr.Dataset(data_vars) + ds.attrs = attrs + + with xr.set_options(display_max_rows=display_max_rows): + + # Parse the data_vars print and show only data_vars rows: + summary = formatting.data_vars_repr(ds.data_vars).split("\n") + summary = [v for v in summary if long_name in v] + + # The length should be less than or equal to display_max_rows: + len_summary = len(summary) + data_vars_print_size = min(display_max_rows, len_summary) + assert len_summary == data_vars_print_size diff --git a/xarray/tests/test_formatting_html.py b/xarray/tests/test_formatting_html.py index 239f339208d..9a210ad6fa3 100644 --- a/xarray/tests/test_formatting_html.py +++ b/xarray/tests/test_formatting_html.py @@ -48,7 +48,7 @@ def dataset(): def test_short_data_repr_html(dataarray): data_repr = fh.short_data_repr_html(dataarray) - assert data_repr.startswith("array") + assert data_repr.startswith("
    array")
     
     
     def test_short_data_repr_html_non_str_keys(dataset):
    @@ -108,8 +108,8 @@ def test_summarize_attrs_with_unsafe_attr_name_and_value():
     def test_repr_of_dataarray(dataarray):
         formatted = fh.array_repr(dataarray)
         assert "dim_0" in formatted
    -    # has an expandable data section
    -    assert formatted.count("class='xr-array-in' type='checkbox' >") == 1
    +    # has an expanded data section
    +    assert formatted.count("class='xr-array-in' type='checkbox' checked>") == 1
         # coords and attrs don't have an items so they'll be be disabled and collapsed
         assert (
             formatted.count("class='xr-section-summary-in' type='checkbox' disabled >") == 2
    @@ -137,3 +137,22 @@ def test_repr_of_dataset(dataset):
         )
         assert "<U4" in formatted or ">U4" in formatted
         assert "<IA>" in formatted
    +
    +
    +def test_repr_text_fallback(dataset):
    +    formatted = fh.dataset_repr(dataset)
    +
    +    # Just test that the "pre" block used for fallback to plain text is present.
    +    assert "
    " in formatted
    +
    +
    +def test_variable_repr_html():
    +    v = xr.Variable(["time", "x"], [[1, 2, 3], [4, 5, 6]], {"foo": "bar"})
    +    assert hasattr(v, "_repr_html_")
    +    with xr.set_options(display_style="html"):
    +        html = v._repr_html_().strip()
    +    # We don't do a complete string identity since
    +    # html output is probably subject to change, is long and... reasons.
    +    # Just test that something reasonable was produced.
    +    assert html.startswith("")
    +    assert "xarray.Variable" in html
    diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py
    index 866d5fb0899..85f729a9f7a 100644
    --- a/xarray/tests/test_groupby.py
    +++ b/xarray/tests/test_groupby.py
    @@ -447,8 +447,7 @@ def test_groupby_drops_nans():
     
         # reduction operation along a different dimension
         actual = grouped.mean("time")
    -    with pytest.warns(RuntimeWarning):  # mean of empty slice
    -        expected = ds.mean("time").where(ds.id.notnull())
    +    expected = ds.mean("time").where(ds.id.notnull())
         assert_identical(actual, expected)
     
         # NaN in non-dimensional coordinate
    @@ -538,4 +537,16 @@ def test_groupby_bins_timeseries():
         assert_identical(actual, expected)
     
     
    +def test_groupby_none_group_name():
    +    # GH158
    +    # xarray should not fail if a DataArray's name attribute is None
    +
    +    data = np.arange(10) + 10
    +    da = xr.DataArray(data)  # da.name = None
    +    key = xr.DataArray(np.floor_divide(data, 2))
    +
    +    mean = da.groupby(key).mean()
    +    assert "group" in mean.dims
    +
    +
     # TODO: move other groupby tests from test_dataset and test_dataarray over here
    diff --git a/xarray/tests/test_indexing.py b/xarray/tests/test_indexing.py
    index d7ed16b9573..4ef7536e1f2 100644
    --- a/xarray/tests/test_indexing.py
    +++ b/xarray/tests/test_indexing.py
    @@ -86,6 +86,15 @@ def test_convert_label_indexer(self):
             with pytest.raises(IndexError):
                 indexing.convert_label_indexer(mindex, (slice(None), 1, "no_level"))
     
    +    def test_convert_label_indexer_datetime(self):
    +        index = pd.to_datetime(["2000-01-01", "2001-01-01", "2002-01-01"])
    +        actual = indexing.convert_label_indexer(index, "2001-01-01")
    +        expected = (1, None)
    +        assert actual == expected
    +
    +        actual = indexing.convert_label_indexer(index, index.to_numpy()[1])
    +        assert actual == expected
    +
         def test_convert_unsorted_datetime_index_raises(self):
             index = pd.to_datetime(["2001", "2000", "2002"])
             with pytest.raises(KeyError):
    diff --git a/xarray/tests/test_interp.py b/xarray/tests/test_interp.py
    index 0502348160e..20d5fb12a62 100644
    --- a/xarray/tests/test_interp.py
    +++ b/xarray/tests/test_interp.py
    @@ -1,9 +1,18 @@
    +from itertools import combinations, permutations
    +
     import numpy as np
     import pandas as pd
     import pytest
     
     import xarray as xr
    -from xarray.tests import assert_allclose, assert_equal, requires_cftime, requires_scipy
    +from xarray.tests import (
    +    assert_allclose,
    +    assert_equal,
    +    assert_identical,
    +    requires_cftime,
    +    requires_dask,
    +    requires_scipy,
    +)
     
     from ..coding.cftimeindex import _parse_array_of_cftime_strings
     from . import has_dask, has_scipy
    @@ -63,12 +72,6 @@ def test_interpolate_1d(method, dim, case):
     
         da = get_example_data(case)
         xdest = np.linspace(0.0, 0.9, 80)
    -
    -    if dim == "y" and case == 1:
    -        with pytest.raises(NotImplementedError):
    -            actual = da.interp(method=method, **{dim: xdest})
    -        pytest.skip("interpolation along chunked dimension is " "not yet supported")
    -
         actual = da.interp(method=method, **{dim: xdest})
     
         # scipy interpolation for the reference
    @@ -274,6 +277,32 @@ def test_interpolate_nd_nd():
             da.interp(a=ia)
     
     
    +@requires_scipy
    +def test_interpolate_nd_with_nan():
    +    """Interpolate an array with an nd indexer and `NaN` values."""
    +
    +    # Create indexer into `a` with dimensions (y, x)
    +    x = [0, 1, 2]
    +    y = [10, 20]
    +    c = {"x": x, "y": y}
    +    a = np.arange(6, dtype=float).reshape(2, 3)
    +    a[0, 1] = np.nan
    +    ia = xr.DataArray(a, dims=("y", "x"), coords=c)
    +
    +    da = xr.DataArray([1, 2, 2], dims=("a"), coords={"a": [0, 2, 4]})
    +    out = da.interp(a=ia)
    +    expected = xr.DataArray(
    +        [[1.0, np.nan, 2.0], [2.0, 2.0, np.nan]], dims=("y", "x"), coords=c
    +    )
    +    xr.testing.assert_allclose(out.drop_vars("a"), expected)
    +
    +    db = 2 * da
    +    ds = xr.Dataset({"da": da, "db": db})
    +    out = ds.interp(a=ia)
    +    expected_ds = xr.Dataset({"da": expected, "db": 2 * expected})
    +    xr.testing.assert_allclose(out.drop_vars("a"), expected_ds)
    +
    +
     @pytest.mark.parametrize("method", ["linear"])
     @pytest.mark.parametrize("case", [0, 1])
     def test_interpolate_scalar(method, case):
    @@ -376,8 +405,6 @@ def test_errors(use_dask):
         # invalid method
         with pytest.raises(ValueError):
             da.interp(x=[2, 0], method="boo")
    -    with pytest.raises(ValueError):
    -        da.interp(x=[2, 0], y=2, method="cubic")
         with pytest.raises(ValueError):
             da.interp(y=[2, 0], method="boo")
     
    @@ -552,6 +579,7 @@ def test_interp_like():
                 [0.5, 1.5],
             ),
             (["2000-01-01T12:00", "2000-01-02T12:00"], [0.5, 1.5]),
    +        (["2000-01-01T12:00", "2000-01-02T12:00", "NaT"], [0.5, 1.5, np.nan]),
             (["2000-01-01T12:00"], 0.5),
             pytest.param("2000-01-01T12:00", 0.5, marks=pytest.mark.xfail),
         ],
    @@ -699,3 +727,142 @@ def test_3641():
         times = xr.cftime_range("0001", periods=3, freq="500Y")
         da = xr.DataArray(range(3), dims=["time"], coords=[times])
         da.interp(time=["0002-05-01"])
    +
    +
    +@requires_scipy
    +@pytest.mark.parametrize("method", ["nearest", "linear"])
    +def test_decompose(method):
    +    da = xr.DataArray(
    +        np.arange(6).reshape(3, 2),
    +        dims=["x", "y"],
    +        coords={"x": [0, 1, 2], "y": [-0.1, -0.3]},
    +    )
    +    x_new = xr.DataArray([0.5, 1.5, 2.5], dims=["x1"])
    +    y_new = xr.DataArray([-0.15, -0.25], dims=["y1"])
    +    x_broadcast, y_broadcast = xr.broadcast(x_new, y_new)
    +    assert x_broadcast.ndim == 2
    +
    +    actual = da.interp(x=x_new, y=y_new, method=method).drop_vars(("x", "y"))
    +    expected = da.interp(x=x_broadcast, y=y_broadcast, method=method).drop_vars(
    +        ("x", "y")
    +    )
    +    assert_allclose(actual, expected)
    +
    +
    +@requires_scipy
    +@requires_dask
    +@pytest.mark.parametrize(
    +    "method", ["linear", "nearest", "zero", "slinear", "quadratic", "cubic"]
    +)
    +@pytest.mark.parametrize("chunked", [True, False])
    +@pytest.mark.parametrize(
    +    "data_ndim,interp_ndim,nscalar",
    +    [
    +        (data_ndim, interp_ndim, nscalar)
    +        for data_ndim in range(1, 4)
    +        for interp_ndim in range(1, data_ndim + 1)
    +        for nscalar in range(0, interp_ndim + 1)
    +    ],
    +)
    +def test_interpolate_chunk_1d(method, data_ndim, interp_ndim, nscalar, chunked):
    +    """Interpolate nd array with multiple independant indexers
    +
    +    It should do a series of 1d interpolation
    +    """
    +
    +    # 3d non chunked data
    +    x = np.linspace(0, 1, 5)
    +    y = np.linspace(2, 4, 7)
    +    z = np.linspace(-0.5, 0.5, 11)
    +    da = xr.DataArray(
    +        data=np.sin(x[:, np.newaxis, np.newaxis])
    +        * np.cos(y[:, np.newaxis])
    +        * np.exp(z),
    +        coords=[("x", x), ("y", y), ("z", z)],
    +    )
    +    kwargs = {"fill_value": "extrapolate"}
    +
    +    # choose the data dimensions
    +    for data_dims in permutations(da.dims, data_ndim):
    +
    +        # select only data_ndim dim
    +        da = da.isel(  # take the middle line
    +            {dim: len(da.coords[dim]) // 2 for dim in da.dims if dim not in data_dims}
    +        )
    +
    +        # chunk data
    +        da = da.chunk(chunks={dim: i + 1 for i, dim in enumerate(da.dims)})
    +
    +        # choose the interpolation dimensions
    +        for interp_dims in permutations(da.dims, interp_ndim):
    +            # choose the scalar interpolation dimensions
    +            for scalar_dims in combinations(interp_dims, nscalar):
    +                dest = {}
    +                for dim in interp_dims:
    +                    if dim in scalar_dims:
    +                        # take the middle point
    +                        dest[dim] = 0.5 * (da.coords[dim][0] + da.coords[dim][-1])
    +                    else:
    +                        # pick some points, including outside the domain
    +                        before = 2 * da.coords[dim][0] - da.coords[dim][1]
    +                        after = 2 * da.coords[dim][-1] - da.coords[dim][-2]
    +
    +                        dest[dim] = np.linspace(before, after, len(da.coords[dim]) * 13)
    +                        if chunked:
    +                            dest[dim] = xr.DataArray(data=dest[dim], dims=[dim])
    +                            dest[dim] = dest[dim].chunk(2)
    +                actual = da.interp(method=method, **dest, kwargs=kwargs)
    +                expected = da.compute().interp(method=method, **dest, kwargs=kwargs)
    +
    +                assert_identical(actual, expected)
    +
    +                # all the combinations are usually not necessary
    +                break
    +            break
    +        break
    +
    +
    +@requires_scipy
    +@requires_dask
    +@pytest.mark.parametrize("method", ["linear", "nearest"])
    +def test_interpolate_chunk_advanced(method):
    +    """Interpolate nd array with an nd indexer sharing coordinates."""
    +    # Create original array
    +    x = np.linspace(-1, 1, 5)
    +    y = np.linspace(-1, 1, 7)
    +    z = np.linspace(-1, 1, 11)
    +    t = np.linspace(0, 1, 13)
    +    q = np.linspace(0, 1, 17)
    +    da = xr.DataArray(
    +        data=np.sin(x[:, np.newaxis, np.newaxis, np.newaxis, np.newaxis])
    +        * np.cos(y[:, np.newaxis, np.newaxis, np.newaxis])
    +        * np.exp(z[:, np.newaxis, np.newaxis])
    +        * t[:, np.newaxis]
    +        + q,
    +        dims=("x", "y", "z", "t", "q"),
    +        coords={"x": x, "y": y, "z": z, "t": t, "q": q, "label": "dummy_attr"},
    +    )
    +
    +    # Create indexer into `da` with shared coordinate ("full-twist" Möbius strip)
    +    theta = np.linspace(0, 2 * np.pi, 5)
    +    w = np.linspace(-0.25, 0.25, 7)
    +    r = xr.DataArray(
    +        data=1 + w[:, np.newaxis] * np.cos(theta),
    +        coords=[("w", w), ("theta", theta)],
    +    )
    +
    +    x = r * np.cos(theta)
    +    y = r * np.sin(theta)
    +    z = xr.DataArray(
    +        data=w[:, np.newaxis] * np.sin(theta),
    +        coords=[("w", w), ("theta", theta)],
    +    )
    +
    +    kwargs = {"fill_value": None}
    +    expected = da.interp(t=0.5, x=x, y=y, z=z, kwargs=kwargs, method=method)
    +
    +    da = da.chunk(2)
    +    x = x.chunk(1)
    +    z = z.chunk(3)
    +    actual = da.interp(t=0.5, x=x, y=y, z=z, kwargs=kwargs, method=method)
    +    assert_identical(actual, expected)
    diff --git a/xarray/tests/test_merge.py b/xarray/tests/test_merge.py
    index 9057575b38c..34b138e1f6a 100644
    --- a/xarray/tests/test_merge.py
    +++ b/xarray/tests/test_merge.py
    @@ -4,7 +4,7 @@
     import xarray as xr
     from xarray.core import dtypes, merge
     from xarray.core.merge import MergeError
    -from xarray.testing import assert_identical
    +from xarray.testing import assert_equal, assert_identical
     
     from . import raises_regex
     from .test_dataset import create_test_data
    @@ -33,17 +33,17 @@ def test_merge_arrays(self):
             data = create_test_data()
             actual = xr.merge([data.var1, data.var2])
             expected = data[["var1", "var2"]]
    -        assert actual.identical(expected)
    +        assert_identical(actual, expected)
     
         def test_merge_datasets(self):
             data = create_test_data()
     
             actual = xr.merge([data[["var1"]], data[["var2"]]])
             expected = data[["var1", "var2"]]
    -        assert actual.identical(expected)
    +        assert_identical(actual, expected)
     
             actual = xr.merge([data, data])
    -        assert actual.identical(data)
    +        assert_identical(actual, data)
     
         def test_merge_dataarray_unnamed(self):
             data = xr.DataArray([1, 2], dims="x")
    @@ -61,10 +61,10 @@ def test_merge_arrays_attrs_default(self):
             actual = xr.merge([data.var1, data.var2])
             expected = data[["var1", "var2"]]
             expected.attrs = expected_attrs
    -        assert actual.identical(expected)
    +        assert_identical(actual, expected)
     
         @pytest.mark.parametrize(
    -        "combine_attrs, var1_attrs, var2_attrs, expected_attrs, " "expect_exception",
    +        "combine_attrs, var1_attrs, var2_attrs, expected_attrs, expect_exception",
             [
                 (
                     "no_conflicts",
    @@ -107,17 +107,24 @@ def test_merge_arrays_attrs(
                 actual = xr.merge([data.var1, data.var2], combine_attrs=combine_attrs)
                 expected = data[["var1", "var2"]]
                 expected.attrs = expected_attrs
    -            assert actual.identical(expected)
    +            assert_identical(actual, expected)
    +
    +    def test_merge_attrs_override_copy(self):
    +        ds1 = xr.Dataset(attrs={"x": 0})
    +        ds2 = xr.Dataset(attrs={"x": 1})
    +        ds3 = xr.merge([ds1, ds2], combine_attrs="override")
    +        ds3.attrs["x"] = 2
    +        assert ds1.x == 0
     
         def test_merge_dicts_simple(self):
             actual = xr.merge([{"foo": 0}, {"bar": "one"}, {"baz": 3.5}])
             expected = xr.Dataset({"foo": 0, "bar": "one", "baz": 3.5})
    -        assert actual.identical(expected)
    +        assert_identical(actual, expected)
     
         def test_merge_dicts_dims(self):
             actual = xr.merge([{"y": ("x", [13])}, {"x": [12]}])
             expected = xr.Dataset({"x": [12], "y": ("x", [13])})
    -        assert actual.identical(expected)
    +        assert_identical(actual, expected)
     
         def test_merge_error(self):
             ds = xr.Dataset({"x": 0})
    @@ -167,7 +174,7 @@ def test_merge_no_conflicts_multi_var(self):
     
             expected = data[["var1", "var2"]]
             actual = xr.merge([data1.var1, data2.var2], compat="no_conflicts")
    -        assert expected.identical(actual)
    +        assert_identical(expected, actual)
     
             data1["var1"][:, :5] = np.nan
             data2["var1"][:, 5:] = np.nan
    @@ -176,22 +183,22 @@ def test_merge_no_conflicts_multi_var(self):
             del data2["var3"]
     
             actual = xr.merge([data1, data2], compat="no_conflicts")
    -        assert data.equals(actual)
    +        assert_equal(data, actual)
     
         def test_merge_no_conflicts_preserve_attrs(self):
             data = xr.Dataset({"x": ([], 0, {"foo": "bar"})})
             actual = xr.merge([data, data])
    -        assert data.identical(actual)
    +        assert_identical(data, actual)
     
         def test_merge_no_conflicts_broadcast(self):
             datasets = [xr.Dataset({"x": ("y", [0])}), xr.Dataset({"x": np.nan})]
             actual = xr.merge(datasets)
             expected = xr.Dataset({"x": ("y", [0])})
    -        assert expected.identical(actual)
    +        assert_identical(expected, actual)
     
             datasets = [xr.Dataset({"x": ("y", [np.nan])}), xr.Dataset({"x": 0})]
             actual = xr.merge(datasets)
    -        assert expected.identical(actual)
    +        assert_identical(expected, actual)
     
     
     class TestMergeMethod:
    @@ -201,17 +208,17 @@ def test_merge(self):
             ds2 = data[["var3"]]
             expected = data[["var1", "var3"]]
             actual = ds1.merge(ds2)
    -        assert expected.identical(actual)
    +        assert_identical(expected, actual)
     
             actual = ds2.merge(ds1)
    -        assert expected.identical(actual)
    +        assert_identical(expected, actual)
     
             actual = data.merge(data)
    -        assert data.identical(actual)
    +        assert_identical(data, actual)
             actual = data.reset_coords(drop=True).merge(data)
    -        assert data.identical(actual)
    +        assert_identical(data, actual)
             actual = data.merge(data.reset_coords(drop=True))
    -        assert data.identical(actual)
    +        assert_identical(data, actual)
     
             with pytest.raises(ValueError):
                 ds1.merge(ds2.rename({"var3": "var1"}))
    @@ -224,19 +231,19 @@ def test_merge_broadcast_equals(self):
             ds1 = xr.Dataset({"x": 0})
             ds2 = xr.Dataset({"x": ("y", [0, 0])})
             actual = ds1.merge(ds2)
    -        assert ds2.identical(actual)
    +        assert_identical(ds2, actual)
     
             actual = ds2.merge(ds1)
    -        assert ds2.identical(actual)
    +        assert_identical(ds2, actual)
     
             actual = ds1.copy()
             actual.update(ds2)
    -        assert ds2.identical(actual)
    +        assert_identical(ds2, actual)
     
             ds1 = xr.Dataset({"x": np.nan})
             ds2 = xr.Dataset({"x": ("y", [np.nan, np.nan])})
             actual = ds1.merge(ds2)
    -        assert ds2.identical(actual)
    +        assert_identical(ds2, actual)
     
         def test_merge_compat(self):
             ds1 = xr.Dataset({"x": 0})
    @@ -276,16 +283,22 @@ def test_merge_auto_align(self):
             assert expected.identical(ds1.merge(ds2, join="inner"))
             assert expected.identical(ds2.merge(ds1, join="inner"))
     
    -    @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0])
    +    @pytest.mark.parametrize("fill_value", [dtypes.NA, 2, 2.0, {"a": 2, "b": 1}])
         def test_merge_fill_value(self, fill_value):
             ds1 = xr.Dataset({"a": ("x", [1, 2]), "x": [0, 1]})
             ds2 = xr.Dataset({"b": ("x", [3, 4]), "x": [1, 2]})
             if fill_value == dtypes.NA:
                 # if we supply the default, we expect the missing value for a
                 # float array
    -            fill_value = np.nan
    +            fill_value_a = fill_value_b = np.nan
    +        elif isinstance(fill_value, dict):
    +            fill_value_a = fill_value["a"]
    +            fill_value_b = fill_value["b"]
    +        else:
    +            fill_value_a = fill_value_b = fill_value
    +
             expected = xr.Dataset(
    -            {"a": ("x", [1, 2, fill_value]), "b": ("x", [fill_value, 3, 4])},
    +            {"a": ("x", [1, 2, fill_value_a]), "b": ("x", [fill_value_b, 3, 4])},
                 {"x": [0, 1, 2]},
             )
             assert expected.identical(ds1.merge(ds2, fill_value=fill_value))
    diff --git a/xarray/tests/test_missing.py b/xarray/tests/test_missing.py
    index 731cd165244..2ab3508b667 100644
    --- a/xarray/tests/test_missing.py
    +++ b/xarray/tests/test_missing.py
    @@ -365,10 +365,25 @@ def test_interpolate_dask():
     def test_interpolate_dask_raises_for_invalid_chunk_dim():
         da, _ = make_interpolate_example_data((40, 40), 0.5)
         da = da.chunk({"time": 5})
    -    with raises_regex(ValueError, "dask='parallelized' consists of multiple"):
    +    # this checks for ValueError in dask.array.apply_gufunc
    +    with raises_regex(ValueError, "consists of multiple chunks"):
             da.interpolate_na("time")
     
     
    +@requires_dask
    +@requires_scipy
    +@pytest.mark.parametrize("dtype, method", [(int, "linear"), (int, "nearest")])
    +def test_interpolate_dask_expected_dtype(dtype, method):
    +    da = xr.DataArray(
    +        data=np.array([0, 1], dtype=dtype),
    +        dims=["time"],
    +        coords=dict(time=np.array([0, 1])),
    +    ).chunk(dict(time=2))
    +    da = da.interp(time=np.array([0, 0.5, 1, 2]), method=method)
    +
    +    assert da.dtype == da.compute().dtype
    +
    +
     @requires_bottleneck
     def test_ffill():
         da = xr.DataArray(np.array([4, 5, np.nan], dtype=np.float64), dims="x")
    @@ -534,6 +549,18 @@ def test_get_clean_interp_index_potential_overflow():
         get_clean_interp_index(da, "time")
     
     
    +@pytest.mark.parametrize("index", ([0, 2, 1], [0, 1, 1]))
    +def test_get_clean_interp_index_strict(index):
    +    da = xr.DataArray([0, 1, 2], dims=("x",), coords={"x": index})
    +
    +    with pytest.raises(ValueError):
    +        get_clean_interp_index(da, "x")
    +
    +    clean = get_clean_interp_index(da, "x", strict=False)
    +    np.testing.assert_array_equal(index, clean)
    +    assert clean.dtype == np.float64
    +
    +
     @pytest.fixture
     def da_time():
         return xr.DataArray(
    diff --git a/xarray/tests/test_nputils.py b/xarray/tests/test_nputils.py
    index 1002a9dd9e3..ccb825dc7e9 100644
    --- a/xarray/tests/test_nputils.py
    +++ b/xarray/tests/test_nputils.py
    @@ -1,4 +1,5 @@
     import numpy as np
    +import pytest
     from numpy.testing import assert_array_equal
     
     from xarray.core.nputils import NumpyVIndexAdapter, _is_contiguous, rolling_window
    @@ -47,3 +48,19 @@ def test_rolling():
         actual = rolling_window(x, axis=-1, window=3, center=False, fill_value=0.0)
         expected = np.stack([expected, expected * 1.1], axis=0)
         assert_array_equal(actual, expected)
    +
    +
    +@pytest.mark.parametrize("center", [[True, True], [False, False]])
    +@pytest.mark.parametrize("axis", [(0, 1), (1, 2), (2, 0)])
    +def test_nd_rolling(center, axis):
    +    x = np.arange(7 * 6 * 8).reshape(7, 6, 8).astype(float)
    +    window = [3, 3]
    +    actual = rolling_window(
    +        x, axis=axis, window=window, center=center, fill_value=np.nan
    +    )
    +    expected = x
    +    for ax, win, cent in zip(axis, window, center):
    +        expected = rolling_window(
    +            expected, axis=ax, window=win, center=cent, fill_value=np.nan
    +        )
    +    assert_array_equal(actual, expected)
    diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py
    index bf1f9ed60bb..47b15446f1d 100644
    --- a/xarray/tests/test_plot.py
    +++ b/xarray/tests/test_plot.py
    @@ -1,5 +1,6 @@
    +import contextlib
     import inspect
    -from copy import deepcopy
    +from copy import copy
     from datetime import datetime
     
     import numpy as np
    @@ -15,6 +16,7 @@
         _build_discrete_cmap,
         _color_palette,
         _determine_cmap_params,
    +    get_axis,
         label_from_attrs,
     )
     
    @@ -23,6 +25,7 @@
         assert_equal,
         has_nc_time_axis,
         raises_regex,
    +    requires_cartopy,
         requires_cftime,
         requires_matplotlib,
         requires_nc_time_axis,
    @@ -36,6 +39,39 @@
     except ImportError:
         pass
     
    +try:
    +    import cartopy as ctpy  # type: ignore
    +except ImportError:
    +    ctpy = None
    +
    +
    +@contextlib.contextmanager
    +def figure_context(*args, **kwargs):
    +    """context manager which autocloses a figure (even if the test failed)"""
    +
    +    try:
    +        yield None
    +    finally:
    +        plt.close("all")
    +
    +
    +@pytest.fixture(scope="function", autouse=True)
    +def test_all_figures_closed():
    +    """meta-test to ensure all figures are closed at the end of a test
    +
    +    Notes:  Scope is kept to module (only invoke this function once per test
    +    module) else tests cannot be run in parallel (locally). Disadvantage: only
    +    catches one open figure per run. May still give a false positive if tests
    +    are run in parallel.
    +    """
    +    yield None
    +
    +    open_figs = len(plt.get_fignums())
    +    if open_figs:
    +        raise RuntimeError(
    +            f"tests did not close all figures ({open_figs} figures open)"
    +        )
    +
     
     @pytest.mark.flaky
     @pytest.mark.skip(reason="maybe flaky")
    @@ -81,6 +117,13 @@ def easy_array(shape, start=0, stop=1):
         return a.reshape(shape)
     
     
    +def get_colorbar_label(colorbar):
    +    if colorbar.orientation == "vertical":
    +        return colorbar.ax.get_ylabel()
    +    else:
    +        return colorbar.ax.get_xlabel()
    +
    +
     @requires_matplotlib
     class PlotTestCase:
         @pytest.fixture(autouse=True)
    @@ -111,6 +154,12 @@ class TestPlot(PlotTestCase):
         def setup_array(self):
             self.darray = DataArray(easy_array((2, 3, 4)))
     
    +    def test_accessor(self):
    +        from ..plot.plot import _PlotMethods
    +
    +        assert DataArray.plot is _PlotMethods
    +        assert isinstance(self.darray.plot, _PlotMethods)
    +
         def test_label_from_attrs(self):
             da = self.darray.copy()
             assert "" == label_from_attrs(da)
    @@ -136,14 +185,14 @@ def test_label_from_attrs(self):
         def test1d(self):
             self.darray[:, 0, 0].plot()
     
    -        with raises_regex(ValueError, "None"):
    +        with raises_regex(ValueError, "x must be one of None, 'dim_0'"):
                 self.darray[:, 0, 0].plot(x="dim_1")
     
             with raises_regex(TypeError, "complex128"):
                 (self.darray[:, 0, 0] + 1j).plot()
     
         def test_1d_bool(self):
    -        xr.ones_like(self.darray[:, 0, 0], dtype=np.bool).plot()
    +        xr.ones_like(self.darray[:, 0, 0], dtype=bool).plot()
     
         def test_1d_x_y_kw(self):
             z = np.arange(10)
    @@ -155,14 +204,31 @@ def test_1d_x_y_kw(self):
             for aa, (x, y) in enumerate(xy):
                 da.plot(x=x, y=y, ax=ax.flat[aa])
     
    -        with raises_regex(ValueError, "cannot"):
    +        with raises_regex(ValueError, "Cannot specify both"):
                 da.plot(x="z", y="z")
     
    -        with raises_regex(ValueError, "None"):
    -            da.plot(x="f", y="z")
    +        error_msg = "must be one of None, 'z'"
    +        with raises_regex(ValueError, f"x {error_msg}"):
    +            da.plot(x="f")
     
    -        with raises_regex(ValueError, "None"):
    -            da.plot(x="z", y="f")
    +        with raises_regex(ValueError, f"y {error_msg}"):
    +            da.plot(y="f")
    +
    +    def test_multiindex_level_as_coord(self):
    +        da = xr.DataArray(
    +            np.arange(5),
    +            dims="x",
    +            coords=dict(a=("x", np.arange(5)), b=("x", np.arange(5, 10))),
    +        )
    +        da = da.set_index(x=["a", "b"])
    +
    +        for x in ["a", "b"]:
    +            h = da.plot(x=x)[0]
    +            assert_array_equal(h.get_xdata(), da[x].values)
    +
    +        for y in ["a", "b"]:
    +            h = da.plot(y=y)[0]
    +            assert_array_equal(h.get_ydata(), da[y].values)
     
         # Test for bug in GH issue #2725
         def test_infer_line_data(self):
    @@ -211,7 +277,7 @@ def test_2d_line(self):
             self.darray[:, :, 0].plot.line(x="dim_0", hue="dim_1")
             self.darray[:, :, 0].plot.line(y="dim_0", hue="dim_1")
     
    -        with raises_regex(ValueError, "cannot"):
    +        with raises_regex(ValueError, "Cannot"):
                 self.darray[:, :, 0].plot.line(x="dim_1", y="dim_0", hue="dim_1")
     
         def test_2d_line_accepts_legend_kw(self):
    @@ -247,12 +313,13 @@ def test_2d_coords_line_plot(self):
                 coords={"lat": (("y", "x"), lat), "lon": (("y", "x"), lon)},
             )
     
    -        hdl = da.plot.line(x="lon", hue="x")
    -        assert len(hdl) == 5
    +        with figure_context():
    +            hdl = da.plot.line(x="lon", hue="x")
    +            assert len(hdl) == 5
     
    -        plt.clf()
    -        hdl = da.plot.line(x="lon", hue="y")
    -        assert len(hdl) == 4
    +        with figure_context():
    +            hdl = da.plot.line(x="lon", hue="y")
    +            assert len(hdl) == 4
     
             with pytest.raises(ValueError, match="For 2D inputs, hue must be a dimension"):
                 da.plot.line(x="lon", hue="lat")
    @@ -298,7 +365,7 @@ def test_contourf_cmap_set(self):
     
             cmap = mpl.cm.viridis
     
    -        # deepcopy to ensure cmap is not changed by contourf()
    +        # use copy to ensure cmap is not changed by contourf()
             # Set vmin and vmax so that _build_discrete_colormap is called with
             # extend='both'. extend is passed to
             # mpl.colors.from_levels_and_colors(), which returns a result with
    @@ -306,12 +373,12 @@ def test_contourf_cmap_set(self):
             # extend='neither' (but if extend='neither' the under and over values
             # would not be used because the data would all be within the plotted
             # range)
    -        pl = a.plot.contourf(cmap=deepcopy(cmap), vmin=0.1, vmax=0.9)
    +        pl = a.plot.contourf(cmap=copy(cmap), vmin=0.1, vmax=0.9)
     
             # check the set_bad color
    -        assert np.all(
    -            pl.cmap(np.ma.masked_invalid([np.nan]))[0]
    -            == cmap(np.ma.masked_invalid([np.nan]))[0]
    +        assert_array_equal(
    +            pl.cmap(np.ma.masked_invalid([np.nan]))[0],
    +            cmap(np.ma.masked_invalid([np.nan]))[0],
             )
     
             # check the set_under color
    @@ -323,10 +390,8 @@ def test_contourf_cmap_set(self):
         def test_contourf_cmap_set_with_bad_under_over(self):
             a = DataArray(easy_array((4, 4)), dims=["z", "time"])
     
    -        # Make a copy here because we want a local cmap that we will modify.
    -        # Use deepcopy because matplotlib Colormap objects have tuple members
    -        # and we want to ensure we do not change the original.
    -        cmap = deepcopy(mpl.cm.viridis)
    +        # make a copy here because we want a local cmap that we will modify.
    +        cmap = copy(mpl.cm.viridis)
     
             cmap.set_bad("w")
             # check we actually changed the set_bad color
    @@ -343,13 +408,13 @@ def test_contourf_cmap_set_with_bad_under_over(self):
             # check we actually changed the set_over color
             assert cmap(np.inf) != mpl.cm.viridis(-np.inf)
     
    -        # deepcopy to ensure cmap is not changed by contourf()
    -        pl = a.plot.contourf(cmap=deepcopy(cmap))
    +        # copy to ensure cmap is not changed by contourf()
    +        pl = a.plot.contourf(cmap=copy(cmap))
     
             # check the set_bad color has been kept
    -        assert np.all(
    -            pl.cmap(np.ma.masked_invalid([np.nan]))[0]
    -            == cmap(np.ma.masked_invalid([np.nan]))[0]
    +        assert_array_equal(
    +            pl.cmap(np.ma.masked_invalid([np.nan]))[0],
    +            cmap(np.ma.masked_invalid([np.nan]))[0],
             )
     
             # check the set_under color has been kept
    @@ -527,6 +592,20 @@ def test_coord_with_interval_xy(self):
             bins = [-1, 0, 1, 2]
             self.darray.groupby_bins("dim_0", bins).mean(...).dim_0_bins.plot()
     
    +    @pytest.mark.parametrize("dim", ("x", "y"))
    +    def test_labels_with_units_with_interval(self, dim):
    +        """Test line plot with intervals and a units attribute."""
    +        bins = [-1, 0, 1, 2]
    +        arr = self.darray.groupby_bins("dim_0", bins).mean(...)
    +        arr.dim_0_bins.attrs["units"] = "m"
    +
    +        (mappable,) = arr.plot(**{dim: "dim_0_bins"})
    +        ax = mappable.figure.gca()
    +        actual = getattr(ax, f"get_{dim}label")()
    +
    +        expected = "dim_0_bins_center [m]"
    +        assert actual == expected
    +
     
     class TestPlot1D(PlotTestCase):
         @pytest.fixture(autouse=True)
    @@ -606,11 +685,13 @@ def setUp(self):
             self.darray = DataArray(easy_array((2, 3, 4)))
     
         def test_step(self):
    -        self.darray[0, 0].plot.step()
    +        hdl = self.darray[0, 0].plot.step()
    +        assert "steps" in hdl[0].get_drawstyle()
     
    -    @pytest.mark.parametrize("ds", ["pre", "post", "mid"])
    -    def test_step_with_drawstyle(self, ds):
    -        self.darray[0, 0].plot.step(drawstyle=ds)
    +    @pytest.mark.parametrize("where", ["pre", "post", "mid"])
    +    def test_step_with_where(self, where):
    +        hdl = self.darray[0, 0].plot.step(where=where)
    +        assert hdl[0].get_drawstyle() == f"steps-{where}"
     
         def test_coord_with_interval_step(self):
             """Test step plot with intervals."""
    @@ -728,18 +809,22 @@ def test_integer_levels(self):
             # default is to cover full data range but with no guarantee on Nlevels
             for level in np.arange(2, 10, dtype=int):
                 cmap_params = _determine_cmap_params(data, levels=level)
    -            assert cmap_params["vmin"] == cmap_params["levels"][0]
    -            assert cmap_params["vmax"] == cmap_params["levels"][-1]
    +            assert cmap_params["vmin"] is None
    +            assert cmap_params["vmax"] is None
    +            assert cmap_params["norm"].vmin == cmap_params["levels"][0]
    +            assert cmap_params["norm"].vmax == cmap_params["levels"][-1]
                 assert cmap_params["extend"] == "neither"
     
             # with min max we are more strict
             cmap_params = _determine_cmap_params(
                 data, levels=5, vmin=0, vmax=5, cmap="Blues"
             )
    -        assert cmap_params["vmin"] == 0
    -        assert cmap_params["vmax"] == 5
    -        assert cmap_params["vmin"] == cmap_params["levels"][0]
    -        assert cmap_params["vmax"] == cmap_params["levels"][-1]
    +        assert cmap_params["vmin"] is None
    +        assert cmap_params["vmax"] is None
    +        assert cmap_params["norm"].vmin == 0
    +        assert cmap_params["norm"].vmax == 5
    +        assert cmap_params["norm"].vmin == cmap_params["levels"][0]
    +        assert cmap_params["norm"].vmax == cmap_params["levels"][-1]
             assert cmap_params["cmap"].name == "Blues"
             assert cmap_params["extend"] == "neither"
             assert cmap_params["cmap"].N == 4
    @@ -763,8 +848,10 @@ def test_list_levels(self):
             orig_levels = [0, 1, 2, 3, 4, 5]
             # vmin and vmax should be ignored if levels are explicitly provided
             cmap_params = _determine_cmap_params(data, levels=orig_levels, vmin=0, vmax=3)
    -        assert cmap_params["vmin"] == 0
    -        assert cmap_params["vmax"] == 5
    +        assert cmap_params["vmin"] is None
    +        assert cmap_params["vmax"] is None
    +        assert cmap_params["norm"].vmin == 0
    +        assert cmap_params["norm"].vmax == 5
             assert cmap_params["cmap"].N == 5
             assert cmap_params["norm"].N == 6
     
    @@ -854,23 +941,26 @@ def test_norm_sets_vmin_vmax(self):
             vmin = self.data.min()
             vmax = self.data.max()
     
    -        for norm, extend in zip(
    +        for norm, extend, levels in zip(
                 [
    +                mpl.colors.Normalize(),
                     mpl.colors.Normalize(),
                     mpl.colors.Normalize(vmin + 0.1, vmax - 0.1),
                     mpl.colors.Normalize(None, vmax - 0.1),
                     mpl.colors.Normalize(vmin + 0.1, None),
                 ],
    -            ["neither", "both", "max", "min"],
    +            ["neither", "neither", "both", "max", "min"],
    +            [7, None, None, None, None],
             ):
     
                 test_min = vmin if norm.vmin is None else norm.vmin
                 test_max = vmax if norm.vmax is None else norm.vmax
     
    -            cmap_params = _determine_cmap_params(self.data, norm=norm)
    -
    -            assert cmap_params["vmin"] == test_min
    -            assert cmap_params["vmax"] == test_max
    +            cmap_params = _determine_cmap_params(self.data, norm=norm, levels=levels)
    +            assert cmap_params["vmin"] is None
    +            assert cmap_params["vmax"] is None
    +            assert cmap_params["norm"].vmin == test_min
    +            assert cmap_params["norm"].vmax == test_max
                 assert cmap_params["extend"] == extend
                 assert cmap_params["norm"] == norm
     
    @@ -886,6 +976,9 @@ def setUp(self):
             self.darray = DataArray(distance, list(zip(("y", "x"), (y, x))))
             self.data_min = distance.min()
             self.data_max = distance.max()
    +        yield
    +        # Remove all matplotlib figures
    +        plt.close("all")
     
         @pytest.mark.slow
         def test_recover_from_seaborn_jet_exception(self):
    @@ -1013,7 +1106,7 @@ def test_1d_raises_valueerror(self):
                 self.plotfunc(self.darray[0, :])
     
         def test_bool(self):
    -        xr.ones_like(self.darray, dtype=np.bool).plot()
    +        xr.ones_like(self.darray, dtype=bool).plot()
     
         def test_complex_raises_typeerror(self):
             with raises_regex(TypeError, "complex128"):
    @@ -1031,6 +1124,16 @@ def test_nonnumeric_index_raises_typeerror(self):
             with raises_regex(TypeError, r"[Pp]lot"):
                 self.plotfunc(a)
     
    +    def test_multiindex_raises_typeerror(self):
    +        a = DataArray(
    +            easy_array((3, 2)),
    +            dims=("x", "y"),
    +            coords=dict(x=("x", [0, 1, 2]), a=("y", [0, 1]), b=("y", [2, 3])),
    +        )
    +        a = a.set_index(y=("a", "b"))
    +        with raises_regex(TypeError, r"[Pp]lot"):
    +            self.plotfunc(a)
    +
         def test_can_pass_in_axis(self):
             self.pass_in_axis(self.plotmethod)
     
    @@ -1139,15 +1242,16 @@ def test_positional_coord_string(self):
             assert "y_long_name [y_units]" == ax.get_ylabel()
     
         def test_bad_x_string_exception(self):
    -        with raises_regex(ValueError, "x and y must be coordinate variables"):
    +
    +        with raises_regex(ValueError, "x and y cannot be equal."):
    +            self.plotmethod(x="y", y="y")
    +
    +        error_msg = "must be one of None, 'x', 'x2d', 'y', 'y2d'"
    +        with raises_regex(ValueError, f"x {error_msg}"):
                 self.plotmethod("not_a_real_dim", "y")
    -        with raises_regex(
    -            ValueError, "x must be a dimension name if y is not supplied"
    -        ):
    +        with raises_regex(ValueError, f"x {error_msg}"):
                 self.plotmethod(x="not_a_real_dim")
    -        with raises_regex(
    -            ValueError, "y must be a dimension name if x is not supplied"
    -        ):
    +        with raises_regex(ValueError, f"y {error_msg}"):
                 self.plotmethod(y="not_a_real_dim")
             self.darray.coords["z"] = 100
     
    @@ -1182,6 +1286,27 @@ def test_non_linked_coords_transpose(self):
             # simply ensure that these high coords were passed over
             assert np.min(ax.get_xlim()) > 100.0
     
    +    def test_multiindex_level_as_coord(self):
    +        da = DataArray(
    +            easy_array((3, 2)),
    +            dims=("x", "y"),
    +            coords=dict(x=("x", [0, 1, 2]), a=("y", [0, 1]), b=("y", [2, 3])),
    +        )
    +        da = da.set_index(y=["a", "b"])
    +
    +        for x, y in (("a", "x"), ("b", "x"), ("x", "a"), ("x", "b")):
    +            self.plotfunc(da, x=x, y=y)
    +
    +            ax = plt.gca()
    +            assert x == ax.get_xlabel()
    +            assert y == ax.get_ylabel()
    +
    +        with raises_regex(ValueError, "levels of the same MultiIndex"):
    +            self.plotfunc(da, x="a", y="b")
    +
    +        with raises_regex(ValueError, "y must be one of None, 'a', 'b', 'x'"):
    +            self.plotfunc(da, x="a", y="y")
    +
         def test_default_title(self):
             a = DataArray(easy_array((4, 3, 2)), dims=["a", "b", "c"])
             a.coords["c"] = [0, 1]
    @@ -1351,7 +1476,7 @@ def test_facetgrid_cbar_kwargs(self):
     
             # catch contour case
             if hasattr(g, "cbar"):
    -            assert g.cbar._label == "test_label"
    +            assert get_colorbar_label(g.cbar) == "test_label"
     
         def test_facetgrid_no_cbar_ax(self):
             a = easy_array((10, 15, 2, 3))
    @@ -1649,14 +1774,15 @@ def test_regression_rgb_imshow_dim_size_one(self):
     
         def test_origin_overrides_xyincrease(self):
             da = DataArray(easy_array((3, 2)), coords=[[-2, 0, 2], [-1, 1]])
    -        da.plot.imshow(origin="upper")
    -        assert plt.xlim()[0] < 0
    -        assert plt.ylim()[1] < 0
    +        with figure_context():
    +            da.plot.imshow(origin="upper")
    +            assert plt.xlim()[0] < 0
    +            assert plt.ylim()[1] < 0
     
    -        plt.clf()
    -        da.plot.imshow(origin="lower")
    -        assert plt.xlim()[0] < 0
    -        assert plt.ylim()[0] < 0
    +        with figure_context():
    +            da.plot.imshow(origin="lower")
    +            assert plt.xlim()[0] < 0
    +            assert plt.ylim()[0] < 0
     
     
     class TestFacetGrid(PlotTestCase):
    @@ -2048,6 +2174,12 @@ def setUp(self):
             ds.B.attrs["units"] = "Bunits"
             self.ds = ds
     
    +    def test_accessor(self):
    +        from ..plot.dataset_plot import _Dataset_PlotMethods
    +
    +        assert Dataset.plot is _Dataset_PlotMethods
    +        assert isinstance(self.ds.plot, _Dataset_PlotMethods)
    +
         @pytest.mark.parametrize(
             "add_guide, hue_style, legend, colorbar",
             [
    @@ -2122,7 +2254,7 @@ def test_datetime_hue(self, hue_style):
             ds2.plot.scatter(x="A", y="B", hue="hue", hue_style=hue_style)
     
         def test_facetgrid_hue_style(self):
    -        # Can't move this to pytest.mark.parametrize because py36-bare-minimum
    +        # Can't move this to pytest.mark.parametrize because py37-bare-minimum
             # doesn't have matplotlib.
             for hue_style, map_type in (
                 ("discrete", list),
    @@ -2151,6 +2283,24 @@ def test_non_numeric_legend(self):
             with pytest.raises(ValueError):
                 ds2.plot.scatter(x="A", y="B", hue="hue", hue_style="continuous")
     
    +    def test_legend_labels(self):
    +        # regression test for #4126: incorrect legend labels
    +        ds2 = self.ds.copy()
    +        ds2["hue"] = ["a", "a", "b", "b"]
    +        lines = ds2.plot.scatter(x="A", y="B", hue="hue")
    +        assert [t.get_text() for t in lines[0].axes.get_legend().texts] == ["a", "b"]
    +
    +    def test_legend_labels_facetgrid(self):
    +        ds2 = self.ds.copy()
    +        ds2["hue"] = ["d", "a", "c", "b"]
    +        g = ds2.plot.scatter(x="A", y="B", hue="hue", col="col")
    +        legend_labels = tuple(t.get_text() for t in g.figlegend.texts)
    +        attached_labels = [
    +            tuple(m.get_label() for m in mappables_per_ax)
    +            for mappables_per_ax in g._mappables
    +        ]
    +        assert list(set(attached_labels)) == [legend_labels]
    +
         def test_add_legend_by_default(self):
             sc = self.ds.plot.scatter(x="A", y="B", hue="hue")
             assert len(sc.figure.axes) == 2
    @@ -2175,6 +2325,7 @@ def test_datetime_line_plot(self):
             self.darray.plot.line()
     
     
    +@pytest.mark.filterwarnings("ignore:setting an array element with a sequence")
     @requires_nc_time_axis
     @requires_cftime
     class TestCFDatetimePlot(PlotTestCase):
    @@ -2237,60 +2388,60 @@ class TestAxesKwargs:
         @pytest.mark.parametrize("da", test_da_list)
         @pytest.mark.parametrize("xincrease", [True, False])
         def test_xincrease_kwarg(self, da, xincrease):
    -        plt.clf()
    -        da.plot(xincrease=xincrease)
    -        assert plt.gca().xaxis_inverted() == (not xincrease)
    +        with figure_context():
    +            da.plot(xincrease=xincrease)
    +            assert plt.gca().xaxis_inverted() == (not xincrease)
     
         @pytest.mark.parametrize("da", test_da_list)
         @pytest.mark.parametrize("yincrease", [True, False])
         def test_yincrease_kwarg(self, da, yincrease):
    -        plt.clf()
    -        da.plot(yincrease=yincrease)
    -        assert plt.gca().yaxis_inverted() == (not yincrease)
    +        with figure_context():
    +            da.plot(yincrease=yincrease)
    +            assert plt.gca().yaxis_inverted() == (not yincrease)
     
         @pytest.mark.parametrize("da", test_da_list)
         @pytest.mark.parametrize("xscale", ["linear", "log", "logit", "symlog"])
         def test_xscale_kwarg(self, da, xscale):
    -        plt.clf()
    -        da.plot(xscale=xscale)
    -        assert plt.gca().get_xscale() == xscale
    +        with figure_context():
    +            da.plot(xscale=xscale)
    +            assert plt.gca().get_xscale() == xscale
     
         @pytest.mark.parametrize(
             "da", [DataArray(easy_array((10,))), DataArray(easy_array((10, 3)))]
         )
         @pytest.mark.parametrize("yscale", ["linear", "log", "logit", "symlog"])
         def test_yscale_kwarg(self, da, yscale):
    -        plt.clf()
    -        da.plot(yscale=yscale)
    -        assert plt.gca().get_yscale() == yscale
    +        with figure_context():
    +            da.plot(yscale=yscale)
    +            assert plt.gca().get_yscale() == yscale
     
         @pytest.mark.parametrize("da", test_da_list)
         def test_xlim_kwarg(self, da):
    -        plt.clf()
    -        expected = (0.0, 1000.0)
    -        da.plot(xlim=[0, 1000])
    -        assert plt.gca().get_xlim() == expected
    +        with figure_context():
    +            expected = (0.0, 1000.0)
    +            da.plot(xlim=[0, 1000])
    +            assert plt.gca().get_xlim() == expected
     
         @pytest.mark.parametrize("da", test_da_list)
         def test_ylim_kwarg(self, da):
    -        plt.clf()
    -        da.plot(ylim=[0, 1000])
    -        expected = (0.0, 1000.0)
    -        assert plt.gca().get_ylim() == expected
    +        with figure_context():
    +            da.plot(ylim=[0, 1000])
    +            expected = (0.0, 1000.0)
    +            assert plt.gca().get_ylim() == expected
     
         @pytest.mark.parametrize("da", test_da_list)
         def test_xticks_kwarg(self, da):
    -        plt.clf()
    -        da.plot(xticks=np.arange(5))
    -        expected = np.arange(5).tolist()
    -        assert np.all(plt.gca().get_xticks() == expected)
    +        with figure_context():
    +            da.plot(xticks=np.arange(5))
    +            expected = np.arange(5).tolist()
    +            assert_array_equal(plt.gca().get_xticks(), expected)
     
         @pytest.mark.parametrize("da", test_da_list)
         def test_yticks_kwarg(self, da):
    -        plt.clf()
    -        da.plot(yticks=np.arange(5))
    -        expected = np.arange(5)
    -        assert np.all(plt.gca().get_yticks() == expected)
    +        with figure_context():
    +            da.plot(yticks=np.arange(5))
    +            expected = np.arange(5)
    +            assert_array_equal(plt.gca().get_yticks(), expected)
     
     
     @requires_matplotlib
    @@ -2305,8 +2456,10 @@ def test_plot_transposed_nondim_coord(plotfunc):
             dims=["s", "x"],
             coords={"x": x, "s": s, "z": (("s", "x"), z), "zt": (("x", "s"), z.T)},
         )
    -    getattr(da.plot, plotfunc)(x="x", y="zt")
    -    getattr(da.plot, plotfunc)(x="zt", y="x")
    +    with figure_context():
    +        getattr(da.plot, plotfunc)(x="x", y="zt")
    +    with figure_context():
    +        getattr(da.plot, plotfunc)(x="zt", y="x")
     
     
     @requires_matplotlib
    @@ -2314,11 +2467,12 @@ def test_plot_transposed_nondim_coord(plotfunc):
     def test_plot_transposes_properly(plotfunc):
         # test that we aren't mistakenly transposing when the 2 dimensions have equal sizes.
         da = xr.DataArray([np.sin(2 * np.pi / 10 * np.arange(10))] * 10, dims=("y", "x"))
    -    hdl = getattr(da.plot, plotfunc)(x="x", y="y")
    -    # get_array doesn't work for contour, contourf. It returns the colormap intervals.
    -    # pcolormesh returns 1D array but imshow returns a 2D array so it is necessary
    -    # to ravel() on the LHS
    -    assert np.all(hdl.get_array().ravel() == da.to_masked_array().ravel())
    +    with figure_context():
    +        hdl = getattr(da.plot, plotfunc)(x="x", y="y")
    +        # get_array doesn't work for contour, contourf. It returns the colormap intervals.
    +        # pcolormesh returns 1D array but imshow returns a 2D array so it is necessary
    +        # to ravel() on the LHS
    +        assert_array_equal(hdl.get_array().ravel(), da.to_masked_array().ravel())
     
     
     @requires_matplotlib
    @@ -2330,4 +2484,40 @@ def test_facetgrid_single_contour():
         ds = xr.concat([z, z2], dim="time")
         ds["time"] = [0, 1]
     
    -    ds.plot.contour(col="time", levels=[4], colors=["k"])
    +    with figure_context():
    +        ds.plot.contour(col="time", levels=[4], colors=["k"])
    +
    +
    +@requires_matplotlib
    +def test_get_axis():
    +    # test get_axis works with different args combinations
    +    # and return the right type
    +
    +    # cannot provide both ax and figsize
    +    with pytest.raises(ValueError, match="both `figsize` and `ax`"):
    +        get_axis(figsize=[4, 4], size=None, aspect=None, ax="something")
    +
    +    # cannot provide both ax and size
    +    with pytest.raises(ValueError, match="both `size` and `ax`"):
    +        get_axis(figsize=None, size=200, aspect=4 / 3, ax="something")
    +
    +    # cannot provide both size and figsize
    +    with pytest.raises(ValueError, match="both `figsize` and `size`"):
    +        get_axis(figsize=[4, 4], size=200, aspect=None, ax=None)
    +
    +    # cannot provide aspect and size
    +    with pytest.raises(ValueError, match="`aspect` argument without `size`"):
    +        get_axis(figsize=None, size=None, aspect=4 / 3, ax=None)
    +
    +    with figure_context():
    +        ax = get_axis()
    +        assert isinstance(ax, mpl.axes.Axes)
    +
    +
    +@requires_cartopy
    +def test_get_axis_cartopy():
    +
    +    kwargs = {"projection": ctpy.crs.PlateCarree()}
    +    with figure_context():
    +        ax = get_axis(**kwargs)
    +        assert isinstance(ax, ctpy.mpl.geoaxes.GeoAxesSubplot)
    diff --git a/xarray/tests/test_plugins.py b/xarray/tests/test_plugins.py
    new file mode 100644
    index 00000000000..110ef47209f
    --- /dev/null
    +++ b/xarray/tests/test_plugins.py
    @@ -0,0 +1,111 @@
    +from unittest import mock
    +
    +import pkg_resources
    +import pytest
    +
    +from xarray.backends import common, plugins
    +
    +
    +def dummy_open_dataset_args(filename_or_obj, *args):
    +    pass
    +
    +
    +def dummy_open_dataset_kwargs(filename_or_obj, **kwargs):
    +    pass
    +
    +
    +def dummy_open_dataset(filename_or_obj, *, decoder):
    +    pass
    +
    +
    +dummy_cfgrib = common.BackendEntrypoint(dummy_open_dataset)
    +
    +
    +@pytest.fixture
    +def dummy_duplicated_entrypoints():
    +    specs = [
    +        "engine1 = xarray.tests.test_plugins:backend_1",
    +        "engine1 = xarray.tests.test_plugins:backend_2",
    +        "engine2 = xarray.tests.test_plugins:backend_1",
    +        "engine2 = xarray.tests.test_plugins:backend_2",
    +    ]
    +    eps = [pkg_resources.EntryPoint.parse(spec) for spec in specs]
    +    return eps
    +
    +
    +@pytest.mark.filterwarnings("ignore:Found")
    +def test_remove_duplicates(dummy_duplicated_entrypoints):
    +    with pytest.warns(RuntimeWarning):
    +        entrypoints = plugins.remove_duplicates(dummy_duplicated_entrypoints)
    +    assert len(entrypoints) == 2
    +
    +
    +def test_remove_duplicates_warnings(dummy_duplicated_entrypoints):
    +
    +    with pytest.warns(RuntimeWarning) as record:
    +        _ = plugins.remove_duplicates(dummy_duplicated_entrypoints)
    +
    +    assert len(record) == 2
    +    message0 = str(record[0].message)
    +    message1 = str(record[1].message)
    +    assert "entrypoints" in message0
    +    assert "entrypoints" in message1
    +
    +
    +@mock.patch("pkg_resources.EntryPoint.load", mock.MagicMock(return_value=None))
    +def test_create_engines_dict():
    +    specs = [
    +        "engine1 = xarray.tests.test_plugins:backend_1",
    +        "engine2 = xarray.tests.test_plugins:backend_2",
    +    ]
    +    entrypoints = [pkg_resources.EntryPoint.parse(spec) for spec in specs]
    +    engines = plugins.create_engines_dict(entrypoints)
    +    assert len(engines) == 2
    +    assert engines.keys() == set(("engine1", "engine2"))
    +
    +
    +def test_set_missing_parameters():
    +    backend_1 = common.BackendEntrypoint(dummy_open_dataset)
    +    backend_2 = common.BackendEntrypoint(dummy_open_dataset, ("filename_or_obj",))
    +    engines = {"engine_1": backend_1, "engine_2": backend_2}
    +    plugins.set_missing_parameters(engines)
    +
    +    assert len(engines) == 2
    +    engine_1 = engines["engine_1"]
    +    assert engine_1.open_dataset_parameters == ("filename_or_obj", "decoder")
    +    engine_2 = engines["engine_2"]
    +    assert engine_2.open_dataset_parameters == ("filename_or_obj",)
    +
    +
    +def test_set_missing_parameters_raise_error():
    +
    +    backend = common.BackendEntrypoint(dummy_open_dataset_args)
    +    with pytest.raises(TypeError):
    +        plugins.set_missing_parameters({"engine": backend})
    +
    +    backend = common.BackendEntrypoint(
    +        dummy_open_dataset_args, ("filename_or_obj", "decoder")
    +    )
    +    plugins.set_missing_parameters({"engine": backend})
    +
    +    backend = common.BackendEntrypoint(dummy_open_dataset_kwargs)
    +    with pytest.raises(TypeError):
    +        plugins.set_missing_parameters({"engine": backend})
    +
    +    backend = plugins.BackendEntrypoint(
    +        dummy_open_dataset_kwargs, ("filename_or_obj", "decoder")
    +    )
    +    plugins.set_missing_parameters({"engine": backend})
    +
    +
    +@mock.patch("pkg_resources.EntryPoint.load", mock.MagicMock(return_value=dummy_cfgrib))
    +def test_build_engines():
    +    dummy_cfgrib_pkg_entrypoint = pkg_resources.EntryPoint.parse(
    +        "cfgrib = xarray.tests.test_plugins:backend_1"
    +    )
    +    backend_entrypoints = plugins.build_engines([dummy_cfgrib_pkg_entrypoint])
    +    assert backend_entrypoints["cfgrib"] is dummy_cfgrib
    +    assert backend_entrypoints["cfgrib"].open_dataset_parameters == (
    +        "filename_or_obj",
    +        "decoder",
    +    )
    diff --git a/xarray/tests/test_sparse.py b/xarray/tests/test_sparse.py
    index f3c09ba6a5f..49b6a58694e 100644
    --- a/xarray/tests/test_sparse.py
    +++ b/xarray/tests/test_sparse.py
    @@ -62,7 +62,13 @@ def __init__(self, meth, *args, **kwargs):
             self.kwargs = kwargs
     
         def __call__(self, obj):
    -        return getattr(obj, self.meth)(*self.args, **self.kwargs)
    +
    +        # cannot pass np.sum when using pytest-xdist
    +        kwargs = self.kwargs.copy()
    +        if "func" in self.kwargs:
    +            kwargs["func"] = getattr(np, kwargs["func"])
    +
    +        return getattr(obj, self.meth)(*self.args, **kwargs)
     
         def __repr__(self):
             return f"obj.{self.meth}(*{self.args}, **{self.kwargs})"
    @@ -94,7 +100,7 @@ def test_variable_property(prop):
             (do("any"), False),
             (do("astype", dtype=int), True),
             (do("clip", min=0, max=1), True),
    -        (do("coarsen", windows={"x": 2}, func=np.sum), True),
    +        (do("coarsen", windows={"x": 2}, func="sum"), True),
             (do("compute"), True),
             (do("conj"), True),
             (do("copy"), True),
    @@ -191,7 +197,7 @@ def test_variable_property(prop):
                 marks=xfail(reason="Only implemented for NumPy arrays (via bottleneck)"),
             ),
             param(
    -            do("reduce", func=np.sum, dim="x"),
    +            do("reduce", func="sum", dim="x"),
                 True,
                 marks=xfail(reason="Coercion to dense"),
             ),
    @@ -359,7 +365,7 @@ def test_dataarray_property(prop):
             (do("sel", x=[0, 1, 2]), True),
             (do("shift"), True),
             (do("sortby", "x", ascending=False), True),
    -        (do("stack", z={"x", "y"}), True),
    +        (do("stack", z=["x", "y"]), True),
             (do("transpose"), True),
             # TODO
             # set_index
    @@ -450,7 +456,7 @@ def test_dataarray_property(prop):
                 marks=xfail(reason="Missing implementation for np.nanmedian"),
             ),
             (do("notnull"), True),
    -        (do("pipe", np.sum, axis=1), True),
    +        (do("pipe", func="sum", axis=1), True),
             (do("prod"), False),
             param(
                 do("quantile", q=0.5),
    @@ -463,7 +469,7 @@ def test_dataarray_property(prop):
                 marks=xfail(reason="Only implemented for NumPy arrays (via bottleneck)"),
             ),
             param(
    -            do("reduce", np.sum, dim="x"),
    +            do("reduce", func="sum", dim="x"),
                 False,
                 marks=xfail(reason="Coercion to dense"),
             ),
    @@ -645,7 +651,7 @@ def test_stack(self):
             assert_equal(expected, stacked)
     
             roundtripped = stacked.unstack()
    -        assert arr.identical(roundtripped)
    +        assert_identical(arr, roundtripped)
     
         @pytest.mark.filterwarnings("ignore::PendingDeprecationWarning")
         def test_ufuncs(self):
    @@ -877,13 +883,12 @@ def test_dask_token():
     
     
     @requires_dask
    -def test_apply_ufunc_meta_to_blockwise():
    -    da = xr.DataArray(np.zeros((2, 3)), dims=["x", "y"]).chunk({"x": 2, "y": 1})
    -    sparse_meta = sparse.COO.from_numpy(np.zeros((0, 0)))
    +def test_apply_ufunc_check_meta_coherence():
    +    s = sparse.COO.from_numpy(np.array([0, 0, 1, 2]))
    +    a = DataArray(s)
    +    ac = a.chunk(2)
    +    sparse_meta = ac.data._meta
     
    -    # if dask computed meta, it would be np.ndarray
    -    expected = xr.apply_ufunc(
    -        lambda x: x, da, dask="parallelized", output_dtypes=[da.dtype], meta=sparse_meta
    -    ).data._meta
    +    result = xr.apply_ufunc(lambda x: x, ac, dask="parallelized").data._meta
     
    -    assert_sparse_equal(expected, sparse_meta)
    +    assert_sparse_equal(result, sparse_meta)
    diff --git a/xarray/tests/test_testing.py b/xarray/tests/test_testing.py
    index 041b7341ade..30ea6aaaee9 100644
    --- a/xarray/tests/test_testing.py
    +++ b/xarray/tests/test_testing.py
    @@ -1,7 +1,129 @@
    +import numpy as np
    +import pytest
    +
     import xarray as xr
    +from xarray.core.npcompat import IS_NEP18_ACTIVE
    +
    +from . import has_dask
    +
    +try:
    +    from dask.array import from_array as dask_from_array
    +except ImportError:
    +    dask_from_array = lambda x: x
    +
    +try:
    +    import pint
    +
    +    unit_registry = pint.UnitRegistry(force_ndarray_like=True)
    +
    +    def quantity(x):
    +        return unit_registry.Quantity(x, "m")
    +
    +    has_pint = True
    +except ImportError:
    +
    +    def quantity(x):
    +        return x
    +
    +    has_pint = False
     
     
     def test_allclose_regression():
         x = xr.DataArray(1.01)
         y = xr.DataArray(1.02)
         xr.testing.assert_allclose(x, y, atol=0.01)
    +
    +
    +@pytest.mark.parametrize(
    +    "obj1,obj2",
    +    (
    +        pytest.param(
    +            xr.Variable("x", [1e-17, 2]), xr.Variable("x", [0, 3]), id="Variable"
    +        ),
    +        pytest.param(
    +            xr.DataArray([1e-17, 2], dims="x"),
    +            xr.DataArray([0, 3], dims="x"),
    +            id="DataArray",
    +        ),
    +        pytest.param(
    +            xr.Dataset({"a": ("x", [1e-17, 2]), "b": ("y", [-2e-18, 2])}),
    +            xr.Dataset({"a": ("x", [0, 2]), "b": ("y", [0, 1])}),
    +            id="Dataset",
    +        ),
    +    ),
    +)
    +def test_assert_allclose(obj1, obj2):
    +    with pytest.raises(AssertionError):
    +        xr.testing.assert_allclose(obj1, obj2)
    +
    +
    +@pytest.mark.filterwarnings("error")
    +@pytest.mark.parametrize(
    +    "duckarray",
    +    (
    +        pytest.param(np.array, id="numpy"),
    +        pytest.param(
    +            dask_from_array,
    +            id="dask",
    +            marks=pytest.mark.skipif(not has_dask, reason="requires dask"),
    +        ),
    +        pytest.param(
    +            quantity,
    +            id="pint",
    +            marks=pytest.mark.skipif(not has_pint, reason="requires pint"),
    +        ),
    +    ),
    +)
    +@pytest.mark.parametrize(
    +    ["obj1", "obj2"],
    +    (
    +        pytest.param([1e-10, 2], [0.0, 2.0], id="both arrays"),
    +        pytest.param([1e-17, 2], 0.0, id="second scalar"),
    +        pytest.param(0.0, [1e-17, 2], id="first scalar"),
    +    ),
    +)
    +def test_assert_duckarray_equal_failing(duckarray, obj1, obj2):
    +    # TODO: actually check the repr
    +    a = duckarray(obj1)
    +    b = duckarray(obj2)
    +    with pytest.raises(AssertionError):
    +        xr.testing.assert_duckarray_equal(a, b)
    +
    +
    +@pytest.mark.filterwarnings("error")
    +@pytest.mark.parametrize(
    +    "duckarray",
    +    (
    +        pytest.param(
    +            np.array,
    +            id="numpy",
    +            marks=pytest.mark.skipif(
    +                not IS_NEP18_ACTIVE,
    +                reason="NUMPY_EXPERIMENTAL_ARRAY_FUNCTION is not enabled",
    +            ),
    +        ),
    +        pytest.param(
    +            dask_from_array,
    +            id="dask",
    +            marks=pytest.mark.skipif(not has_dask, reason="requires dask"),
    +        ),
    +        pytest.param(
    +            quantity,
    +            id="pint",
    +            marks=pytest.mark.skipif(not has_pint, reason="requires pint"),
    +        ),
    +    ),
    +)
    +@pytest.mark.parametrize(
    +    ["obj1", "obj2"],
    +    (
    +        pytest.param([0, 2], [0.0, 2.0], id="both arrays"),
    +        pytest.param([0, 0], 0.0, id="second scalar"),
    +        pytest.param(0.0, [0, 0], id="first scalar"),
    +    ),
    +)
    +def test_assert_duckarray_equal(duckarray, obj1, obj2):
    +    a = duckarray(obj1)
    +    b = duckarray(obj2)
    +
    +    xr.testing.assert_duckarray_equal(a, b)
    diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py
    index b2aa65c9a04..76dd830de23 100644
    --- a/xarray/tests/test_units.py
    +++ b/xarray/tests/test_units.py
    @@ -1,17 +1,16 @@
     import functools
     import operator
    -from distutils.version import LooseVersion
     
     import numpy as np
     import pandas as pd
     import pytest
     
     import xarray as xr
    -from xarray.core import formatting
    +from xarray.core import dtypes
     from xarray.core.npcompat import IS_NEP18_ACTIVE
    -from xarray.testing import assert_allclose, assert_identical
     
    -from .test_variable import _PAD_XR_NP_ARGS, VariableSubclassobjects
    +from . import assert_allclose, assert_duckarray_allclose, assert_equal, assert_identical
    +from .test_variable import _PAD_XR_NP_ARGS
     
     pint = pytest.importorskip("pint")
     DimensionalityError = pint.errors.DimensionalityError
    @@ -27,12 +26,7 @@
         pytest.mark.skipif(
             not IS_NEP18_ACTIVE, reason="NUMPY_EXPERIMENTAL_ARRAY_FUNCTION is not enabled"
         ),
    -    # TODO: remove this once pint has a released version with __array_function__
    -    pytest.mark.skipif(
    -        not hasattr(unit_registry.Quantity, "__array_function__"),
    -        reason="pint does not implement __array_function__ yet",
    -    ),
    -    # pytest.mark.filterwarnings("ignore:::pint[.*]"),
    +    pytest.mark.filterwarnings("error::pint.UnitStrippedWarning"),
     ]
     
     
    @@ -51,10 +45,23 @@ def dimensionality(obj):
     def compatible_mappings(first, second):
         return {
             key: is_compatible(unit1, unit2)
    -        for key, (unit1, unit2) in merge_mappings(first, second)
    +        for key, (unit1, unit2) in zip_mappings(first, second)
         }
     
     
    +def merge_mappings(base, *mappings):
    +    result = base.copy()
    +    for m in mappings:
    +        result.update(m)
    +
    +    return result
    +
    +
    +def zip_mappings(*mappings):
    +    for key in set(mappings[0]).intersection(*mappings[1:]):
    +        yield key, tuple(m[key] for m in mappings)
    +
    +
     def array_extract_units(obj):
         if isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)):
             obj = obj.data
    @@ -173,12 +180,7 @@ def attach_units(obj, units):
             new_obj = xr.Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs)
         elif isinstance(obj, xr.DataArray):
             # try the array name, "data" and None, then fall back to dimensionless
    -        data_units = (
    -            units.get(obj.name, None)
    -            or units.get("data", None)
    -            or units.get(None, None)
    -            or 1
    -        )
    +        data_units = units.get(obj.name, None) or units.get(None, None) or 1
     
             data = array_attach_units(obj.data, data_units)
     
    @@ -257,50 +259,11 @@ def assert_units_equal(a, b):
         assert extract_units(a) == extract_units(b)
     
     
    -def assert_equal_with_units(a, b):
    -    # works like xr.testing.assert_equal, but also explicitly checks units
    -    # so, it is more like assert_identical
    -    __tracebackhide__ = True
    -
    -    if isinstance(a, xr.Dataset) or isinstance(b, xr.Dataset):
    -        a_units = extract_units(a)
    -        b_units = extract_units(b)
    -
    -        a_without_units = strip_units(a)
    -        b_without_units = strip_units(b)
    -
    -        assert a_without_units.equals(b_without_units), formatting.diff_dataset_repr(
    -            a, b, "equals"
    -        )
    -        assert a_units == b_units
    -    else:
    -        a = a if not isinstance(a, (xr.DataArray, xr.Variable)) else a.data
    -        b = b if not isinstance(b, (xr.DataArray, xr.Variable)) else b.data
    -
    -        assert type(a) == type(b) or (
    -            isinstance(a, Quantity) and isinstance(b, Quantity)
    -        )
    -
    -        # workaround until pint implements allclose in __array_function__
    -        if isinstance(a, Quantity) or isinstance(b, Quantity):
    -            assert (
    -                hasattr(a, "magnitude") and hasattr(b, "magnitude")
    -            ) and np.allclose(a.magnitude, b.magnitude, equal_nan=True)
    -            assert (hasattr(a, "units") and hasattr(b, "units")) and a.units == b.units
    -        else:
    -            assert np.allclose(a, b, equal_nan=True)
    -
    -
    -@pytest.fixture(params=[float, int])
    +@pytest.fixture(params=[np.dtype(float), np.dtype(int)], ids=str)
     def dtype(request):
         return request.param
     
     
    -def merge_mappings(*mappings):
    -    for key in set(mappings[0]).intersection(*mappings[1:]):
    -        yield key, tuple(m[key] for m in mappings)
    -
    -
     def merge_args(default_args, new_args):
         from itertools import zip_longest
     
    @@ -312,7 +275,7 @@ def merge_args(default_args, new_args):
     
     
     class method:
    -    """ wrapper class to help with passing methods via parametrize
    +    """wrapper class to help with passing methods via parametrize
     
         This is works a bit similar to using `partial(Class.method, arg, kwarg)`
         """
    @@ -329,19 +292,29 @@ def __call__(self, obj, *args, **kwargs):
             all_args = merge_args(self.args, args)
             all_kwargs = {**self.kwargs, **kwargs}
     
    +        xarray_classes = (
    +            xr.Variable,
    +            xr.DataArray,
    +            xr.Dataset,
    +            xr.core.groupby.GroupBy,
    +        )
    +
    +        if not isinstance(obj, xarray_classes):
    +            # remove typical xarray args like "dim"
    +            exclude_kwargs = ("dim", "dims")
    +            all_kwargs = {
    +                key: value
    +                for key, value in all_kwargs.items()
    +                if key not in exclude_kwargs
    +            }
    +
             func = getattr(obj, self.name, None)
    +
             if func is None or not isinstance(func, Callable):
                 # fall back to module level numpy functions if not a xarray object
                 if not isinstance(obj, (xr.Variable, xr.DataArray, xr.Dataset)):
                     numpy_func = getattr(np, self.name)
                     func = partial(numpy_func, obj)
    -                # remove typical xarray args like "dim"
    -                exclude_kwargs = ("dim", "dims")
    -                all_kwargs = {
    -                    key: value
    -                    for key, value in all_kwargs.items()
    -                    if key not in exclude_kwargs
    -                }
                 else:
                     raise AttributeError(f"{obj} has no method named '{self.name}'")
     
    @@ -352,7 +325,7 @@ def __repr__(self):
     
     
     class function:
    -    """ wrapper class for numpy functions
    +    """wrapper class for numpy functions
     
         Same as method, but the name is used for referencing numpy functions
         """
    @@ -386,14 +359,31 @@ def __repr__(self):
             return f"function_{self.name}"
     
     
    -def test_apply_ufunc_dataarray(dtype):
    +@pytest.mark.parametrize(
    +    "variant",
    +    (
    +        "data",
    +        pytest.param(
    +            "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +        ),
    +        "coords",
    +    ),
    +)
    +def test_apply_ufunc_dataarray(variant, dtype):
    +    variants = {
    +        "data": (unit_registry.m, 1, 1),
    +        "dims": (1, unit_registry.m, 1),
    +        "coords": (1, 1, unit_registry.m),
    +    }
    +    data_unit, dim_unit, coord_unit = variants.get(variant)
         func = functools.partial(
             xr.apply_ufunc, np.mean, input_core_dims=[["x"]], kwargs={"axis": -1}
         )
     
    -    array = np.linspace(0, 10, 20).astype(dtype) * unit_registry.m
    -    x = np.arange(20) * unit_registry.s
    -    data_array = xr.DataArray(data=array, dims="x", coords={"x": x})
    +    array = np.linspace(0, 10, 20).astype(dtype) * data_unit
    +    x = np.arange(20) * dim_unit
    +    u = np.linspace(-1, 1, 20) * coord_unit
    +    data_array = xr.DataArray(data=array, dims="x", coords={"x": x, "u": ("x", u)})
     
         expected = attach_units(func(strip_units(data_array)), extract_units(data_array))
         actual = func(data_array)
    @@ -402,20 +392,39 @@ def test_apply_ufunc_dataarray(dtype):
         assert_identical(expected, actual)
     
     
    -def test_apply_ufunc_dataset(dtype):
    +@pytest.mark.parametrize(
    +    "variant",
    +    (
    +        "data",
    +        pytest.param(
    +            "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +        ),
    +        "coords",
    +    ),
    +)
    +def test_apply_ufunc_dataset(variant, dtype):
    +    variants = {
    +        "data": (unit_registry.m, 1, 1),
    +        "dims": (1, unit_registry.m, 1),
    +        "coords": (1, 1, unit_registry.s),
    +    }
    +    data_unit, dim_unit, coord_unit = variants.get(variant)
    +
         func = functools.partial(
             xr.apply_ufunc, np.mean, input_core_dims=[["x"]], kwargs={"axis": -1}
         )
     
    -    array1 = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * unit_registry.m
    -    array2 = np.linspace(0, 10, 5).astype(dtype) * unit_registry.m
    +    array1 = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit
    +    array2 = np.linspace(0, 10, 5).astype(dtype) * data_unit
     
    -    x = np.arange(5) * unit_registry.s
    -    y = np.arange(10) * unit_registry.m
    +    x = np.arange(5) * dim_unit
    +    y = np.arange(10) * dim_unit
    +
    +    u = np.linspace(-1, 1, 10) * coord_unit
     
         ds = xr.Dataset(
             data_vars={"a": (("x", "y"), array1), "b": ("x", array2)},
    -        coords={"x": x, "y": y},
    +        coords={"x": x, "y": y, "u": ("y", u)},
         )
     
         expected = attach_units(func(strip_units(ds)), extract_units(ds))
    @@ -442,44 +451,61 @@ def test_apply_ufunc_dataset(dtype):
         "variant",
         (
             "data",
    -        pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
    +        pytest.param(
    +            "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +        ),
             "coords",
         ),
     )
    -@pytest.mark.parametrize("fill_value", (10, np.nan))
    -def test_align_dataarray(fill_value, variant, unit, error, dtype):
    +@pytest.mark.parametrize("value", (10, dtypes.NA))
    +def test_align_dataarray(value, variant, unit, error, dtype):
    +    if variant == "coords" and (
    +        value != dtypes.NA or isinstance(unit, unit_registry.Unit)
    +    ):
    +        pytest.xfail(
    +            reason=(
    +                "fill_value is used for both data variables and coords. "
    +                "See https://github.com/pydata/xarray/issues/4165"
    +            )
    +        )
    +
    +    fill_value = dtypes.get_fill_value(dtype) if value == dtypes.NA else value
    +
         original_unit = unit_registry.m
     
         variants = {
    -        "data": (unit, original_unit, original_unit),
    -        "dims": (original_unit, unit, original_unit),
    -        "coords": (original_unit, original_unit, unit),
    +        "data": ((original_unit, unit), (1, 1), (1, 1)),
    +        "dims": ((1, 1), (original_unit, unit), (1, 1)),
    +        "coords": ((1, 1), (1, 1), (original_unit, unit)),
         }
    -    data_unit, dim_unit, coord_unit = variants.get(variant)
    +    (
    +        (data_unit1, data_unit2),
    +        (dim_unit1, dim_unit2),
    +        (coord_unit1, coord_unit2),
    +    ) = variants.get(variant)
     
    -    array1 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * original_unit
    -    array2 = np.linspace(0, 8, 2 * 5).reshape(2, 5).astype(dtype) * data_unit
    -    x = np.arange(2) * original_unit
    +    array1 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * data_unit1
    +    array2 = np.linspace(0, 8, 2 * 5).reshape(2, 5).astype(dtype) * data_unit2
     
    -    y1 = np.arange(5) * original_unit
    -    y2 = np.arange(2, 7) * dim_unit
    -    y_a1 = np.array([3, 5, 7, 8, 9]) * original_unit
    -    y_a2 = np.array([7, 8, 9, 11, 13]) * coord_unit
    +    x = np.arange(2) * dim_unit1
    +    y1 = np.arange(5) * dim_unit1
    +    y2 = np.arange(2, 7) * dim_unit2
    +
    +    u1 = np.array([3, 5, 7, 8, 9]) * coord_unit1
    +    u2 = np.array([7, 8, 9, 11, 13]) * coord_unit2
     
         coords1 = {"x": x, "y": y1}
         coords2 = {"x": x, "y": y2}
         if variant == "coords":
    -        coords1["y_a"] = ("y", y_a1)
    -        coords2["y_a"] = ("y", y_a2)
    +        coords1["y_a"] = ("y", u1)
    +        coords2["y_a"] = ("y", u2)
     
         data_array1 = xr.DataArray(data=array1, coords=coords1, dims=("x", "y"))
         data_array2 = xr.DataArray(data=array2, coords=coords2, dims=("x", "y"))
     
    -    fill_value = fill_value * data_unit
    +    fill_value = fill_value * data_unit2
         func = function(xr.align, join="outer", fill_value=fill_value)
    -    if error is not None and not (
    -        np.isnan(fill_value) and not isinstance(fill_value, Quantity)
    -    ):
    +    if error is not None and (value != dtypes.NA or isinstance(fill_value, Quantity)):
             with pytest.raises(error):
                 func(data_array1, data_array2)
     
    @@ -487,7 +513,7 @@ def test_align_dataarray(fill_value, variant, unit, error, dtype):
     
         stripped_kwargs = {
             key: strip_units(
    -            convert_units(value, {None: original_unit if data_unit != 1 else None})
    +            convert_units(value, {None: data_unit1 if data_unit2 != 1 else None})
             )
             for key, value in func.kwargs.items()
         }
    @@ -529,45 +555,61 @@ def test_align_dataarray(fill_value, variant, unit, error, dtype):
         "variant",
         (
             "data",
    -        pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
    +        pytest.param(
    +            "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +        ),
             "coords",
         ),
     )
    -@pytest.mark.parametrize("fill_value", (np.float64(10), np.float64(np.nan)))
    -def test_align_dataset(fill_value, unit, variant, error, dtype):
    +@pytest.mark.parametrize("value", (10, dtypes.NA))
    +def test_align_dataset(value, unit, variant, error, dtype):
    +    if variant == "coords" and (
    +        value != dtypes.NA or isinstance(unit, unit_registry.Unit)
    +    ):
    +        pytest.xfail(
    +            reason=(
    +                "fill_value is used for both data variables and coords. "
    +                "See https://github.com/pydata/xarray/issues/4165"
    +            )
    +        )
    +
    +    fill_value = dtypes.get_fill_value(dtype) if value == dtypes.NA else value
    +
         original_unit = unit_registry.m
     
         variants = {
    -        "data": (unit, original_unit, original_unit),
    -        "dims": (original_unit, unit, original_unit),
    -        "coords": (original_unit, original_unit, unit),
    +        "data": ((original_unit, unit), (1, 1), (1, 1)),
    +        "dims": ((1, 1), (original_unit, unit), (1, 1)),
    +        "coords": ((1, 1), (1, 1), (original_unit, unit)),
         }
    -    data_unit, dim_unit, coord_unit = variants.get(variant)
    +    (
    +        (data_unit1, data_unit2),
    +        (dim_unit1, dim_unit2),
    +        (coord_unit1, coord_unit2),
    +    ) = variants.get(variant)
     
    -    array1 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * original_unit
    -    array2 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * data_unit
    +    array1 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * data_unit1
    +    array2 = np.linspace(0, 10, 2 * 5).reshape(2, 5).astype(dtype) * data_unit2
     
    -    x = np.arange(2) * original_unit
    +    x = np.arange(2) * dim_unit1
    +    y1 = np.arange(5) * dim_unit1
    +    y2 = np.arange(2, 7) * dim_unit2
     
    -    y1 = np.arange(5) * original_unit
    -    y2 = np.arange(2, 7) * dim_unit
    -    y_a1 = np.array([3, 5, 7, 8, 9]) * original_unit
    -    y_a2 = np.array([7, 8, 9, 11, 13]) * coord_unit
    +    u1 = np.array([3, 5, 7, 8, 9]) * coord_unit1
    +    u2 = np.array([7, 8, 9, 11, 13]) * coord_unit2
     
         coords1 = {"x": x, "y": y1}
         coords2 = {"x": x, "y": y2}
         if variant == "coords":
    -        coords1["y_a"] = ("y", y_a1)
    -        coords2["y_a"] = ("y", y_a2)
    +        coords1["u"] = ("y", u1)
    +        coords2["u"] = ("y", u2)
     
         ds1 = xr.Dataset(data_vars={"a": (("x", "y"), array1)}, coords=coords1)
         ds2 = xr.Dataset(data_vars={"a": (("x", "y"), array2)}, coords=coords2)
     
    -    fill_value = fill_value * data_unit
    +    fill_value = fill_value * data_unit2
         func = function(xr.align, join="outer", fill_value=fill_value)
    -    if error is not None and not (
    -        np.isnan(fill_value) and not isinstance(fill_value, Quantity)
    -    ):
    +    if error is not None and (value != dtypes.NA or isinstance(fill_value, Quantity)):
             with pytest.raises(error):
                 func(ds1, ds2)
     
    @@ -575,14 +617,16 @@ def test_align_dataset(fill_value, unit, variant, error, dtype):
     
         stripped_kwargs = {
             key: strip_units(
    -            convert_units(value, {None: original_unit if data_unit != 1 else None})
    +            convert_units(value, {None: data_unit1 if data_unit2 != 1 else None})
             )
             for key, value in func.kwargs.items()
         }
         units_a = extract_units(ds1)
         units_b = extract_units(ds2)
         expected_a, expected_b = func(
    -        strip_units(ds1), strip_units(convert_units(ds2, units_a)), **stripped_kwargs
    +        strip_units(ds1),
    +        strip_units(convert_units(ds2, units_a)),
    +        **stripped_kwargs,
         )
         expected_a = attach_units(expected_a, units_a)
         if isinstance(array2, Quantity):
    @@ -599,6 +643,7 @@ def test_align_dataset(fill_value, unit, variant, error, dtype):
     
     
     def test_broadcast_dataarray(dtype):
    +    # uses align internally so more thorough tests are not needed
         array1 = np.linspace(0, 10, 2) * unit_registry.Pa
         array2 = np.linspace(0, 10, 3) * unit_registry.Pa
     
    @@ -620,6 +665,7 @@ def test_broadcast_dataarray(dtype):
     
     
     def test_broadcast_dataset(dtype):
    +    # uses align internally so more thorough tests are not needed
         array1 = np.linspace(0, 10, 2) * unit_registry.Pa
         array2 = np.linspace(0, 10, 3) * unit_registry.Pa
     
    @@ -671,7 +717,9 @@ def test_broadcast_dataset(dtype):
         "variant",
         (
             "data",
    -        pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
    +        pytest.param(
    +            "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +        ),
             "coords",
         ),
     )
    @@ -679,31 +727,35 @@ def test_combine_by_coords(variant, unit, error, dtype):
         original_unit = unit_registry.m
     
         variants = {
    -        "data": (unit, original_unit, original_unit),
    -        "dims": (original_unit, unit, original_unit),
    -        "coords": (original_unit, original_unit, unit),
    +        "data": ((original_unit, unit), (1, 1), (1, 1)),
    +        "dims": ((1, 1), (original_unit, unit), (1, 1)),
    +        "coords": ((1, 1), (1, 1), (original_unit, unit)),
         }
    -    data_unit, dim_unit, coord_unit = variants.get(variant)
    -
    -    array1 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit
    -    array2 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit
    -    x = np.arange(1, 4) * 10 * original_unit
    -    y = np.arange(2) * original_unit
    -    z = np.arange(3) * original_unit
    -
    -    other_array1 = np.ones_like(array1) * data_unit
    -    other_array2 = np.ones_like(array2) * data_unit
    -    other_x = np.arange(1, 4) * 10 * dim_unit
    -    other_y = np.arange(2, 4) * dim_unit
    -    other_z = np.arange(3, 6) * coord_unit
    +    (
    +        (data_unit1, data_unit2),
    +        (dim_unit1, dim_unit2),
    +        (coord_unit1, coord_unit2),
    +    ) = variants.get(variant)
    +
    +    array1 = np.zeros(shape=(2, 3), dtype=dtype) * data_unit1
    +    array2 = np.zeros(shape=(2, 3), dtype=dtype) * data_unit1
    +    x = np.arange(1, 4) * 10 * dim_unit1
    +    y = np.arange(2) * dim_unit1
    +    u = np.arange(3) * coord_unit1
    +
    +    other_array1 = np.ones_like(array1) * data_unit2
    +    other_array2 = np.ones_like(array2) * data_unit2
    +    other_x = np.arange(1, 4) * 10 * dim_unit2
    +    other_y = np.arange(2, 4) * dim_unit2
    +    other_u = np.arange(3, 6) * coord_unit2
     
         ds = xr.Dataset(
             data_vars={"a": (("y", "x"), array1), "b": (("y", "x"), array2)},
    -        coords={"x": x, "y": y, "z": ("x", z)},
    +        coords={"x": x, "y": y, "u": ("x", u)},
         )
         other = xr.Dataset(
             data_vars={"a": (("y", "x"), other_array1), "b": (("y", "x"), other_array2)},
    -        coords={"x": other_x, "y": other_y, "z": ("x", other_z)},
    +        coords={"x": other_x, "y": other_y, "u": ("x", other_u)},
         )
     
         if error is not None:
    @@ -742,7 +794,9 @@ def test_combine_by_coords(variant, unit, error, dtype):
         "variant",
         (
             "data",
    -        pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
    +        pytest.param(
    +            "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +        ),
             "coords",
         ),
     )
    @@ -750,18 +804,22 @@ def test_combine_nested(variant, unit, error, dtype):
         original_unit = unit_registry.m
     
         variants = {
    -        "data": (unit, original_unit, original_unit),
    -        "dims": (original_unit, unit, original_unit),
    -        "coords": (original_unit, original_unit, unit),
    +        "data": ((original_unit, unit), (1, 1), (1, 1)),
    +        "dims": ((1, 1), (original_unit, unit), (1, 1)),
    +        "coords": ((1, 1), (1, 1), (original_unit, unit)),
         }
    -    data_unit, dim_unit, coord_unit = variants.get(variant)
    +    (
    +        (data_unit1, data_unit2),
    +        (dim_unit1, dim_unit2),
    +        (coord_unit1, coord_unit2),
    +    ) = variants.get(variant)
     
    -    array1 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit
    -    array2 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit
    +    array1 = np.zeros(shape=(2, 3), dtype=dtype) * data_unit1
    +    array2 = np.zeros(shape=(2, 3), dtype=dtype) * data_unit1
     
    -    x = np.arange(1, 4) * 10 * original_unit
    -    y = np.arange(2) * original_unit
    -    z = np.arange(3) * original_unit
    +    x = np.arange(1, 4) * 10 * dim_unit1
    +    y = np.arange(2) * dim_unit1
    +    z = np.arange(3) * coord_unit1
     
         ds1 = xr.Dataset(
             data_vars={"a": (("y", "x"), array1), "b": (("y", "x"), array2)},
    @@ -769,35 +827,35 @@ def test_combine_nested(variant, unit, error, dtype):
         )
         ds2 = xr.Dataset(
             data_vars={
    -            "a": (("y", "x"), np.ones_like(array1) * data_unit),
    -            "b": (("y", "x"), np.ones_like(array2) * data_unit),
    +            "a": (("y", "x"), np.ones_like(array1) * data_unit2),
    +            "b": (("y", "x"), np.ones_like(array2) * data_unit2),
             },
             coords={
    -            "x": np.arange(3) * dim_unit,
    -            "y": np.arange(2, 4) * dim_unit,
    -            "z": ("x", np.arange(-3, 0) * coord_unit),
    +            "x": np.arange(3) * dim_unit2,
    +            "y": np.arange(2, 4) * dim_unit2,
    +            "z": ("x", np.arange(-3, 0) * coord_unit2),
             },
         )
         ds3 = xr.Dataset(
             data_vars={
    -            "a": (("y", "x"), np.zeros_like(array1) * np.nan * data_unit),
    -            "b": (("y", "x"), np.zeros_like(array2) * np.nan * data_unit),
    +            "a": (("y", "x"), np.full_like(array1, fill_value=np.nan) * data_unit2),
    +            "b": (("y", "x"), np.full_like(array2, fill_value=np.nan) * data_unit2),
             },
             coords={
    -            "x": np.arange(3, 6) * dim_unit,
    -            "y": np.arange(4, 6) * dim_unit,
    -            "z": ("x", np.arange(3, 6) * coord_unit),
    +            "x": np.arange(3, 6) * dim_unit2,
    +            "y": np.arange(4, 6) * dim_unit2,
    +            "z": ("x", np.arange(3, 6) * coord_unit2),
             },
         )
         ds4 = xr.Dataset(
             data_vars={
    -            "a": (("y", "x"), -1 * np.ones_like(array1) * data_unit),
    -            "b": (("y", "x"), -1 * np.ones_like(array2) * data_unit),
    +            "a": (("y", "x"), -1 * np.ones_like(array1) * data_unit2),
    +            "b": (("y", "x"), -1 * np.ones_like(array2) * data_unit2),
             },
             coords={
    -            "x": np.arange(6, 9) * dim_unit,
    -            "y": np.arange(6, 8) * dim_unit,
    -            "z": ("x", np.arange(6, 9) * coord_unit),
    +            "x": np.arange(6, 9) * dim_unit2,
    +            "y": np.arange(6, 8) * dim_unit2,
    +            "z": ("x", np.arange(6, 9) * coord_unit2),
             },
         )
     
    @@ -842,22 +900,37 @@ def test_combine_nested(variant, unit, error, dtype):
         "variant",
         (
             "data",
    -        pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
    +        pytest.param(
    +            "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +        ),
    +        "coords",
         ),
     )
     def test_concat_dataarray(variant, unit, error, dtype):
         original_unit = unit_registry.m
     
    -    variants = {"data": (unit, original_unit), "dims": (original_unit, unit)}
    -    data_unit, dims_unit = variants.get(variant)
    +    variants = {
    +        "data": ((original_unit, unit), (1, 1), (1, 1)),
    +        "dims": ((1, 1), (original_unit, unit), (1, 1)),
    +        "coords": ((1, 1), (1, 1), (original_unit, unit)),
    +    }
    +    (
    +        (data_unit1, data_unit2),
    +        (dim_unit1, dim_unit2),
    +        (coord_unit1, coord_unit2),
    +    ) = variants.get(variant)
    +
    +    array1 = np.linspace(0, 5, 10).astype(dtype) * data_unit1
    +    array2 = np.linspace(-5, 0, 5).astype(dtype) * data_unit2
     
    -    array1 = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m
    -    array2 = np.linspace(-5, 0, 5).astype(dtype) * data_unit
    -    x1 = np.arange(5, 15) * original_unit
    -    x2 = np.arange(5) * dims_unit
    +    x1 = np.arange(5, 15) * dim_unit1
    +    x2 = np.arange(5) * dim_unit2
     
    -    arr1 = xr.DataArray(data=array1, coords={"x": x1}, dims="x")
    -    arr2 = xr.DataArray(data=array2, coords={"x": x2}, dims="x")
    +    u1 = np.linspace(1, 2, 10).astype(dtype) * coord_unit1
    +    u2 = np.linspace(0, 1, 5).astype(dtype) * coord_unit2
    +
    +    arr1 = xr.DataArray(data=array1, coords={"x": x1, "u": ("x", u1)}, dims="x")
    +    arr2 = xr.DataArray(data=array2, coords={"x": x2, "u": ("x", u2)}, dims="x")
     
         if error is not None:
             with pytest.raises(error):
    @@ -895,22 +968,37 @@ def test_concat_dataarray(variant, unit, error, dtype):
         "variant",
         (
             "data",
    -        pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
    +        pytest.param(
    +            "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +        ),
    +        "coords",
         ),
     )
     def test_concat_dataset(variant, unit, error, dtype):
         original_unit = unit_registry.m
     
    -    variants = {"data": (unit, original_unit), "dims": (original_unit, unit)}
    -    data_unit, dims_unit = variants.get(variant)
    +    variants = {
    +        "data": ((original_unit, unit), (1, 1), (1, 1)),
    +        "dims": ((1, 1), (original_unit, unit), (1, 1)),
    +        "coords": ((1, 1), (1, 1), (original_unit, unit)),
    +    }
    +    (
    +        (data_unit1, data_unit2),
    +        (dim_unit1, dim_unit2),
    +        (coord_unit1, coord_unit2),
    +    ) = variants.get(variant)
    +
    +    array1 = np.linspace(0, 5, 10).astype(dtype) * data_unit1
    +    array2 = np.linspace(-5, 0, 5).astype(dtype) * data_unit2
     
    -    array1 = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m
    -    array2 = np.linspace(-5, 0, 5).astype(dtype) * data_unit
    -    x1 = np.arange(5, 15) * original_unit
    -    x2 = np.arange(5) * dims_unit
    +    x1 = np.arange(5, 15) * dim_unit1
    +    x2 = np.arange(5) * dim_unit2
    +
    +    u1 = np.linspace(1, 2, 10).astype(dtype) * coord_unit1
    +    u2 = np.linspace(0, 1, 5).astype(dtype) * coord_unit2
     
    -    ds1 = xr.Dataset(data_vars={"a": ("x", array1)}, coords={"x": x1})
    -    ds2 = xr.Dataset(data_vars={"a": ("x", array2)}, coords={"x": x2})
    +    ds1 = xr.Dataset(data_vars={"a": ("x", array1)}, coords={"x": x1, "u": ("x", u1)})
    +    ds2 = xr.Dataset(data_vars={"a": ("x", array2)}, coords={"x": x2, "u": ("x", u2)})
     
         if error is not None:
             with pytest.raises(error):
    @@ -946,7 +1034,9 @@ def test_concat_dataset(variant, unit, error, dtype):
         "variant",
         (
             "data",
    -        pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
    +        pytest.param(
    +            "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +        ),
             "coords",
         ),
     )
    @@ -954,29 +1044,33 @@ def test_merge_dataarray(variant, unit, error, dtype):
         original_unit = unit_registry.m
     
         variants = {
    -        "data": (unit, original_unit, original_unit),
    -        "dims": (original_unit, unit, original_unit),
    -        "coords": (original_unit, original_unit, unit),
    +        "data": ((original_unit, unit), (1, 1), (1, 1)),
    +        "dims": ((1, 1), (original_unit, unit), (1, 1)),
    +        "coords": ((1, 1), (1, 1), (original_unit, unit)),
         }
    -    data_unit, dim_unit, coord_unit = variants.get(variant)
    -
    -    array1 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * original_unit
    -    x1 = np.arange(2) * original_unit
    -    y1 = np.arange(3) * original_unit
    -    u1 = np.linspace(10, 20, 2) * original_unit
    -    v1 = np.linspace(10, 20, 3) * original_unit
    -
    -    array2 = np.linspace(1, 2, 2 * 4).reshape(2, 4).astype(dtype) * data_unit
    -    x2 = np.arange(2, 4) * dim_unit
    -    z2 = np.arange(4) * original_unit
    -    u2 = np.linspace(20, 30, 2) * coord_unit
    -    w2 = np.linspace(10, 20, 4) * original_unit
    -
    -    array3 = np.linspace(0, 2, 3 * 4).reshape(3, 4).astype(dtype) * data_unit
    -    y3 = np.arange(3, 6) * dim_unit
    -    z3 = np.arange(4, 8) * dim_unit
    -    v3 = np.linspace(10, 20, 3) * coord_unit
    -    w3 = np.linspace(10, 20, 4) * coord_unit
    +    (
    +        (data_unit1, data_unit2),
    +        (dim_unit1, dim_unit2),
    +        (coord_unit1, coord_unit2),
    +    ) = variants.get(variant)
    +
    +    array1 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * data_unit1
    +    x1 = np.arange(2) * dim_unit1
    +    y1 = np.arange(3) * dim_unit1
    +    u1 = np.linspace(10, 20, 2) * coord_unit1
    +    v1 = np.linspace(10, 20, 3) * coord_unit1
    +
    +    array2 = np.linspace(1, 2, 2 * 4).reshape(2, 4).astype(dtype) * data_unit2
    +    x2 = np.arange(2, 4) * dim_unit2
    +    z2 = np.arange(4) * dim_unit1
    +    u2 = np.linspace(20, 30, 2) * coord_unit2
    +    w2 = np.linspace(10, 20, 4) * coord_unit1
    +
    +    array3 = np.linspace(0, 2, 3 * 4).reshape(3, 4).astype(dtype) * data_unit2
    +    y3 = np.arange(3, 6) * dim_unit2
    +    z3 = np.arange(4, 8) * dim_unit2
    +    v3 = np.linspace(10, 20, 3) * coord_unit2
    +    w3 = np.linspace(10, 20, 4) * coord_unit2
     
         arr1 = xr.DataArray(
             name="a",
    @@ -1003,31 +1097,22 @@ def test_merge_dataarray(variant, unit, error, dtype):
     
             return
     
    -    units = {name: original_unit for name in list("axyzuvw")}
    -
    -    convert_and_strip = lambda arr: strip_units(convert_units(arr, units))
    -    expected_units = {
    -        "a": original_unit,
    -        "u": original_unit,
    -        "v": original_unit,
    -        "w": original_unit,
    -        "x": original_unit,
    -        "y": original_unit,
    -        "z": original_unit,
    +    units = {
    +        "a": data_unit1,
    +        "u": coord_unit1,
    +        "v": coord_unit1,
    +        "w": coord_unit1,
    +        "x": dim_unit1,
    +        "y": dim_unit1,
    +        "z": dim_unit1,
         }
    +    convert_and_strip = lambda arr: strip_units(convert_units(arr, units))
     
    -    expected = convert_units(
    -        attach_units(
    -            xr.merge(
    -                [
    -                    convert_and_strip(arr1),
    -                    convert_and_strip(arr2),
    -                    convert_and_strip(arr3),
    -                ]
    -            ),
    -            units,
    +    expected = attach_units(
    +        xr.merge(
    +            [convert_and_strip(arr1), convert_and_strip(arr2), convert_and_strip(arr3)]
             ),
    -        expected_units,
    +        units,
         )
     
         actual = xr.merge([arr1, arr2, arr3])
    @@ -1053,7 +1138,9 @@ def test_merge_dataarray(variant, unit, error, dtype):
         "variant",
         (
             "data",
    -        pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
    +        pytest.param(
    +            "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +        ),
             "coords",
         ),
     )
    @@ -1061,43 +1148,47 @@ def test_merge_dataset(variant, unit, error, dtype):
         original_unit = unit_registry.m
     
         variants = {
    -        "data": (unit, original_unit, original_unit),
    -        "dims": (original_unit, unit, original_unit),
    -        "coords": (original_unit, original_unit, unit),
    +        "data": ((original_unit, unit), (1, 1), (1, 1)),
    +        "dims": ((1, 1), (original_unit, unit), (1, 1)),
    +        "coords": ((1, 1), (1, 1), (original_unit, unit)),
         }
    -    data_unit, dim_unit, coord_unit = variants.get(variant)
    +    (
    +        (data_unit1, data_unit2),
    +        (dim_unit1, dim_unit2),
    +        (coord_unit1, coord_unit2),
    +    ) = variants.get(variant)
     
    -    array1 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit
    -    array2 = np.zeros(shape=(2, 3), dtype=dtype) * original_unit
    +    array1 = np.zeros(shape=(2, 3), dtype=dtype) * data_unit1
    +    array2 = np.zeros(shape=(2, 3), dtype=dtype) * data_unit1
     
    -    x = np.arange(11, 14) * original_unit
    -    y = np.arange(2) * original_unit
    -    z = np.arange(3) * original_unit
    +    x = np.arange(11, 14) * dim_unit1
    +    y = np.arange(2) * dim_unit1
    +    u = np.arange(3) * coord_unit1
     
         ds1 = xr.Dataset(
             data_vars={"a": (("y", "x"), array1), "b": (("y", "x"), array2)},
    -        coords={"x": x, "y": y, "u": ("x", z)},
    +        coords={"x": x, "y": y, "u": ("x", u)},
         )
         ds2 = xr.Dataset(
             data_vars={
    -            "a": (("y", "x"), np.ones_like(array1) * data_unit),
    -            "b": (("y", "x"), np.ones_like(array2) * data_unit),
    +            "a": (("y", "x"), np.ones_like(array1) * data_unit2),
    +            "b": (("y", "x"), np.ones_like(array2) * data_unit2),
             },
             coords={
    -            "x": np.arange(3) * dim_unit,
    -            "y": np.arange(2, 4) * dim_unit,
    -            "u": ("x", np.arange(-3, 0) * coord_unit),
    +            "x": np.arange(3) * dim_unit2,
    +            "y": np.arange(2, 4) * dim_unit2,
    +            "u": ("x", np.arange(-3, 0) * coord_unit2),
             },
         )
         ds3 = xr.Dataset(
             data_vars={
    -            "a": (("y", "x"), np.full_like(array1, np.nan) * data_unit),
    -            "b": (("y", "x"), np.full_like(array2, np.nan) * data_unit),
    +            "a": (("y", "x"), np.full_like(array1, np.nan) * data_unit2),
    +            "b": (("y", "x"), np.full_like(array2, np.nan) * data_unit2),
             },
             coords={
    -            "x": np.arange(3, 6) * dim_unit,
    -            "y": np.arange(4, 6) * dim_unit,
    -            "u": ("x", np.arange(3, 6) * coord_unit),
    +            "x": np.arange(3, 6) * dim_unit2,
    +            "y": np.arange(4, 6) * dim_unit2,
    +            "u": ("x", np.arange(3, 6) * coord_unit2),
             },
         )
     
    @@ -1110,15 +1201,9 @@ def test_merge_dataset(variant, unit, error, dtype):
     
         units = extract_units(ds1)
         convert_and_strip = lambda ds: strip_units(convert_units(ds, units))
    -    expected_units = {name: original_unit for name in list("abxyzu")}
    -    expected = convert_units(
    -        attach_units(
    -            func(
    -                [convert_and_strip(ds1), convert_and_strip(ds2), convert_and_strip(ds3)]
    -            ),
    -            units,
    -        ),
    -        expected_units,
    +    expected = attach_units(
    +        func([convert_and_strip(ds1), convert_and_strip(ds2), convert_and_strip(ds3)]),
    +        units,
         )
         actual = func([ds1, ds2, ds3])
     
    @@ -1126,35 +1211,79 @@ def test_merge_dataset(variant, unit, error, dtype):
         assert_allclose(expected, actual)
     
     
    +@pytest.mark.parametrize(
    +    "variant",
    +    (
    +        "data",
    +        pytest.param(
    +            "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +        ),
    +        "coords",
    +    ),
    +)
     @pytest.mark.parametrize("func", (xr.zeros_like, xr.ones_like))
    -def test_replication_dataarray(func, dtype):
    -    array = np.linspace(0, 10, 20).astype(dtype) * unit_registry.s
    -    data_array = xr.DataArray(data=array, dims="x")
    +def test_replication_dataarray(func, variant, dtype):
    +    unit = unit_registry.m
    +
    +    variants = {
    +        "data": (unit, 1, 1),
    +        "dims": (1, unit, 1),
    +        "coords": (1, 1, unit),
    +    }
    +    data_unit, dim_unit, coord_unit = variants.get(variant)
     
    -    numpy_func = getattr(np, func.__name__)
    -    units = extract_units(numpy_func(data_array))
    -    expected = attach_units(func(data_array), units)
    +    array = np.linspace(0, 10, 20).astype(dtype) * data_unit
    +    x = np.arange(20) * dim_unit
    +    u = np.linspace(0, 1, 20) * coord_unit
    +
    +    data_array = xr.DataArray(data=array, dims="x", coords={"x": x, "u": ("x", u)})
    +    units = extract_units(data_array)
    +    units.pop(data_array.name)
    +
    +    expected = attach_units(func(strip_units(data_array)), units)
         actual = func(data_array)
     
         assert_units_equal(expected, actual)
         assert_identical(expected, actual)
     
     
    +@pytest.mark.parametrize(
    +    "variant",
    +    (
    +        "data",
    +        pytest.param(
    +            "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +        ),
    +        "coords",
    +    ),
    +)
     @pytest.mark.parametrize("func", (xr.zeros_like, xr.ones_like))
    -def test_replication_dataset(func, dtype):
    -    array1 = np.linspace(0, 10, 20).astype(dtype) * unit_registry.s
    -    array2 = np.linspace(5, 10, 10).astype(dtype) * unit_registry.Pa
    -    x = np.arange(20).astype(dtype) * unit_registry.m
    -    y = np.arange(10).astype(dtype) * unit_registry.m
    -    z = y.to(unit_registry.mm)
    +def test_replication_dataset(func, variant, dtype):
    +    unit = unit_registry.m
    +
    +    variants = {
    +        "data": ((unit_registry.m, unit_registry.Pa), 1, 1),
    +        "dims": ((1, 1), unit, 1),
    +        "coords": ((1, 1), 1, unit),
    +    }
    +    (data_unit1, data_unit2), dim_unit, coord_unit = variants.get(variant)
    +
    +    array1 = np.linspace(0, 10, 20).astype(dtype) * data_unit1
    +    array2 = np.linspace(5, 10, 10).astype(dtype) * data_unit2
    +    x = np.arange(20).astype(dtype) * dim_unit
    +    y = np.arange(10).astype(dtype) * dim_unit
    +    u = np.linspace(0, 1, 10) * coord_unit
     
         ds = xr.Dataset(
             data_vars={"a": ("x", array1), "b": ("y", array2)},
    -        coords={"x": x, "y": y, "z": ("y", z)},
    +        coords={"x": x, "y": y, "u": ("y", u)},
         )
    +    units = {
    +        name: unit
    +        for name, unit in extract_units(ds).items()
    +        if name not in ds.data_vars
    +    }
     
    -    numpy_func = getattr(np, func.__name__)
    -    units = extract_units(ds.map(numpy_func))
         expected = attach_units(func(strip_units(ds)), units)
     
         actual = func(ds)
    @@ -1163,37 +1292,40 @@ def test_replication_dataset(func, dtype):
         assert_identical(expected, actual)
     
     
    -@pytest.mark.xfail(
    -    reason=(
    -        "pint is undecided on how `full_like` should work, so incorrect errors "
    -        "may be expected: hgrecco/pint#882"
    -    )
    -)
     @pytest.mark.parametrize(
    -    "unit,error",
    +    "variant",
         (
    -        pytest.param(1, DimensionalityError, id="no_unit"),
    +        "data",
             pytest.param(
    -            unit_registry.dimensionless, DimensionalityError, id="dimensionless"
    +            "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +        ),
    +        pytest.param(
    +            "coords",
    +            marks=pytest.mark.xfail(reason="can't copy quantity into non-quantity"),
             ),
    -        pytest.param(unit_registry.m, DimensionalityError, id="incompatible_unit"),
    -        pytest.param(unit_registry.ms, None, id="compatible_unit"),
    -        pytest.param(unit_registry.s, None, id="identical_unit"),
         ),
    -    ids=repr,
     )
    -def test_replication_full_like_dataarray(unit, error, dtype):
    -    array = np.linspace(0, 5, 10) * unit_registry.s
    -    data_array = xr.DataArray(data=array, dims="x")
    +def test_replication_full_like_dataarray(variant, dtype):
    +    # since full_like will strip units and then use the units of the
    +    # fill value, we don't need to try multiple units
    +    unit = unit_registry.m
     
    -    fill_value = -1 * unit
    -    if error is not None:
    -        with pytest.raises(error):
    -            xr.full_like(data_array, fill_value=fill_value)
    +    variants = {
    +        "data": (unit, 1, 1),
    +        "dims": (1, unit, 1),
    +        "coords": (1, 1, unit),
    +    }
    +    data_unit, dim_unit, coord_unit = variants.get(variant)
     
    -        return
    +    array = np.linspace(0, 5, 10) * data_unit
    +    x = np.arange(10) * dim_unit
    +    u = np.linspace(0, 1, 10) * coord_unit
    +    data_array = xr.DataArray(data=array, dims="x", coords={"x": x, "u": ("x", u)})
     
    -    units = {**extract_units(data_array), **{None: unit if unit != 1 else None}}
    +    fill_value = -1 * unit_registry.degK
    +
    +    units = extract_units(data_array)
    +    units[data_array.name] = fill_value.units
         expected = attach_units(
             xr.full_like(strip_units(data_array), fill_value=strip_units(fill_value)), units
         )
    @@ -1203,47 +1335,46 @@ def test_replication_full_like_dataarray(unit, error, dtype):
         assert_identical(expected, actual)
     
     
    -@pytest.mark.xfail(
    -    reason=(
    -        "pint is undecided on how `full_like` should work, so incorrect errors "
    -        "may be expected: hgrecco/pint#882"
    -    )
    -)
     @pytest.mark.parametrize(
    -    "unit,error",
    +    "variant",
         (
    -        pytest.param(1, DimensionalityError, id="no_unit"),
    +        "data",
             pytest.param(
    -            unit_registry.dimensionless, DimensionalityError, id="dimensionless"
    +            "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +        ),
    +        pytest.param(
    +            "coords",
    +            marks=pytest.mark.xfail(reason="can't copy quantity into non-quantity"),
             ),
    -        pytest.param(unit_registry.m, DimensionalityError, id="incompatible_unit"),
    -        pytest.param(unit_registry.ms, None, id="compatible_unit"),
    -        pytest.param(unit_registry.s, None, id="identical_unit"),
         ),
    -    ids=repr,
     )
    -def test_replication_full_like_dataset(unit, error, dtype):
    -    array1 = np.linspace(0, 10, 20).astype(dtype) * unit_registry.s
    -    array2 = np.linspace(5, 10, 10).astype(dtype) * unit_registry.Pa
    -    x = np.arange(20).astype(dtype) * unit_registry.m
    -    y = np.arange(10).astype(dtype) * unit_registry.m
    -    z = y.to(unit_registry.mm)
    +def test_replication_full_like_dataset(variant, dtype):
    +    unit = unit_registry.m
    +
    +    variants = {
    +        "data": ((unit_registry.s, unit_registry.Pa), 1, 1),
    +        "dims": ((1, 1), unit, 1),
    +        "coords": ((1, 1), 1, unit),
    +    }
    +    (data_unit1, data_unit2), dim_unit, coord_unit = variants.get(variant)
    +
    +    array1 = np.linspace(0, 10, 20).astype(dtype) * data_unit1
    +    array2 = np.linspace(5, 10, 10).astype(dtype) * data_unit2
    +    x = np.arange(20).astype(dtype) * dim_unit
    +    y = np.arange(10).astype(dtype) * dim_unit
    +
    +    u = np.linspace(0, 1, 10) * coord_unit
     
         ds = xr.Dataset(
             data_vars={"a": ("x", array1), "b": ("y", array2)},
    -        coords={"x": x, "y": y, "z": ("y", z)},
    +        coords={"x": x, "y": y, "u": ("y", u)},
         )
     
    -    fill_value = -1 * unit
    -    if error is not None:
    -        with pytest.raises(error):
    -            xr.full_like(ds, fill_value=fill_value)
    -
    -        return
    +    fill_value = -1 * unit_registry.degK
     
         units = {
             **extract_units(ds),
    -        **{name: unit if unit != 1 else None for name in ds.data_vars},
    +        **{name: unit_registry.degK for name in ds.data_vars},
         }
         expected = attach_units(
             xr.full_like(strip_units(ds), fill_value=strip_units(fill_value)), units
    @@ -1314,10 +1445,9 @@ def test_where_dataarray(fill_value, unit, error, dtype):
     def test_where_dataset(fill_value, unit, error, dtype):
         array1 = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m
         array2 = np.linspace(-5, 0, 10).astype(dtype) * unit_registry.m
    -    x = np.arange(10) * unit_registry.s
     
    -    ds = xr.Dataset(data_vars={"a": ("x", array1), "b": ("x", array2)}, coords={"x": x})
    -    cond = x < 5 * unit_registry.s
    +    ds = xr.Dataset(data_vars={"a": ("x", array1), "b": ("x", array2)})
    +    cond = array1 < 2 * unit_registry.m
         fill_value = fill_value * unit
     
         if error is not None and not (
    @@ -1364,62 +1494,14 @@ def test_dot_dataarray(dtype):
         assert_identical(expected, actual)
     
     
    -def delete_attrs(*to_delete):
    -    def wrapper(cls):
    -        for item in to_delete:
    -            setattr(cls, item, None)
    -
    -        return cls
    -
    -    return wrapper
    -
    -
    -@delete_attrs(
    -    "test_getitem_with_mask",
    -    "test_getitem_with_mask_nd_indexer",
    -    "test_index_0d_string",
    -    "test_index_0d_datetime",
    -    "test_index_0d_timedelta64",
    -    "test_0d_time_data",
    -    "test_index_0d_not_a_time",
    -    "test_datetime64_conversion",
    -    "test_timedelta64_conversion",
    -    "test_pandas_period_index",
    -    "test_1d_math",
    -    "test_1d_reduce",
    -    "test_array_interface",
    -    "test___array__",
    -    "test_copy_index",
    -    "test_concat_number_strings",
    -    "test_concat_fixed_len_str",
    -    "test_concat_mixed_dtypes",
    -    "test_pandas_datetime64_with_tz",
    -    "test_pandas_data",
    -    "test_multiindex",
    -)
    -class TestVariable(VariableSubclassobjects):
    -    @staticmethod
    -    def cls(dims, data, *args, **kwargs):
    -        return xr.Variable(
    -            dims, unit_registry.Quantity(data, unit_registry.m), *args, **kwargs
    -        )
    -
    -    def example_1d_objects(self):
    -        for data in [
    -            range(3),
    -            0.5 * np.arange(3),
    -            0.5 * np.arange(3, dtype=np.float32),
    -            np.array(["a", "b", "c"], dtype=object),
    -        ]:
    -            yield (self.cls("x", data), data)
    -
    +class TestVariable:
         @pytest.mark.parametrize(
             "func",
             (
                 method("all"),
                 method("any"),
    -            method("argmax"),
    -            method("argmin"),
    +            method("argmax", dim="x"),
    +            method("argmin", dim="x"),
                 method("argsort"),
                 method("cumprod"),
                 method("cumsum"),
    @@ -1427,10 +1509,7 @@ def example_1d_objects(self):
                 method("mean"),
                 method("median"),
                 method("min"),
    -            pytest.param(
    -                method("prod"),
    -                marks=pytest.mark.xfail(reason="not implemented by pint"),
    -            ),
    +            method("prod"),
                 method("std"),
                 method("sum"),
                 method("var"),
    @@ -1438,17 +1517,32 @@ def example_1d_objects(self):
             ids=repr,
         )
         def test_aggregation(self, func, dtype):
    +        if func.name == "prod" and dtype.kind == "f":
    +            pytest.xfail(reason="nanprod is not supported, yet")
    +
             array = np.linspace(0, 1, 10).astype(dtype) * (
                 unit_registry.m if func.name != "cumprod" else unit_registry.dimensionless
             )
             variable = xr.Variable("x", array)
     
    -        units = extract_units(func(array))
    +        numpy_kwargs = func.kwargs.copy()
    +        if "dim" in func.kwargs:
    +            numpy_kwargs["axis"] = variable.get_axis_num(numpy_kwargs.pop("dim"))
    +
    +        units = extract_units(func(array, **numpy_kwargs))
             expected = attach_units(func(strip_units(variable)), units)
             actual = func(variable)
     
             assert_units_equal(expected, actual)
    -        xr.testing.assert_identical(expected, actual)
    +        assert_allclose(expected, actual)
    +
    +    def test_aggregate_complex(self):
    +        variable = xr.Variable("x", [1, 2j, np.nan] * unit_registry.m)
    +        expected = xr.Variable((), (0.5 + 1j) * unit_registry.m)
    +        actual = variable.mean()
    +
    +        assert_units_equal(expected, actual)
    +        assert_allclose(expected, actual)
     
         @pytest.mark.parametrize(
             "func",
    @@ -1506,7 +1600,7 @@ def test_numpy_methods(self, func, unit, error, dtype):
             actual = func(variable, *args, **kwargs)
     
             assert_units_equal(expected, actual)
    -        xr.testing.assert_allclose(expected, actual)
    +        assert_allclose(expected, actual)
     
         @pytest.mark.parametrize(
             "func", (method("item", 5), method("searchsorted", 5)), ids=repr
    @@ -1566,7 +1660,7 @@ def test_raw_numpy_methods(self, func, unit, error, dtype):
             actual = func(variable, *args, **kwargs)
     
             assert_units_equal(expected, actual)
    -        np.testing.assert_allclose(expected, actual)
    +        assert_duckarray_allclose(expected, actual)
     
         @pytest.mark.parametrize(
             "func", (method("isnull"), method("notnull"), method("count")), ids=repr
    @@ -1589,7 +1683,7 @@ def test_missing_value_detection(self, func):
             actual = func(variable)
     
             assert_units_equal(expected, actual)
    -        xr.testing.assert_identical(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize(
             "unit,error",
    @@ -1635,7 +1729,7 @@ def test_missing_value_fillna(self, unit, error):
             actual = variable.fillna(value=fill_value)
     
             assert_units_equal(expected, actual)
    -        xr.testing.assert_identical(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize(
             "unit",
    @@ -1643,7 +1737,10 @@ def test_missing_value_fillna(self, unit, error):
                 pytest.param(1, id="no_unit"),
                 pytest.param(unit_registry.dimensionless, id="dimensionless"),
                 pytest.param(unit_registry.s, id="incompatible_unit"),
    -            pytest.param(unit_registry.cm, id="compatible_unit",),
    +            pytest.param(
    +                unit_registry.cm,
    +                id="compatible_unit",
    +            ),
                 pytest.param(unit_registry.m, id="identical_unit"),
             ),
         )
    @@ -1660,7 +1757,7 @@ def test_missing_value_fillna(self, unit, error):
                 method("equals"),
                 pytest.param(
                     method("identical"),
    -                marks=pytest.mark.skip(reason="behaviour of identical is unclear"),
    +                marks=pytest.mark.skip(reason="behavior of identical is undecided"),
                 ),
             ),
             ids=repr,
    @@ -1746,7 +1843,7 @@ def test_isel(self, indices, dtype):
             actual = variable.isel(x=indices)
     
             assert_units_equal(expected, actual)
    -        xr.testing.assert_identical(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize(
             "unit,error",
    @@ -1804,7 +1901,7 @@ def test_1d_math(self, func, unit, error, dtype):
             actual = func(variable, y)
     
             assert_units_equal(expected, actual)
    -        xr.testing.assert_allclose(expected, actual)
    +        assert_allclose(expected, actual)
     
         @pytest.mark.parametrize(
             "unit,error",
    @@ -1853,43 +1950,33 @@ def test_masking(self, func, unit, error, dtype):
             actual = func(variable, cond, other)
     
             assert_units_equal(expected, actual)
    -        xr.testing.assert_identical(expected, actual)
    +        assert_identical(expected, actual)
     
    -    def test_squeeze(self, dtype):
    +    @pytest.mark.parametrize("dim", ("x", "y", "z", "t", "all"))
    +    def test_squeeze(self, dim, dtype):
             shape = (2, 1, 3, 1, 1, 2)
             names = list("abcdef")
    +        dim_lengths = dict(zip(names, shape))
             array = np.ones(shape=shape) * unit_registry.m
             variable = xr.Variable(names, array)
     
    +        kwargs = {"dim": dim} if dim != "all" and dim_lengths.get(dim, 0) == 1 else {}
             expected = attach_units(
    -            strip_units(variable).squeeze(), extract_units(variable)
    +            strip_units(variable).squeeze(**kwargs), extract_units(variable)
             )
    -        actual = variable.squeeze()
    +        actual = variable.squeeze(**kwargs)
     
             assert_units_equal(expected, actual)
    -        xr.testing.assert_identical(expected, actual)
    -
    -        names = tuple(name for name, size in zip(names, shape) if shape == 1)
    -        for name in names:
    -            expected = attach_units(
    -                strip_units(variable).squeeze(dim=name), extract_units(variable)
    -            )
    -            actual = variable.squeeze(dim=name)
    -
    -            assert_units_equal(expected, actual)
    -            xr.testing.assert_identical(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize(
             "func",
             (
                 method("coarsen", windows={"y": 2}, func=np.mean),
    -            pytest.param(
    -                method("quantile", q=[0.25, 0.75]),
    -                marks=pytest.mark.xfail(reason="nanquantile not implemented"),
    -            ),
    +            method("quantile", q=[0.25, 0.75]),
                 pytest.param(
                     method("rank", dim="x"),
    -                marks=pytest.mark.xfail(reason="rank not implemented for non-ndarray"),
    +                marks=pytest.mark.skip(reason="rank not implemented for non-ndarray"),
                 ),
                 method("roll", {"x": 2}),
                 pytest.param(
    @@ -1913,7 +2000,7 @@ def test_computation(self, func, dtype):
             actual = func(variable)
     
             assert_units_equal(expected, actual)
    -        xr.testing.assert_identical(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize(
             "unit,error",
    @@ -1959,7 +2046,7 @@ def test_stack(self, dtype):
             actual = variable.stack(z=("x", "y"))
     
             assert_units_equal(expected, actual)
    -        xr.testing.assert_identical(expected, actual)
    +        assert_identical(expected, actual)
     
         def test_unstack(self, dtype):
             array = np.linspace(0, 5, 3 * 10).astype(dtype) * unit_registry.m
    @@ -1971,7 +2058,7 @@ def test_unstack(self, dtype):
             actual = variable.unstack(z={"x": 3, "y": 10})
     
             assert_units_equal(expected, actual)
    -        xr.testing.assert_identical(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize(
             "unit,error",
    @@ -2011,7 +2098,7 @@ def test_concat(self, unit, error, dtype):
             actual = xr.Variable.concat([variable, other], dim="y")
     
             assert_units_equal(expected, actual)
    -        xr.testing.assert_identical(expected, actual)
    +        assert_identical(expected, actual)
     
         def test_set_dims(self, dtype):
             array = np.linspace(0, 5, 3 * 10).reshape(3, 10).astype(dtype) * unit_registry.m
    @@ -2024,7 +2111,7 @@ def test_set_dims(self, dtype):
             actual = variable.set_dims(dims)
     
             assert_units_equal(expected, actual)
    -        xr.testing.assert_identical(expected, actual)
    +        assert_identical(expected, actual)
     
         def test_copy(self, dtype):
             array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m
    @@ -2037,7 +2124,7 @@ def test_copy(self, dtype):
             actual = variable.copy(data=other)
     
             assert_units_equal(expected, actual)
    -        xr.testing.assert_identical(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize(
             "unit",
    @@ -2078,45 +2165,39 @@ def test_no_conflicts(self, unit, dtype):
     
             assert expected == actual
     
    +    @pytest.mark.parametrize(
    +        "mode",
    +        [
    +            "constant",
    +            "mean",
    +            "median",
    +            "reflect",
    +            "edge",
    +            "linear_ramp",
    +            "maximum",
    +            "minimum",
    +            "symmetric",
    +            "wrap",
    +        ],
    +    )
         @pytest.mark.parametrize("xr_arg, np_arg", _PAD_XR_NP_ARGS)
    -    def test_pad_constant_values(self, dtype, xr_arg, np_arg):
    -        data = np.arange(4 * 3 * 2).reshape(4, 3, 2).astype(dtype) * unit_registry.m
    +    def test_pad(self, mode, xr_arg, np_arg):
    +        data = np.arange(4 * 3 * 2).reshape(4, 3, 2) * unit_registry.m
             v = xr.Variable(["x", "y", "z"], data)
     
    -        actual = v.pad(**xr_arg, mode="constant")
    -        expected = xr.Variable(
    -            v.dims,
    -            np.pad(
    -                v.data.astype(float), np_arg, mode="constant", constant_values=np.nan,
    -            ),
    +        expected = attach_units(
    +            strip_units(v).pad(mode=mode, **xr_arg),
    +            extract_units(v),
             )
    -        xr.testing.assert_identical(expected, actual)
    -        assert_units_equal(expected, actual)
    -        assert isinstance(actual._data, type(v._data))
    +        actual = v.pad(mode=mode, **xr_arg)
     
    -        # for the boolean array, we pad False
    -        data = np.full_like(data, False, dtype=bool).reshape(4, 3, 2)
    -        v = xr.Variable(["x", "y", "z"], data)
    -        actual = v.pad(**xr_arg, mode="constant", constant_values=data.flat[0])
    -        expected = xr.Variable(
    -            v.dims,
    -            np.pad(v.data, np_arg, mode="constant", constant_values=v.data.flat[0]),
    -        )
    -        xr.testing.assert_identical(actual, expected)
             assert_units_equal(expected, actual)
    +        assert_equal(actual, expected)
     
         @pytest.mark.parametrize(
             "unit,error",
             (
    -            pytest.param(
    -                1,
    -                DimensionalityError,
    -                id="no_unit",
    -                marks=pytest.mark.xfail(
    -                    LooseVersion(pint.__version__) < LooseVersion("0.10.2"),
    -                    reason="bug in pint's implementation of np.pad",
    -                ),
    -            ),
    +            pytest.param(1, DimensionalityError, id="no_unit"),
                 pytest.param(
                     unit_registry.dimensionless, DimensionalityError, id="dimensionless"
                 ),
    @@ -2149,20 +2230,19 @@ def test_pad_unit_constant_value(self, unit, error, dtype):
             actual = func(variable, constant_values=fill_value)
     
             assert_units_equal(expected, actual)
    -        xr.testing.assert_identical(expected, actual)
    +        assert_identical(expected, actual)
     
     
     class TestDataArray:
    -    @pytest.mark.filterwarnings("error:::pint[.*]")
         @pytest.mark.parametrize(
             "variant",
             (
                 pytest.param(
                     "with_dims",
    -                marks=pytest.mark.xfail(reason="units in indexes are not supported"),
    +                marks=pytest.mark.skip(reason="indexes don't support units"),
                 ),
    -            pytest.param("with_coords"),
    -            pytest.param("without_coords"),
    +            "with_coords",
    +            "without_coords",
             ),
         )
         def test_init(self, variant, dtype):
    @@ -2188,7 +2268,6 @@ def test_init(self, variant, dtype):
                 }.values()
             )
     
    -    @pytest.mark.filterwarnings("error:::pint[.*]")
         @pytest.mark.parametrize(
             "func", (pytest.param(str, id="str"), pytest.param(repr, id="repr"))
         )
    @@ -2197,7 +2276,7 @@ def test_init(self, variant, dtype):
             (
                 pytest.param(
                     "with_dims",
    -                marks=pytest.mark.xfail(reason="units in indexes are not supported"),
    +                marks=pytest.mark.skip(reason="indexes don't support units"),
                 ),
                 pytest.param("with_coords"),
                 pytest.param("without_coords"),
    @@ -2224,79 +2303,75 @@ def test_repr(self, func, variant, dtype):
         @pytest.mark.parametrize(
             "func",
             (
    +            function("all"),
    +            function("any"),
                 pytest.param(
    -                function("all"),
    -                marks=pytest.mark.xfail(reason="not implemented by pint yet"),
    +                function("argmax"),
    +                marks=pytest.mark.skip(
    +                    reason="calling np.argmax as a function on xarray objects is not "
    +                    "supported"
    +                ),
                 ),
                 pytest.param(
    -                function("any"),
    -                marks=pytest.mark.xfail(reason="not implemented by pint yet"),
    +                function("argmin"),
    +                marks=pytest.mark.skip(
    +                    reason="calling np.argmin as a function on xarray objects is not "
    +                    "supported"
    +                ),
                 ),
    -            function("argmax"),
    -            function("argmin"),
                 function("max"),
                 function("mean"),
                 pytest.param(
                     function("median"),
    -                marks=pytest.mark.xfail(reason="not implemented by xarray"),
    +                marks=pytest.mark.skip(
    +                    reason="median does not work with dataarrays yet"
    +                ),
                 ),
                 function("min"),
    -            pytest.param(
    -                function("prod"),
    -                marks=pytest.mark.xfail(reason="not implemented by pint yet"),
    -            ),
    +            function("prod"),
                 function("sum"),
                 function("std"),
                 function("var"),
                 function("cumsum"),
    -            pytest.param(
    -                function("cumprod"),
    -                marks=pytest.mark.xfail(reason="not implemented by pint yet"),
    -            ),
    -            pytest.param(
    -                method("all"),
    -                marks=pytest.mark.xfail(reason="not implemented by pint yet"),
    -            ),
    -            pytest.param(
    -                method("any"),
    -                marks=pytest.mark.xfail(reason="not implemented by pint yet"),
    -            ),
    -            method("argmax"),
    -            method("argmin"),
    +            function("cumprod"),
    +            method("all"),
    +            method("any"),
    +            method("argmax", dim="x"),
    +            method("argmin", dim="x"),
                 method("max"),
                 method("mean"),
                 method("median"),
                 method("min"),
    -            pytest.param(
    -                method("prod"),
    -                marks=pytest.mark.xfail(
    -                    reason="comparison of quantity with ndarrays in nanops not implemented"
    -                ),
    -            ),
    +            method("prod"),
                 method("sum"),
                 method("std"),
                 method("var"),
                 method("cumsum"),
    -            pytest.param(
    -                method("cumprod"),
    -                marks=pytest.mark.xfail(reason="pint does not implement cumprod yet"),
    -            ),
    +            method("cumprod"),
             ),
             ids=repr,
         )
         def test_aggregation(self, func, dtype):
    +        if func.name == "prod" and dtype.kind == "f":
    +            pytest.xfail(reason="nanprod is not supported, yet")
    +
             array = np.arange(10).astype(dtype) * (
                 unit_registry.m if func.name != "cumprod" else unit_registry.dimensionless
             )
             data_array = xr.DataArray(data=array, dims="x")
     
    +        numpy_kwargs = func.kwargs.copy()
    +        if "dim" in numpy_kwargs:
    +            numpy_kwargs["axis"] = data_array.get_axis_num(numpy_kwargs.pop("dim"))
    +
             # units differ based on the applied function, so we need to
             # first compute the units
             units = extract_units(func(array))
             expected = attach_units(func(strip_units(data_array)), units)
             actual = func(data_array)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_allclose(expected, actual)
     
         @pytest.mark.parametrize(
             "func",
    @@ -2314,7 +2389,8 @@ def test_unary_operations(self, func, dtype):
             expected = attach_units(func(strip_units(data_array)), units)
             actual = func(data_array)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize(
             "func",
    @@ -2333,14 +2409,18 @@ def test_binary_operations(self, func, dtype):
             expected = attach_units(func(strip_units(data_array)), units)
             actual = func(data_array)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize(
             "comparison",
             (
                 pytest.param(operator.lt, id="less_than"),
                 pytest.param(operator.ge, id="greater_equal"),
    -            pytest.param(operator.eq, id="equal"),
    +            pytest.param(
    +                operator.eq,
    +                id="equal",
    +            ),
             ),
         )
         @pytest.mark.parametrize(
    @@ -2383,7 +2463,8 @@ def test_comparison_operations(self, comparison, unit, error, dtype):
                 strip_units(convert_units(to_compare_with, expected_units)),
             )
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize(
             "units,error",
    @@ -2411,9 +2492,10 @@ def test_univariate_ufunc(self, units, error, dtype):
             )
             actual = func(data_array)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
    -    @pytest.mark.xfail(reason="xarray's `np.maximum` strips units")
    +    @pytest.mark.xfail(reason="needs the type register system for __array_ufunc__")
         @pytest.mark.parametrize(
             "unit,error",
             (
    @@ -2422,7 +2504,12 @@ def test_univariate_ufunc(self, units, error, dtype):
                     unit_registry.dimensionless, DimensionalityError, id="dimensionless"
                 ),
                 pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
    -            pytest.param(unit_registry.mm, None, id="compatible_unit"),
    +            pytest.param(
    +                unit_registry.mm,
    +                None,
    +                id="compatible_unit",
    +                marks=pytest.mark.xfail(reason="pint converts to the wrong units"),
    +            ),
                 pytest.param(unit_registry.m, None, id="identical_unit"),
             ),
         )
    @@ -2433,7 +2520,7 @@ def test_bivariate_ufunc(self, unit, error, dtype):
     
             if error is not None:
                 with pytest.raises(error):
    -                np.maximum(data_array, 0 * unit)
    +                np.maximum(data_array, 1 * unit)
     
                 return
     
    @@ -2441,16 +2528,18 @@ def test_bivariate_ufunc(self, unit, error, dtype):
             expected = attach_units(
                 np.maximum(
                     strip_units(data_array),
    -                strip_units(convert_units(0 * unit, expected_units)),
    +                strip_units(convert_units(1 * unit, expected_units)),
                 ),
                 expected_units,
             )
     
    -        actual = np.maximum(data_array, 0 * unit)
    -        assert_equal_with_units(expected, actual)
    +        actual = np.maximum(data_array, 1 * unit)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
    -        actual = np.maximum(0 * unit, data_array)
    -        assert_equal_with_units(expected, actual)
    +        actual = np.maximum(1 * unit, data_array)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize("property", ("T", "imag", "real"))
         def test_numpy_properties(self, property, dtype):
    @@ -2466,7 +2555,8 @@ def test_numpy_properties(self, property, dtype):
             )
             actual = getattr(data_array, property)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize(
             "func",
    @@ -2481,16 +2571,86 @@ def test_numpy_methods(self, func, dtype):
             expected = attach_units(strip_units(data_array), units)
             actual = func(data_array)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
    +
    +    def test_item(self, dtype):
    +        array = np.arange(10).astype(dtype) * unit_registry.m
    +        data_array = xr.DataArray(data=array)
    +
    +        func = method("item", 2)
    +
    +        expected = func(strip_units(data_array)) * unit_registry.m
    +        actual = func(data_array)
    +
    +        assert_duckarray_allclose(expected, actual)
    +
    +    @pytest.mark.parametrize(
    +        "unit,error",
    +        (
    +            pytest.param(1, DimensionalityError, id="no_unit"),
    +            pytest.param(
    +                unit_registry.dimensionless, DimensionalityError, id="dimensionless"
    +            ),
    +            pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
    +            pytest.param(unit_registry.cm, None, id="compatible_unit"),
    +            pytest.param(unit_registry.m, None, id="identical_unit"),
    +        ),
    +    )
    +    @pytest.mark.parametrize(
    +        "func",
    +        (
    +            method("searchsorted", 5),
    +            pytest.param(
    +                function("searchsorted", 5),
    +                marks=pytest.mark.xfail(
    +                    reason="xarray does not implement __array_function__"
    +                ),
    +            ),
    +        ),
    +        ids=repr,
    +    )
    +    def test_searchsorted(self, func, unit, error, dtype):
    +        array = np.arange(10).astype(dtype) * unit_registry.m
    +        data_array = xr.DataArray(data=array)
    +
    +        scalar_types = (int, float)
    +        args = list(value * unit for value in func.args)
    +        kwargs = {
    +            key: (value * unit if isinstance(value, scalar_types) else value)
    +            for key, value in func.kwargs.items()
    +        }
    +
    +        if error is not None:
    +            with pytest.raises(error):
    +                func(data_array, *args, **kwargs)
    +
    +            return
    +
    +        units = extract_units(data_array)
    +        expected_units = extract_units(func(array, *args, **kwargs))
    +        stripped_args = [strip_units(convert_units(value, units)) for value in args]
    +        stripped_kwargs = {
    +            key: strip_units(convert_units(value, units))
    +            for key, value in kwargs.items()
    +        }
    +        expected = attach_units(
    +            func(strip_units(data_array), *stripped_args, **stripped_kwargs),
    +            expected_units,
    +        )
    +        actual = func(data_array, *args, **kwargs)
    +
    +        assert_units_equal(expected, actual)
    +        np.testing.assert_allclose(expected, actual)
     
         @pytest.mark.parametrize(
             "func",
             (
                 method("clip", min=3, max=8),
                 pytest.param(
    -                method("searchsorted", v=5),
    +                function("clip", a_min=3, a_max=8),
                     marks=pytest.mark.xfail(
    -                    reason="searchsorted somehow requires a undocumented `keys` argument"
    +                    reason="xarray does not implement __array_function__"
                     ),
                 ),
             ),
    @@ -2513,28 +2673,32 @@ def test_numpy_methods_with_args(self, func, unit, error, dtype):
             data_array = xr.DataArray(data=array)
     
             scalar_types = (int, float)
    +        args = list(value * unit for value in func.args)
             kwargs = {
                 key: (value * unit if isinstance(value, scalar_types) else value)
                 for key, value in func.kwargs.items()
             }
             if error is not None:
                 with pytest.raises(error):
    -                func(data_array, **kwargs)
    +                func(data_array, *args, **kwargs)
     
                 return
     
             units = extract_units(data_array)
    -        expected_units = extract_units(func(array, **kwargs))
    +        expected_units = extract_units(func(array, *args, **kwargs))
    +        stripped_args = [strip_units(convert_units(value, units)) for value in args]
             stripped_kwargs = {
                 key: strip_units(convert_units(value, units))
                 for key, value in kwargs.items()
             }
             expected = attach_units(
    -            func(strip_units(data_array), **stripped_kwargs), expected_units
    +            func(strip_units(data_array), *stripped_args, **stripped_kwargs),
    +            expected_units,
             )
    -        actual = func(data_array, **kwargs)
    +        actual = func(data_array, *args, **kwargs)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize(
             "func", (method("isnull"), method("notnull"), method("count")), ids=repr
    @@ -2551,15 +2715,13 @@ def test_missing_value_detection(self, func, dtype):
                 )
                 * unit_registry.degK
             )
    -        x = np.arange(array.shape[0]) * unit_registry.m
    -        y = np.arange(array.shape[1]) * unit_registry.m
    -
    -        data_array = xr.DataArray(data=array, coords={"x": x, "y": y}, dims=("x", "y"))
    +        data_array = xr.DataArray(data=array)
     
             expected = func(strip_units(data_array))
             actual = func(data_array)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.xfail(reason="ffill and bfill lose units in data")
         @pytest.mark.parametrize("func", (method("ffill"), method("bfill")), ids=repr)
    @@ -2576,7 +2738,8 @@ def test_missing_value_filling(self, func, dtype):
             )
             actual = func(data_array, dim="x")
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize(
             "unit,error",
    @@ -2586,12 +2749,7 @@ def test_missing_value_filling(self, func, dtype):
                     unit_registry.dimensionless, DimensionalityError, id="dimensionless"
                 ),
                 pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
    -            pytest.param(
    -                unit_registry.cm,
    -                None,
    -                id="compatible_unit",
    -                marks=pytest.mark.xfail(reason="fillna converts to value's unit"),
    -            ),
    +            pytest.param(unit_registry.cm, None, id="compatible_unit"),
                 pytest.param(unit_registry.m, None, id="identical_unit"),
             ),
         )
    @@ -2629,7 +2787,8 @@ def test_fillna(self, fill_value, unit, error, dtype):
             )
             actual = func(data_array, value=value)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
         def test_dropna(self, dtype):
             array = (
    @@ -2643,18 +2802,13 @@ def test_dropna(self, dtype):
             expected = attach_units(strip_units(data_array).dropna(dim="x"), units)
             actual = data_array.dropna(dim="x")
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize(
             "unit",
             (
    -            pytest.param(
    -                1,
    -                id="no_unit",
    -                marks=pytest.mark.xfail(
    -                    reason="pint's isin implementation does not work well with mixed args"
    -                ),
    -            ),
    +            pytest.param(1, id="no_unit"),
                 pytest.param(unit_registry.dimensionless, id="dimensionless"),
                 pytest.param(unit_registry.s, id="incompatible_unit"),
                 pytest.param(unit_registry.cm, id="compatible_unit"),
    @@ -2677,22 +2831,11 @@ def test_isin(self, unit, dtype):
             ) & array.check(unit)
             actual = data_array.isin(values)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize(
    -        "variant",
    -        (
    -            pytest.param(
    -                "masking",
    -                marks=pytest.mark.xfail(reason="array(nan) is not a quantity"),
    -            ),
    -            "replacing_scalar",
    -            "replacing_array",
    -            pytest.param(
    -                "dropping",
    -                marks=pytest.mark.xfail(reason="array(nan) is not a quantity"),
    -            ),
    -        ),
    +        "variant", ("masking", "replacing_scalar", "replacing_array", "dropping")
         )
         @pytest.mark.parametrize(
             "unit,error",
    @@ -2742,22 +2885,24 @@ def test_where(self, variant, unit, error, dtype):
             )
             actual = data_array.where(**kwargs)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
    -    @pytest.mark.xfail(reason="interpolate strips units")
    -    def test_interpolate_na(self, dtype):
    +    @pytest.mark.xfail(reason="uses numpy.vectorize")
    +    def test_interpolate_na(self):
             array = (
                 np.array([-1.03, 0.1, 1.4, np.nan, 2.3, np.nan, np.nan, 9.1])
                 * unit_registry.m
             )
             x = np.arange(len(array))
    -        data_array = xr.DataArray(data=array, coords={"x": x}, dims="x").astype(dtype)
    +        data_array = xr.DataArray(data=array, coords={"x": x}, dims="x")
     
             units = extract_units(data_array)
             expected = attach_units(strip_units(data_array).interpolate_na(dim="x"), units)
             actual = data_array.interpolate_na(dim="x")
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize(
             "unit,error",
    @@ -2771,13 +2916,11 @@ def test_interpolate_na(self, dtype):
                     unit_registry.cm,
                     None,
                     id="compatible_unit",
    -                marks=pytest.mark.xfail(reason="depends on reindex"),
                 ),
                 pytest.param(
                     unit_registry.m,
                     None,
                     id="identical_unit",
    -                marks=pytest.mark.xfail(reason="depends on reindex"),
                 ),
             ),
         )
    @@ -2807,7 +2950,8 @@ def test_combine_first(self, unit, error, dtype):
             )
             actual = data_array.combine_first(other)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize(
             "unit",
    @@ -2824,12 +2968,22 @@ def test_combine_first(self, unit, error, dtype):
             (
                 "data",
                 pytest.param(
    -                "dims", marks=pytest.mark.xfail(reason="units in indexes not supported")
    +                "dims", marks=pytest.mark.skip(reason="indexes don't support units")
                 ),
                 "coords",
             ),
         )
    -    @pytest.mark.parametrize("func", (method("equals"), method("identical")), ids=repr)
    +    @pytest.mark.parametrize(
    +        "func",
    +        (
    +            method("equals"),
    +            pytest.param(
    +                method("identical"),
    +                marks=pytest.mark.skip(reason="the behavior of identical is undecided"),
    +            ),
    +        ),
    +        ids=repr,
    +    )
         def test_comparisons(self, func, variation, unit, dtype):
             def is_compatible(a, b):
                 a = a if a is not None else 1
    @@ -2886,24 +3040,55 @@ def is_compatible(a, b):
                 pytest.param(unit_registry.m, id="identical_unit"),
             ),
         )
    -    def test_broadcast_like(self, unit, dtype):
    -        array1 = np.linspace(1, 2, 2 * 1).reshape(2, 1).astype(dtype) * unit_registry.Pa
    -        array2 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * unit_registry.Pa
    +    @pytest.mark.parametrize(
    +        "variant",
    +        (
    +            "data",
    +            pytest.param(
    +                "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +            ),
    +            "coords",
    +        ),
    +    )
    +    def test_broadcast_like(self, variant, unit, dtype):
    +        original_unit = unit_registry.m
    +
    +        variants = {
    +            "data": ((original_unit, unit), (1, 1), (1, 1)),
    +            "dims": ((1, 1), (original_unit, unit), (1, 1)),
    +            "coords": ((1, 1), (1, 1), (original_unit, unit)),
    +        }
    +        (
    +            (data_unit1, data_unit2),
    +            (dim_unit1, dim_unit2),
    +            (coord_unit1, coord_unit2),
    +        ) = variants.get(variant)
    +
    +        array1 = np.linspace(1, 2, 2 * 1).reshape(2, 1).astype(dtype) * data_unit1
    +        array2 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * data_unit2
    +
    +        x1 = np.arange(2) * dim_unit1
    +        x2 = np.arange(2) * dim_unit2
    +        y1 = np.array([0]) * dim_unit1
    +        y2 = np.arange(3) * dim_unit2
     
    -        x1 = np.arange(2) * unit_registry.m
    -        x2 = np.arange(2) * unit
    -        y1 = np.array([0]) * unit_registry.m
    -        y2 = np.arange(3) * unit
    +        u1 = np.linspace(0, 1, 2) * coord_unit1
    +        u2 = np.linspace(0, 1, 2) * coord_unit2
     
    -        arr1 = xr.DataArray(data=array1, coords={"x": x1, "y": y1}, dims=("x", "y"))
    -        arr2 = xr.DataArray(data=array2, coords={"x": x2, "y": y2}, dims=("x", "y"))
    +        arr1 = xr.DataArray(
    +            data=array1, coords={"x": x1, "y": y1, "u": ("x", u1)}, dims=("x", "y")
    +        )
    +        arr2 = xr.DataArray(
    +            data=array2, coords={"x": x2, "y": y2, "u": ("x", u2)}, dims=("x", "y")
    +        )
     
             expected = attach_units(
                 strip_units(arr1).broadcast_like(strip_units(arr2)), extract_units(arr1)
             )
             actual = arr1.broadcast_like(arr2)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize(
             "unit",
    @@ -2933,56 +3118,89 @@ def test_broadcast_equals(self, unit, dtype):
     
             assert expected == actual
     
    +    def test_pad(self, dtype):
    +        array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m
    +
    +        data_array = xr.DataArray(data=array, dims="x")
    +        units = extract_units(data_array)
    +
    +        expected = attach_units(strip_units(data_array).pad(x=(2, 3)), units)
    +        actual = data_array.pad(x=(2, 3))
    +
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
    +
    +    @pytest.mark.parametrize(
    +        "variant",
    +        (
    +            "data",
    +            pytest.param(
    +                "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +            ),
    +            "coords",
    +        ),
    +    )
         @pytest.mark.parametrize(
             "func",
             (
                 method("pipe", lambda da: da * 10),
    -            method("assign_coords", y2=("y", np.arange(10) * unit_registry.mm)),
    +            method("assign_coords", w=("y", np.arange(10) * unit_registry.mm)),
                 method("assign_attrs", attr1="value"),
    -            method("rename", x2="x_mm"),
    -            method("swap_dims", {"x": "x2"}),
    -            method(
    -                "expand_dims",
    -                dim={"z": np.linspace(10, 20, 12) * unit_registry.s},
    -                axis=1,
    +            method("rename", u="v"),
    +            pytest.param(
    +                method("swap_dims", {"x": "u"}),
    +                marks=pytest.mark.skip(reason="indexes don't support units"),
    +            ),
    +            pytest.param(
    +                method(
    +                    "expand_dims",
    +                    dim={"z": np.linspace(10, 20, 12) * unit_registry.s},
    +                    axis=1,
    +                ),
    +                marks=pytest.mark.skip(reason="indexes don't support units"),
                 ),
                 method("drop_vars", "x"),
    -            method("reset_coords", names="x2"),
    +            method("reset_coords", names="u"),
                 method("copy"),
                 method("astype", np.float32),
    -            method("item", 1),
             ),
             ids=repr,
         )
    -    def test_content_manipulation(self, func, dtype):
    -        quantity = (
    -            np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype)
    -            * unit_registry.pascal
    -        )
    -        x = np.arange(quantity.shape[0]) * unit_registry.m
    -        y = np.arange(quantity.shape[1]) * unit_registry.m
    -        x2 = x.to(unit_registry.mm)
    +    def test_content_manipulation(self, func, variant, dtype):
    +        unit = unit_registry.m
    +
    +        variants = {
    +            "data": (unit, 1, 1),
    +            "dims": (1, unit, 1),
    +            "coords": (1, 1, unit),
    +        }
    +        data_unit, dim_unit, coord_unit = variants.get(variant)
    +
    +        quantity = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit
    +        x = np.arange(quantity.shape[0]) * dim_unit
    +        y = np.arange(quantity.shape[1]) * dim_unit
    +        u = np.linspace(0, 1, quantity.shape[0]) * coord_unit
     
             data_array = xr.DataArray(
    -            name="data",
    +            name="a",
                 data=quantity,
    -            coords={"x": x, "x2": ("x", x2), "y": y},
    +            coords={"x": x, "u": ("x", u), "y": y},
                 dims=("x", "y"),
             )
     
             stripped_kwargs = {
                 key: array_strip_units(value) for key, value in func.kwargs.items()
             }
    -        units = {**{"x_mm": x2.units, "x2": x2.units}, **extract_units(data_array)}
    +        units = extract_units(data_array)
    +        units["u"] = getattr(u, "units", None)
    +        units["v"] = getattr(u, "units", None)
     
             expected = attach_units(func(strip_units(data_array), **stripped_kwargs), units)
             actual = func(data_array)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
    -    @pytest.mark.parametrize(
    -        "func", (pytest.param(method("copy", data=np.arange(20))),), ids=repr
    -    )
         @pytest.mark.parametrize(
             "unit",
             (
    @@ -2991,20 +3209,20 @@ def test_content_manipulation(self, func, dtype):
                 pytest.param(unit_registry.degK, id="with_unit"),
             ),
         )
    -    def test_content_manipulation_with_units(self, func, unit, dtype):
    +    def test_copy(self, unit, dtype):
             quantity = np.linspace(0, 10, 20, dtype=dtype) * unit_registry.pascal
    -        x = np.arange(len(quantity)) * unit_registry.m
    +        new_data = np.arange(20)
     
    -        data_array = xr.DataArray(data=quantity, coords={"x": x}, dims="x")
    -
    -        kwargs = {key: value * unit for key, value in func.kwargs.items()}
    +        data_array = xr.DataArray(data=quantity, dims="x")
     
             expected = attach_units(
    -            func(strip_units(data_array)), {None: unit, "x": x.units}
    +            strip_units(data_array).copy(data=new_data), {None: unit}
             )
     
    -        actual = func(data_array, **kwargs)
    -        assert_equal_with_units(expected, actual)
    +        actual = data_array.copy(data=new_data * unit)
    +
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize(
             "indices",
    @@ -3014,19 +3232,20 @@ def test_content_manipulation_with_units(self, func, unit, dtype):
             ),
         )
         def test_isel(self, indices, dtype):
    +        # TODO: maybe test for units in indexes?
             array = np.arange(10).astype(dtype) * unit_registry.s
    -        x = np.arange(len(array)) * unit_registry.m
     
    -        data_array = xr.DataArray(data=array, coords={"x": x}, dims="x")
    +        data_array = xr.DataArray(data=array, dims="x")
     
             expected = attach_units(
                 strip_units(data_array).isel(x=indices), extract_units(data_array)
             )
             actual = data_array.isel(x=indices)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
    -    @pytest.mark.xfail(reason="indexes don't support units")
    +    @pytest.mark.skip(reason="indexes don't support units")
         @pytest.mark.parametrize(
             "raw_values",
             (
    @@ -3067,9 +3286,11 @@ def test_sel(self, raw_values, unit, error, dtype):
                 extract_units(data_array),
             )
             actual = data_array.sel(x=values)
    -        assert_equal_with_units(expected, actual)
     
    -    @pytest.mark.xfail(reason="indexes don't support units")
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
    +
    +    @pytest.mark.skip(reason="indexes don't support units")
         @pytest.mark.parametrize(
             "raw_values",
             (
    @@ -3110,9 +3331,11 @@ def test_loc(self, raw_values, unit, error, dtype):
                 extract_units(data_array),
             )
             actual = data_array.loc[{"x": values}]
    -        assert_equal_with_units(expected, actual)
     
    -    @pytest.mark.xfail(reason="indexes don't support units")
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
    +
    +    @pytest.mark.skip(reason="indexes don't support units")
         @pytest.mark.parametrize(
             "raw_values",
             (
    @@ -3153,8 +3376,11 @@ def test_drop_sel(self, raw_values, unit, error, dtype):
                 extract_units(data_array),
             )
             actual = data_array.drop_sel(x=values)
    -        assert_equal_with_units(expected, actual)
     
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
    +
    +    @pytest.mark.parametrize("dim", ("x", "y", "z", "t", "all"))
         @pytest.mark.parametrize(
             "shape",
             (
    @@ -3165,32 +3391,22 @@ def test_drop_sel(self, raw_values, unit, error, dtype):
                 pytest.param((1, 10, 1, 20), id="first_and_last_dimension_squeezable"),
             ),
         )
    -    def test_squeeze(self, shape, dtype):
    +    def test_squeeze(self, shape, dim, dtype):
    +        names = "xyzt"
    +        dim_lengths = dict(zip(names, shape))
             names = "xyzt"
    -        coords = {
    -            name: np.arange(length).astype(dtype)
    -            * (unit_registry.m if name != "t" else unit_registry.s)
    -            for name, length in zip(names, shape)
    -        }
             array = np.arange(10 * 20).astype(dtype).reshape(shape) * unit_registry.J
    -        data_array = xr.DataArray(
    -            data=array, coords=coords, dims=tuple(names[: len(shape)])
    -        )
    +        data_array = xr.DataArray(data=array, dims=tuple(names[: len(shape)]))
    +
    +        kwargs = {"dim": dim} if dim != "all" and dim_lengths.get(dim, 0) == 1 else {}
     
             expected = attach_units(
    -            strip_units(data_array).squeeze(), extract_units(data_array)
    +            strip_units(data_array).squeeze(**kwargs), extract_units(data_array)
             )
    -        actual = data_array.squeeze()
    -        assert_equal_with_units(expected, actual)
    +        actual = data_array.squeeze(**kwargs)
     
    -        # try squeezing the dimensions separately
    -        names = tuple(dim for dim, coord in coords.items() if len(coord) == 1)
    -        for index, name in enumerate(names):
    -            expected = attach_units(
    -                strip_units(data_array).squeeze(dim=name), extract_units(data_array)
    -            )
    -            actual = data_array.squeeze(dim=name)
    -            assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize(
             "func",
    @@ -3198,63 +3414,52 @@ def test_squeeze(self, shape, dtype):
             ids=repr,
         )
         def test_head_tail_thin(self, func, dtype):
    +        # TODO: works like isel. Maybe also test units in indexes?
             array = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK
     
    -        coords = {
    -            "x": np.arange(10) * unit_registry.m,
    -            "y": np.arange(5) * unit_registry.m,
    -        }
    -
    -        data_array = xr.DataArray(data=array, coords=coords, dims=("x", "y"))
    +        data_array = xr.DataArray(data=array, dims=("x", "y"))
     
             expected = attach_units(
                 func(strip_units(data_array)), extract_units(data_array)
             )
             actual = func(data_array)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
    -    @pytest.mark.xfail(reason="indexes don't support units")
    +    @pytest.mark.parametrize("variant", ("data", "coords"))
         @pytest.mark.parametrize(
    -        "unit,error",
    +        "func",
             (
    -            pytest.param(1, DimensionalityError, id="no_unit"),
                 pytest.param(
    -                unit_registry.dimensionless, DimensionalityError, id="dimensionless"
    +                method("interp"), marks=pytest.mark.xfail(reason="uses scipy")
                 ),
    -            pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
    -            pytest.param(unit_registry.cm, None, id="compatible_unit"),
    -            pytest.param(unit_registry.m, None, id="identical_unit"),
    +            method("reindex"),
             ),
    +        ids=repr,
         )
    -    def test_interp(self, unit, error):
    -        array = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK
    -        new_coords = (np.arange(10) + 0.5) * unit
    -        coords = {
    -            "x": np.arange(10) * unit_registry.m,
    -            "y": np.arange(5) * unit_registry.m,
    +    def test_interp_reindex(self, variant, func, dtype):
    +        variants = {
    +            "data": (unit_registry.m, 1),
    +            "coords": (1, unit_registry.m),
             }
    +        data_unit, coord_unit = variants.get(variant)
     
    -        data_array = xr.DataArray(array, coords=coords, dims=("x", "y"))
    -
    -        if error is not None:
    -            with pytest.raises(error):
    -                data_array.interp(x=new_coords)
    +        array = np.linspace(1, 2, 10).astype(dtype) * data_unit
    +        y = np.arange(10) * coord_unit
     
    -            return
    +        x = np.arange(10)
    +        new_x = np.arange(10) + 0.5
    +        data_array = xr.DataArray(array, coords={"x": x, "y": ("x", y)}, dims="x")
     
             units = extract_units(data_array)
    -        expected = attach_units(
    -            strip_units(data_array).interp(
    -                x=strip_units(convert_units(new_coords, {None: unit_registry.m}))
    -            ),
    -            units,
    -        )
    -        actual = data_array.interp(x=new_coords)
    +        expected = attach_units(func(strip_units(data_array), x=new_x), units)
    +        actual = func(data_array, x=new_x)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_allclose(expected, actual)
     
    -    @pytest.mark.xfail(reason="indexes strip units")
    +    @pytest.mark.skip(reason="indexes don't support units")
         @pytest.mark.parametrize(
             "unit,error",
             (
    @@ -3267,81 +3472,70 @@ def test_interp(self, unit, error):
                 pytest.param(unit_registry.m, None, id="identical_unit"),
             ),
         )
    -    def test_interp_like(self, unit, error):
    -        array = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK
    -        coords = {
    -            "x": (np.arange(10) + 0.3) * unit_registry.m,
    -            "y": (np.arange(5) + 0.3) * unit_registry.m,
    -        }
    -
    -        data_array = xr.DataArray(array, coords=coords, dims=("x", "y"))
    -        other = xr.DataArray(
    -            data=np.empty((20, 10)) * unit_registry.degK,
    -            coords={"x": np.arange(20) * unit, "y": np.arange(10) * unit},
    -            dims=("x", "y"),
    -        )
    +    @pytest.mark.parametrize(
    +        "func",
    +        (method("interp"), method("reindex")),
    +        ids=repr,
    +    )
    +    def test_interp_reindex_indexing(self, func, unit, error, dtype):
    +        array = np.linspace(1, 2, 10).astype(dtype)
    +        x = np.arange(10) * unit_registry.m
    +        new_x = (np.arange(10) + 0.5) * unit
    +        data_array = xr.DataArray(array, coords={"x": x}, dims="x")
     
             if error is not None:
                 with pytest.raises(error):
    -                data_array.interp_like(other)
    +                func(data_array, x=new_x)
     
                 return
     
             units = extract_units(data_array)
             expected = attach_units(
    -            strip_units(data_array).interp_like(
    -                strip_units(convert_units(other, units))
    +            func(
    +                strip_units(data_array),
    +                x=strip_units(convert_units(new_x, {None: unit_registry.m})),
                 ),
                 units,
             )
    -        actual = data_array.interp_like(other)
    +        actual = func(data_array, x=new_x)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
    -    @pytest.mark.xfail(reason="indexes don't support units")
    +    @pytest.mark.parametrize("variant", ("data", "coords"))
         @pytest.mark.parametrize(
    -        "unit,error",
    +        "func",
             (
    -            pytest.param(1, DimensionalityError, id="no_unit"),
                 pytest.param(
    -                unit_registry.dimensionless, DimensionalityError, id="dimensionless"
    +                method("interp_like"), marks=pytest.mark.xfail(reason="uses scipy")
                 ),
    -            pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
    -            pytest.param(unit_registry.cm, None, id="compatible_unit"),
    -            pytest.param(unit_registry.m, None, id="identical_unit"),
    +            method("reindex_like"),
             ),
    +        ids=repr,
         )
    -    def test_reindex(self, unit, error, dtype):
    -        array = (
    -            np.linspace(1, 2, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK
    -        )
    -        new_coords = (np.arange(10) + 0.5) * unit
    -        coords = {
    -            "x": np.arange(10) * unit_registry.m,
    -            "y": np.arange(5) * unit_registry.m,
    +    def test_interp_reindex_like(self, variant, func, dtype):
    +        variants = {
    +            "data": (unit_registry.m, 1),
    +            "coords": (1, unit_registry.m),
             }
    +        data_unit, coord_unit = variants.get(variant)
     
    -        data_array = xr.DataArray(array, coords=coords, dims=("x", "y"))
    -        func = method("reindex")
    +        array = np.linspace(1, 2, 10).astype(dtype) * data_unit
    +        coord = np.arange(10) * coord_unit
     
    -        if error is not None:
    -            with pytest.raises(error):
    -                func(data_array, x=new_coords)
    -
    -            return
    +        x = np.arange(10)
    +        new_x = np.arange(-2, 2) + 0.5
    +        data_array = xr.DataArray(array, coords={"x": x, "y": ("x", coord)}, dims="x")
    +        other = xr.DataArray(np.empty_like(new_x), coords={"x": new_x}, dims="x")
     
    -        expected = attach_units(
    -            func(
    -                strip_units(data_array),
    -                x=strip_units(convert_units(new_coords, {None: unit_registry.m})),
    -            ),
    -            {None: unit_registry.degK},
    -        )
    -        actual = func(data_array, x=new_coords)
    +        units = extract_units(data_array)
    +        expected = attach_units(func(strip_units(data_array), other), units)
    +        actual = func(data_array, other)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_allclose(expected, actual)
     
    -    @pytest.mark.xfail(reason="indexes don't support units")
    +    @pytest.mark.skip(reason="indexes don't support units")
         @pytest.mark.parametrize(
             "unit,error",
             (
    @@ -3354,38 +3548,37 @@ def test_reindex(self, unit, error, dtype):
                 pytest.param(unit_registry.m, None, id="identical_unit"),
             ),
         )
    -    def test_reindex_like(self, unit, error, dtype):
    -        array = (
    -            np.linspace(1, 2, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK
    -        )
    -        coords = {
    -            "x": (np.arange(10) + 0.3) * unit_registry.m,
    -            "y": (np.arange(5) + 0.3) * unit_registry.m,
    -        }
    +    @pytest.mark.parametrize(
    +        "func",
    +        (method("interp_like"), method("reindex_like")),
    +        ids=repr,
    +    )
    +    def test_interp_reindex_like_indexing(self, func, unit, error, dtype):
    +        array = np.linspace(1, 2, 10).astype(dtype)
    +        x = np.arange(10) * unit_registry.m
    +        new_x = (np.arange(-2, 2) + 0.5) * unit
     
    -        data_array = xr.DataArray(array, coords=coords, dims=("x", "y"))
    -        other = xr.DataArray(
    -            data=np.empty((20, 10)) * unit_registry.degK,
    -            coords={"x": np.arange(20) * unit, "y": np.arange(10) * unit},
    -            dims=("x", "y"),
    -        )
    +        data_array = xr.DataArray(array, coords={"x": x}, dims="x")
    +        other = xr.DataArray(np.empty_like(new_x), {"x": new_x}, dims="x")
     
             if error is not None:
                 with pytest.raises(error):
    -                data_array.reindex_like(other)
    +                func(data_array, other)
     
                 return
     
             units = extract_units(data_array)
             expected = attach_units(
    -            strip_units(data_array).reindex_like(
    -                strip_units(convert_units(other, units))
    +            func(
    +                strip_units(data_array),
    +                strip_units(convert_units(other, {None: unit_registry.m})),
                 ),
                 units,
             )
    -        actual = data_array.reindex_like(other)
    +        actual = func(data_array, other)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize(
             "func",
    @@ -3407,9 +3600,10 @@ def test_stacking_stacked(self, func, dtype):
             expected = attach_units(func(strip_units(stacked)), {"data": unit_registry.m})
             actual = func(stacked)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
    -    @pytest.mark.xfail(reason="indexes don't support units")
    +    @pytest.mark.skip(reason="indexes don't support units")
         def test_to_unstacked_dataset(self, dtype):
             array = (
                 np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype)
    @@ -3430,7 +3624,8 @@ def test_to_unstacked_dataset(self, dtype):
             ).rename({elem.magnitude: elem for elem in x})
             actual = func(data_array)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
         @pytest.mark.parametrize(
             "func",
    @@ -3438,8 +3633,10 @@ def test_to_unstacked_dataset(self, dtype):
                 method("transpose", "y", "x", "z"),
                 method("stack", a=("x", "y")),
                 method("set_index", x="x2"),
    +            method("shift", x=2),
                 pytest.param(
    -                method("shift", x=2), marks=pytest.mark.xfail(reason="strips units")
    +                method("rank", dim="x"),
    +                marks=pytest.mark.skip(reason="rank not implemented for non-ndarray"),
                 ),
                 method("roll", x=2, roll_coords=False),
                 method("sortby", "x2"),
    @@ -3466,54 +3663,73 @@ def test_stacking_reordering(self, func, dtype):
             expected = attach_units(func(strip_units(data_array)), {None: unit_registry.m})
             actual = func(data_array)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
    +    @pytest.mark.parametrize(
    +        "variant",
    +        (
    +            "data",
    +            pytest.param(
    +                "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +            ),
    +            "coords",
    +        ),
    +    )
         @pytest.mark.parametrize(
             "func",
             (
                 method("diff", dim="x"),
                 method("differentiate", coord="x"),
                 method("integrate", coord="x"),
    -            pytest.param(
    -                method("quantile", q=[0.25, 0.75]),
    -                marks=pytest.mark.xfail(reason="nanquantile not implemented"),
    -            ),
    +            method("quantile", q=[0.25, 0.75]),
                 method("reduce", func=np.sum, dim="x"),
    -            pytest.param(
    -                lambda x: x.dot(x),
    -                id="method_dot",
    -                marks=pytest.mark.xfail(
    -                    reason="pint does not implement the dot method"
    -                ),
    -            ),
    +            pytest.param(lambda x: x.dot(x), id="method_dot"),
             ),
             ids=repr,
         )
    -    def test_computation(self, func, dtype):
    -        array = (
    -            np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * unit_registry.m
    -        )
    +    def test_computation(self, func, variant, dtype):
    +        unit = unit_registry.m
     
    -        x = np.arange(array.shape[0]) * unit_registry.m
    -        y = np.arange(array.shape[1]) * unit_registry.s
    +        variants = {
    +            "data": (unit, 1, 1),
    +            "dims": (1, unit, 1),
    +            "coords": (1, 1, unit),
    +        }
    +        data_unit, dim_unit, coord_unit = variants.get(variant)
     
    -        data_array = xr.DataArray(data=array, coords={"x": x, "y": y}, dims=("x", "y"))
    +        array = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit
    +
    +        x = np.arange(array.shape[0]) * dim_unit
    +        y = np.arange(array.shape[1]) * dim_unit
    +
    +        u = np.linspace(0, 1, array.shape[0]) * coord_unit
    +
    +        data_array = xr.DataArray(
    +            data=array, coords={"x": x, "y": y, "u": ("x", u)}, dims=("x", "y")
    +        )
     
             # we want to make sure the output unit is correct
    -        units = {
    -            **extract_units(data_array),
    -            **(
    -                {}
    -                if isinstance(func, (function, method))
    -                else extract_units(func(array.reshape(-1)))
    -            ),
    -        }
    +        units = extract_units(data_array)
    +        if not isinstance(func, (function, method)):
    +            units.update(extract_units(func(array.reshape(-1))))
     
             expected = attach_units(func(strip_units(data_array)), units)
             actual = func(data_array)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
    +    @pytest.mark.parametrize(
    +        "variant",
    +        (
    +            "data",
    +            pytest.param(
    +                "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +            ),
    +            "coords",
    +        ),
    +    )
         @pytest.mark.parametrize(
             "func",
             (
    @@ -3522,30 +3738,47 @@ def test_computation(self, func, dtype):
                 method("coarsen", y=2),
                 pytest.param(
                     method("rolling", y=3),
    -                marks=pytest.mark.xfail(reason="rolling strips units"),
    +                marks=pytest.mark.xfail(
    +                    reason="numpy.lib.stride_tricks.as_strided converts to ndarray"
    +                ),
                 ),
                 pytest.param(
                     method("rolling_exp", y=3),
    -                marks=pytest.mark.xfail(reason="units not supported by numbagg"),
    +                marks=pytest.mark.xfail(
    +                    reason="numbagg functions are not supported by pint"
    +                ),
                 ),
    +            method("weighted", xr.DataArray(data=np.linspace(0, 1, 10), dims="y")),
             ),
             ids=repr,
         )
    -    def test_computation_objects(self, func, dtype):
    -        array = (
    -            np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * unit_registry.m
    -        )
    +    def test_computation_objects(self, func, variant, dtype):
    +        unit = unit_registry.m
    +
    +        variants = {
    +            "data": (unit, 1, 1),
    +            "dims": (1, unit, 1),
    +            "coords": (1, 1, unit),
    +        }
    +        data_unit, dim_unit, coord_unit = variants.get(variant)
    +
    +        array = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit
     
    -        x = np.array([0, 0, 1, 2, 2]) * unit_registry.m
    -        y = np.arange(array.shape[1]) * 3 * unit_registry.s
    +        x = np.array([0, 0, 1, 2, 2]) * dim_unit
    +        y = np.arange(array.shape[1]) * 3 * dim_unit
     
    -        data_array = xr.DataArray(data=array, coords={"x": x, "y": y}, dims=("x", "y"))
    +        u = np.linspace(0, 1, 5) * coord_unit
    +
    +        data_array = xr.DataArray(
    +            data=array, coords={"x": x, "y": y, "u": ("x", u)}, dims=("x", "y")
    +        )
             units = extract_units(data_array)
     
             expected = attach_units(func(strip_units(data_array)).mean(), units)
             actual = func(data_array).mean()
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_allclose(expected, actual)
     
         def test_resample(self, dtype):
             array = np.linspace(0, 5, 10).astype(dtype) * unit_registry.m
    @@ -3559,30 +3792,48 @@ def test_resample(self, dtype):
             expected = attach_units(func(strip_units(data_array)).mean(), units)
             actual = func(data_array).mean()
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
    +    @pytest.mark.parametrize(
    +        "variant",
    +        (
    +            "data",
    +            pytest.param(
    +                "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +            ),
    +            "coords",
    +        ),
    +    )
         @pytest.mark.parametrize(
             "func",
             (
    -            method("assign_coords", z=(["x"], np.arange(5) * unit_registry.s)),
    +            method("assign_coords", z=("x", np.arange(5) * unit_registry.s)),
                 method("first"),
                 method("last"),
    -            pytest.param(
    -                method("quantile", q=[0.25, 0.5, 0.75], dim="x"),
    -                marks=pytest.mark.xfail(reason="nanquantile not implemented"),
    -            ),
    +            method("quantile", q=[0.25, 0.5, 0.75], dim="x"),
             ),
             ids=repr,
         )
    -    def test_grouped_operations(self, func, dtype):
    -        array = (
    -            np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * unit_registry.m
    -        )
    +    def test_grouped_operations(self, func, variant, dtype):
    +        unit = unit_registry.m
     
    -        x = np.arange(array.shape[0]) * unit_registry.m
    -        y = np.arange(array.shape[1]) * 3 * unit_registry.s
    +        variants = {
    +            "data": (unit, 1, 1),
    +            "dims": (1, unit, 1),
    +            "coords": (1, 1, unit),
    +        }
    +        data_unit, dim_unit, coord_unit = variants.get(variant)
    +        array = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit
     
    -        data_array = xr.DataArray(data=array, coords={"x": x, "y": y}, dims=("x", "y"))
    +        x = np.arange(array.shape[0]) * dim_unit
    +        y = np.arange(array.shape[1]) * 3 * dim_unit
    +
    +        u = np.linspace(0, 1, array.shape[0]) * coord_unit
    +
    +        data_array = xr.DataArray(
    +            data=array, coords={"x": x, "y": y, "u": ("x", u)}, dims=("x", "y")
    +        )
             units = {**extract_units(data_array), **{"z": unit_registry.s, "q": None}}
     
             stripped_kwargs = {
    @@ -3598,18 +3849,19 @@ def test_grouped_operations(self, func, dtype):
             )
             actual = func(data_array.groupby("y"))
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_identical(expected, actual)
     
     
     class TestDataset:
         @pytest.mark.parametrize(
             "unit,error",
             (
    -            pytest.param(1, DimensionalityError, id="no_unit"),
    +            pytest.param(1, xr.MergeError, id="no_unit"),
                 pytest.param(
    -                unit_registry.dimensionless, DimensionalityError, id="dimensionless"
    +                unit_registry.dimensionless, xr.MergeError, id="dimensionless"
                 ),
    -            pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
    +            pytest.param(unit_registry.s, xr.MergeError, id="incompatible_unit"),
                 pytest.param(unit_registry.mm, None, id="compatible_unit"),
                 pytest.param(unit_registry.m, None, id="same_unit"),
             ),
    @@ -3618,11 +3870,10 @@ class TestDataset:
             "shared",
             (
                 "nothing",
    -            pytest.param("dims", marks=pytest.mark.xfail(reason="indexes strip units")),
                 pytest.param(
    -                "coords",
    -                marks=pytest.mark.xfail(reason="reindex does not work with pint yet"),
    +                "dims", marks=pytest.mark.skip(reason="indexes don't support units")
                 ),
    +            "coords",
             ),
         )
         def test_init(self, shared, unit, error, dtype):
    @@ -3630,60 +3881,53 @@ def test_init(self, shared, unit, error, dtype):
             scaled_unit = unit_registry.mm
     
             a = np.linspace(0, 1, 10).astype(dtype) * unit_registry.Pa
    -        b = np.linspace(-1, 0, 12).astype(dtype) * unit_registry.Pa
    -
    -        raw_x = np.arange(a.shape[0])
    -        x = raw_x * original_unit
    -        x2 = x.to(scaled_unit)
    -
    -        raw_y = np.arange(b.shape[0])
    -        y = raw_y * unit
    -        y_units = unit if isinstance(y, unit_registry.Quantity) else None
    -        if isinstance(y, unit_registry.Quantity):
    -            if y.check(scaled_unit):
    -                y2 = y.to(scaled_unit)
    -            else:
    -                y2 = y * 1000
    -            y2_units = y2.units
    -        else:
    -            y2 = y * 1000
    -            y2_units = None
    +        b = np.linspace(-1, 0, 10).astype(dtype) * unit_registry.degK
    +
    +        values_a = np.arange(a.shape[0])
    +        dim_a = values_a * original_unit
    +        coord_a = dim_a.to(scaled_unit)
    +
    +        values_b = np.arange(b.shape[0])
    +        dim_b = values_b * unit
    +        coord_b = (
    +            dim_b.to(scaled_unit)
    +            if unit_registry.is_compatible_with(dim_b, scaled_unit)
    +            and unit != scaled_unit
    +            else dim_b * 1000
    +        )
     
             variants = {
    -            "nothing": ({"x": x, "x2": ("x", x2)}, {"y": y, "y2": ("y", y2)}),
    -            "dims": (
    -                {"x": x, "x2": ("x", strip_units(x2))},
    -                {"x": y, "y2": ("x", strip_units(y2))},
    +            "nothing": ({}, {}),
    +            "dims": ({"x": dim_a}, {"x": dim_b}),
    +            "coords": (
    +                {"x": values_a, "y": ("x", coord_a)},
    +                {"x": values_b, "y": ("x", coord_b)},
                 ),
    -            "coords": ({"x": raw_x, "y": ("x", x2)}, {"x": raw_y, "y": ("x", y2)}),
             }
             coords_a, coords_b = variants.get(shared)
     
             dims_a, dims_b = ("x", "y") if shared == "nothing" else ("x", "x")
     
    -        arr1 = xr.DataArray(data=a, coords=coords_a, dims=dims_a)
    -        arr2 = xr.DataArray(data=b, coords=coords_b, dims=dims_b)
    +        a = xr.DataArray(data=a, coords=coords_a, dims=dims_a)
    +        b = xr.DataArray(data=b, coords=coords_b, dims=dims_b)
    +
             if error is not None and shared != "nothing":
                 with pytest.raises(error):
    -                xr.Dataset(data_vars={"a": arr1, "b": arr2})
    +                xr.Dataset(data_vars={"a": a, "b": b})
     
                 return
     
    -        actual = xr.Dataset(data_vars={"a": arr1, "b": arr2})
    +        actual = xr.Dataset(data_vars={"a": a, "b": b})
     
    -        expected_units = {
    -            "a": a.units,
    -            "b": b.units,
    -            "x": x.units,
    -            "x2": x2.units,
    -            "y": y_units,
    -            "y2": y2_units,
    -        }
    +        units = merge_mappings(
    +            extract_units(a.rename("a")), extract_units(b.rename("b"))
    +        )
             expected = attach_units(
    -            xr.Dataset(data_vars={"a": strip_units(arr1), "b": strip_units(arr2)}),
    -            expected_units,
    +            xr.Dataset(data_vars={"a": strip_units(a), "b": strip_units(b)}), units
             )
    -        assert_equal_with_units(actual, expected)
    +
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
         @pytest.mark.parametrize(
             "func", (pytest.param(str, id="str"), pytest.param(repr, id="repr"))
    @@ -3691,148 +3935,141 @@ def test_init(self, shared, unit, error, dtype):
         @pytest.mark.parametrize(
             "variant",
             (
    +            "data",
                 pytest.param(
    -                "with_dims",
    -                marks=pytest.mark.xfail(reason="units in indexes are not supported"),
    +                "dims",
    +                marks=pytest.mark.skip(reason="indexes don't support units"),
                 ),
    -            pytest.param("with_coords"),
    -            pytest.param("without_coords"),
    +            "coords",
             ),
         )
    -    @pytest.mark.filterwarnings("error:::pint[.*]")
         def test_repr(self, func, variant, dtype):
    -        array1 = np.linspace(1, 2, 10, dtype=dtype) * unit_registry.Pa
    -        array2 = np.linspace(0, 1, 10, dtype=dtype) * unit_registry.degK
    +        unit1, unit2 = (
    +            (unit_registry.Pa, unit_registry.degK) if variant == "data" else (1, 1)
    +        )
    +
    +        array1 = np.linspace(1, 2, 10, dtype=dtype) * unit1
    +        array2 = np.linspace(0, 1, 10, dtype=dtype) * unit2
     
             x = np.arange(len(array1)) * unit_registry.s
             y = x.to(unit_registry.ms)
     
             variants = {
    -            "with_dims": {"x": x},
    -            "with_coords": {"y": ("x", y)},
    -            "without_coords": {},
    +            "dims": {"x": x},
    +            "coords": {"y": ("x", y)},
    +            "data": {},
             }
     
    -        data_array = xr.Dataset(
    +        ds = xr.Dataset(
                 data_vars={"a": ("x", array1), "b": ("x", array2)},
                 coords=variants.get(variant),
             )
     
             # FIXME: this just checks that the repr does not raise
             # warnings or errors, but does not check the result
    -        func(data_array)
    +        func(ds)
     
         @pytest.mark.parametrize(
             "func",
             (
    +            function("all"),
    +            function("any"),
                 pytest.param(
    -                function("all"),
    -                marks=pytest.mark.xfail(reason="not implemented by pint"),
    +                function("argmax"),
    +                marks=pytest.mark.skip(
    +                    reason="calling np.argmax as a function on xarray objects is not "
    +                    "supported"
    +                ),
                 ),
                 pytest.param(
    -                function("any"),
    -                marks=pytest.mark.xfail(reason="not implemented by pint"),
    +                function("argmin"),
    +                marks=pytest.mark.skip(
    +                    reason="calling np.argmin as a function on xarray objects is not "
    +                    "supported"
    +                ),
                 ),
    -            function("argmax"),
    -            function("argmin"),
                 function("max"),
                 function("min"),
                 function("mean"),
                 pytest.param(
                     function("median"),
    -                marks=pytest.mark.xfail(
    -                    reason="np.median does not work with dataset yet"
    -                ),
    +                marks=pytest.mark.xfail(reason="median does not work with dataset yet"),
                 ),
                 function("sum"),
    -            pytest.param(
    -                function("prod"),
    -                marks=pytest.mark.xfail(reason="not implemented by pint"),
    -            ),
    +            function("prod"),
                 function("std"),
                 function("var"),
                 function("cumsum"),
    -            pytest.param(
    -                function("cumprod"),
    -                marks=pytest.mark.xfail(reason="fails within xarray"),
    -            ),
    -            pytest.param(
    -                method("all"), marks=pytest.mark.xfail(reason="not implemented by pint")
    -            ),
    -            pytest.param(
    -                method("any"), marks=pytest.mark.xfail(reason="not implemented by pint")
    -            ),
    -            method("argmax"),
    -            method("argmin"),
    +            function("cumprod"),
    +            method("all"),
    +            method("any"),
    +            method("argmax", dim="x"),
    +            method("argmin", dim="x"),
                 method("max"),
                 method("min"),
                 method("mean"),
                 method("median"),
                 method("sum"),
    -            pytest.param(
    -                method("prod"),
    -                marks=pytest.mark.xfail(reason="not implemented by pint"),
    -            ),
    +            method("prod"),
                 method("std"),
                 method("var"),
                 method("cumsum"),
    -            pytest.param(
    -                method("cumprod"), marks=pytest.mark.xfail(reason="fails within xarray")
    -            ),
    +            method("cumprod"),
             ),
             ids=repr,
         )
         def test_aggregation(self, func, dtype):
    -        unit_a = (
    -            unit_registry.Pa if func.name != "cumprod" else unit_registry.dimensionless
    -        )
    -        unit_b = (
    -            unit_registry.kg / unit_registry.m ** 3
    +        if func.name == "prod" and dtype.kind == "f":
    +            pytest.xfail(reason="nanprod is not supported, yet")
    +
    +        unit_a, unit_b = (
    +            (unit_registry.Pa, unit_registry.degK)
                 if func.name != "cumprod"
    -            else unit_registry.dimensionless
    -        )
    -        a = xr.DataArray(data=np.linspace(0, 1, 10).astype(dtype) * unit_a, dims="x")
    -        b = xr.DataArray(data=np.linspace(-1, 0, 10).astype(dtype) * unit_b, dims="x")
    -        x = xr.DataArray(data=np.arange(10).astype(dtype) * unit_registry.m, dims="x")
    -        y = xr.DataArray(
    -            data=np.arange(10, 20).astype(dtype) * unit_registry.s, dims="x"
    +            else (unit_registry.dimensionless, unit_registry.dimensionless)
             )
     
    -        ds = xr.Dataset(data_vars={"a": a, "b": b}, coords={"x": x, "y": y})
    +        a = np.linspace(0, 1, 10).astype(dtype) * unit_a
    +        b = np.linspace(-1, 0, 10).astype(dtype) * unit_b
    +
    +        ds = xr.Dataset({"a": ("x", a), "b": ("x", b)})
    +
    +        if "dim" in func.kwargs:
    +            numpy_kwargs = func.kwargs.copy()
    +            dim = numpy_kwargs.pop("dim")
    +
    +            axis_a = ds.a.get_axis_num(dim)
    +            axis_b = ds.b.get_axis_num(dim)
    +
    +            numpy_kwargs_a = numpy_kwargs.copy()
    +            numpy_kwargs_a["axis"] = axis_a
    +            numpy_kwargs_b = numpy_kwargs.copy()
    +            numpy_kwargs_b["axis"] = axis_b
    +        else:
    +            numpy_kwargs_a = {}
    +            numpy_kwargs_b = {}
    +
    +        units_a = array_extract_units(func(a, **numpy_kwargs_a))
    +        units_b = array_extract_units(func(b, **numpy_kwargs_b))
    +        units = {"a": units_a, "b": units_b}
     
             actual = func(ds)
    -        expected = attach_units(
    -            func(strip_units(ds)),
    -            {
    -                "a": extract_units(func(a)).get(None),
    -                "b": extract_units(func(b)).get(None),
    -            },
    -        )
    +        expected = attach_units(func(strip_units(ds)), units)
     
    -        assert_equal_with_units(actual, expected)
    +        assert_units_equal(expected, actual)
    +        assert_allclose(expected, actual)
     
         @pytest.mark.parametrize("property", ("imag", "real"))
         def test_numpy_properties(self, property, dtype):
    -        ds = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(
    -                    data=np.linspace(0, 1, 10) * unit_registry.Pa, dims="x"
    -                ),
    -                "b": xr.DataArray(
    -                    data=np.linspace(-1, 0, 15) * unit_registry.Pa, dims="y"
    -                ),
    -            },
    -            coords={
    -                "x": np.arange(10) * unit_registry.m,
    -                "y": np.arange(15) * unit_registry.s,
    -            },
    -        )
    +        a = np.linspace(0, 1, 10) * unit_registry.Pa
    +        b = np.linspace(-1, 0, 15) * unit_registry.degK
    +        ds = xr.Dataset({"a": ("x", a), "b": ("y", b)})
             units = extract_units(ds)
     
             actual = getattr(ds, property)
             expected = attach_units(getattr(strip_units(ds), property), units)
     
    -        assert_equal_with_units(actual, expected)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
         @pytest.mark.parametrize(
             "func",
    @@ -3846,31 +4083,19 @@ def test_numpy_properties(self, property, dtype):
             ids=repr,
         )
         def test_numpy_methods(self, func, dtype):
    -        ds = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(
    -                    data=np.linspace(1, -1, 10) * unit_registry.Pa, dims="x"
    -                ),
    -                "b": xr.DataArray(
    -                    data=np.linspace(-1, 1, 15) * unit_registry.Pa, dims="y"
    -                ),
    -            },
    -            coords={
    -                "x": np.arange(10) * unit_registry.m,
    -                "y": np.arange(15) * unit_registry.s,
    -            },
    -        )
    -        units = {
    -            "a": array_extract_units(func(ds.a)),
    -            "b": array_extract_units(func(ds.b)),
    -            "x": unit_registry.m,
    -            "y": unit_registry.s,
    -        }
    +        a = np.linspace(1, -1, 10) * unit_registry.Pa
    +        b = np.linspace(-1, 1, 15) * unit_registry.degK
    +        ds = xr.Dataset({"a": ("x", a), "b": ("y", b)})
    +
    +        units_a = array_extract_units(func(a))
    +        units_b = array_extract_units(func(b))
    +        units = {"a": units_a, "b": units_b}
     
             actual = func(ds)
             expected = attach_units(func(strip_units(ds)), units)
     
    -        assert_equal_with_units(actual, expected)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
         @pytest.mark.parametrize("func", (method("clip", min=3, max=8),), ids=repr)
         @pytest.mark.parametrize(
    @@ -3887,21 +4112,13 @@ def test_numpy_methods(self, func, dtype):
         )
         def test_numpy_methods_with_args(self, func, unit, error, dtype):
             data_unit = unit_registry.m
    -        ds = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(data=np.arange(10) * data_unit, dims="x"),
    -                "b": xr.DataArray(data=np.arange(15) * data_unit, dims="y"),
    -            },
    -            coords={
    -                "x": np.arange(10) * unit_registry.m,
    -                "y": np.arange(15) * unit_registry.s,
    -            },
    -        )
    +        a = np.linspace(0, 10, 15) * unit_registry.m
    +        b = np.linspace(-2, 12, 20) * unit_registry.m
    +        ds = xr.Dataset({"a": ("x", a), "b": ("y", b)})
             units = extract_units(ds)
     
             kwargs = {
    -            key: (value * unit if isinstance(value, (int, float)) else value)
    -            for key, value in func.kwargs.items()
    +            key: array_attach_units(value, unit) for key, value in func.kwargs.items()
             }
     
             if error is not None:
    @@ -3918,7 +4135,8 @@ def test_numpy_methods_with_args(self, func, unit, error, dtype):
             actual = func(ds, **kwargs)
             expected = attach_units(func(strip_units(ds), **stripped_kwargs), units)
     
    -        assert_equal_with_units(actual, expected)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
         @pytest.mark.parametrize(
             "func", (method("isnull"), method("notnull"), method("count")), ids=repr
    @@ -3948,22 +4166,13 @@ def test_missing_value_detection(self, func, dtype):
                 * unit_registry.Pa
             )
     
    -        x = np.arange(array1.shape[0]) * unit_registry.m
    -        y = np.arange(array1.shape[1]) * unit_registry.m
    -        z = np.arange(array2.shape[0]) * unit_registry.m
    -
    -        ds = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(data=array1, dims=("x", "y")),
    -                "b": xr.DataArray(data=array2, dims=("z", "x")),
    -            },
    -            coords={"x": x, "y": y, "z": z},
    -        )
    +        ds = xr.Dataset({"a": (("x", "y"), array1), "b": (("z", "x"), array2)})
     
             expected = func(strip_units(ds))
             actual = func(ds)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
         @pytest.mark.xfail(reason="ffill and bfill lose the unit")
         @pytest.mark.parametrize("func", (method("ffill"), method("bfill")), ids=repr)
    @@ -3977,23 +4186,14 @@ def test_missing_value_filling(self, func, dtype):
                 * unit_registry.Pa
             )
     
    -        x = np.arange(len(array1))
    -
    -        ds = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(data=array1, dims="x"),
    -                "b": xr.DataArray(data=array2, dims="x"),
    -            },
    -            coords={"x": x},
    -        )
    +        ds = xr.Dataset({"a": ("x", array1), "b": ("y", array2)})
    +        units = extract_units(ds)
     
    -        expected = attach_units(
    -            func(strip_units(ds), dim="x"),
    -            {"a": unit_registry.degK, "b": unit_registry.Pa},
    -        )
    +        expected = attach_units(func(strip_units(ds), dim="x"), units)
             actual = func(ds, dim="x")
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
         @pytest.mark.parametrize(
             "unit,error",
    @@ -4007,9 +4207,6 @@ def test_missing_value_filling(self, func, dtype):
                     unit_registry.cm,
                     None,
                     id="compatible_unit",
    -                marks=pytest.mark.xfail(
    -                    reason="where converts the array, not the fill value"
    -                ),
                 ),
                 pytest.param(unit_registry.m, None, id="identical_unit"),
             ),
    @@ -4031,30 +4228,26 @@ def test_fillna(self, fill_value, unit, error, dtype):
                 np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype)
                 * unit_registry.m
             )
    -        ds = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(data=array1, dims="x"),
    -                "b": xr.DataArray(data=array2, dims="x"),
    -            }
    -        )
    +        ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)})
    +        value = fill_value * unit
    +        units = extract_units(ds)
     
             if error is not None:
                 with pytest.raises(error):
    -                ds.fillna(value=fill_value * unit)
    +                ds.fillna(value=value)
     
                 return
     
    -        actual = ds.fillna(value=fill_value * unit)
    +        actual = ds.fillna(value=value)
             expected = attach_units(
                 strip_units(ds).fillna(
    -                value=strip_units(
    -                    convert_units(fill_value * unit, {None: unit_registry.m})
    -                )
    +                value=strip_units(convert_units(value, {None: unit_registry.m}))
                 ),
    -            {"a": unit_registry.m, "b": unit_registry.m},
    +            units,
             )
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
         def test_dropna(self, dtype):
             array1 = (
    @@ -4065,22 +4258,14 @@ def test_dropna(self, dtype):
                 np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype)
                 * unit_registry.Pa
             )
    -        x = np.arange(len(array1))
    -        ds = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(data=array1, dims="x"),
    -                "b": xr.DataArray(data=array2, dims="x"),
    -            },
    -            coords={"x": x},
    -        )
    +        ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)})
    +        units = extract_units(ds)
     
    -        expected = attach_units(
    -            strip_units(ds).dropna(dim="x"),
    -            {"a": unit_registry.degK, "b": unit_registry.Pa},
    -        )
    +        expected = attach_units(strip_units(ds).dropna(dim="x"), units)
             actual = ds.dropna(dim="x")
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
         @pytest.mark.parametrize(
             "unit",
    @@ -4101,34 +4286,28 @@ def test_isin(self, unit, dtype):
                 np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype)
                 * unit_registry.m
             )
    -        x = np.arange(len(array1))
    -        ds = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(data=array1, dims="x"),
    -                "b": xr.DataArray(data=array2, dims="x"),
    -            },
    -            coords={"x": x},
    -        )
    +        ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)})
     
             raw_values = np.array([1.4, np.nan, 2.3]).astype(dtype)
             values = raw_values * unit
     
    -        if (
    -            isinstance(values, unit_registry.Quantity)
    -            and values.check(unit_registry.m)
    -            and unit != unit_registry.m
    -        ):
    -            raw_values = values.to(unit_registry.m).magnitude
    +        converted_values = (
    +            convert_units(values, {None: unit_registry.m})
    +            if is_compatible(unit, unit_registry.m)
    +            else values
    +        )
     
    -        expected = strip_units(ds).isin(raw_values)
    -        if not isinstance(values, unit_registry.Quantity) or not values.check(
    -            unit_registry.m
    -        ):
    +        expected = strip_units(ds).isin(strip_units(converted_values))
    +        # TODO: use `unit_registry.is_compatible_with(unit, unit_registry.m)` instead.
    +        # Needs `pint>=0.12.1`, though, so we probably should wait until that is released.
    +        if not is_compatible(unit, unit_registry.m):
                 expected.a[:] = False
                 expected.b[:] = False
    +
             actual = ds.isin(values)
     
    -        assert_equal_with_units(actual, expected)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
         @pytest.mark.parametrize(
             "variant", ("masking", "replacing_scalar", "replacing_array", "dropping")
    @@ -4150,13 +4329,8 @@ def test_where(self, variant, unit, error, dtype):
             array1 = np.linspace(0, 1, 10).astype(dtype) * original_unit
             array2 = np.linspace(-1, 0, 10).astype(dtype) * original_unit
     
    -        ds = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(data=array1, dims="x"),
    -                "b": xr.DataArray(data=array2, dims="x"),
    -            },
    -            coords={"x": np.arange(len(array1))},
    -        )
    +        ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)})
    +        units = extract_units(ds)
     
             condition = ds < 0.5 * original_unit
             other = np.linspace(-2, -1, 10).astype(dtype) * unit
    @@ -4180,13 +4354,14 @@ def test_where(self, variant, unit, error, dtype):
     
             expected = attach_units(
                 strip_units(ds).where(**kwargs_without_units),
    -            {"a": original_unit, "b": original_unit},
    +            units,
             )
             actual = ds.where(**kwargs)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
    -    @pytest.mark.xfail(reason="interpolate strips units")
    +    @pytest.mark.xfail(reason="interpolate_na uses numpy.vectorize")
         def test_interpolate_na(self, dtype):
             array1 = (
                 np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype)
    @@ -4196,24 +4371,18 @@ def test_interpolate_na(self, dtype):
                 np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype)
                 * unit_registry.Pa
             )
    -        x = np.arange(len(array1))
    -        ds = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(data=array1, dims="x"),
    -                "b": xr.DataArray(data=array2, dims="x"),
    -            },
    -            coords={"x": x},
    -        )
    +        ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)})
    +        units = extract_units(ds)
     
             expected = attach_units(
                 strip_units(ds).interpolate_na(dim="x"),
    -            {"a": unit_registry.degK, "b": unit_registry.Pa},
    +            units,
             )
             actual = ds.interpolate_na(dim="x")
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
    -    @pytest.mark.xfail(reason="wrong argument order for `where`")
         @pytest.mark.parametrize(
             "unit,error",
             (
    @@ -4226,31 +4395,42 @@ def test_interpolate_na(self, dtype):
                 pytest.param(unit_registry.m, None, id="same_unit"),
             ),
         )
    -    def test_combine_first(self, unit, error, dtype):
    +    @pytest.mark.parametrize(
    +        "variant",
    +        (
    +            "data",
    +            pytest.param(
    +                "dims",
    +                marks=pytest.mark.skip(reason="indexes don't support units"),
    +            ),
    +        ),
    +    )
    +    def test_combine_first(self, variant, unit, error, dtype):
    +        variants = {
    +            "data": (unit_registry.m, unit, 1, 1),
    +            "dims": (1, 1, unit_registry.m, unit),
    +        }
    +        data_unit, other_data_unit, dims_unit, other_dims_unit = variants.get(variant)
    +
             array1 = (
    -            np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype)
    -            * unit_registry.m
    +            np.array([1.4, np.nan, 2.3, np.nan, np.nan, 9.1]).astype(dtype) * data_unit
             )
             array2 = (
    -            np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype)
    -            * unit_registry.m
    +            np.array([4.3, 9.8, 7.5, np.nan, 8.2, np.nan]).astype(dtype) * data_unit
             )
    -        x = np.arange(len(array1))
    +        x = np.arange(len(array1)) * dims_unit
             ds = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(data=array1, dims="x"),
    -                "b": xr.DataArray(data=array2, dims="x"),
    -            },
    +            data_vars={"a": ("x", array1), "b": ("x", array2)},
                 coords={"x": x},
             )
    -        other_array1 = np.ones_like(array1) * unit
    -        other_array2 = -1 * np.ones_like(array2) * unit
    +        units = extract_units(ds)
    +
    +        other_array1 = np.ones_like(array1) * other_data_unit
    +        other_array2 = np.full_like(array2, fill_value=-1) * other_data_unit
    +        other_x = (np.arange(array1.shape[0]) + 5) * other_dims_unit
             other = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(data=other_array1, dims="x"),
    -                "b": xr.DataArray(data=other_array2, dims="x"),
    -            },
    -            coords={"x": np.arange(array1.shape[0])},
    +            data_vars={"a": ("x", other_array1), "b": ("x", other_array2)},
    +            coords={"x": other_x},
             )
     
             if error is not None:
    @@ -4260,16 +4440,13 @@ def test_combine_first(self, unit, error, dtype):
                 return
     
             expected = attach_units(
    -            strip_units(ds).combine_first(
    -                strip_units(
    -                    convert_units(other, {"a": unit_registry.m, "b": unit_registry.m})
    -                )
    -            ),
    -            {"a": unit_registry.m, "b": unit_registry.m},
    +            strip_units(ds).combine_first(strip_units(convert_units(other, units))),
    +            units,
             )
             actual = ds.combine_first(other)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
         @pytest.mark.parametrize(
             "unit",
    @@ -4282,59 +4459,77 @@ def test_combine_first(self, unit, error, dtype):
             ),
         )
         @pytest.mark.parametrize(
    -        "variation",
    +        "variant",
             (
                 "data",
                 pytest.param(
    -                "dims", marks=pytest.mark.xfail(reason="units in indexes not supported")
    +                "dims", marks=pytest.mark.skip(reason="indexes don't support units")
                 ),
                 "coords",
             ),
         )
    -    @pytest.mark.parametrize("func", (method("equals"), method("identical")), ids=repr)
    -    def test_comparisons(self, func, variation, unit, dtype):
    -        def is_compatible(a, b):
    -            a = a if a is not None else 1
    -            b = b if b is not None else 1
    -            quantity = np.arange(5) * a
    -
    -            return a == b or quantity.check(b)
    -
    +    @pytest.mark.parametrize(
    +        "func",
    +        (
    +            method("equals"),
    +            pytest.param(
    +                method("identical"),
    +                marks=pytest.mark.skip("behaviour of identical is unclear"),
    +            ),
    +        ),
    +        ids=repr,
    +    )
    +    def test_comparisons(self, func, variant, unit, dtype):
             array1 = np.linspace(0, 5, 10).astype(dtype)
             array2 = np.linspace(-5, 0, 10).astype(dtype)
     
             coord = np.arange(len(array1)).astype(dtype)
     
    -        original_unit = unit_registry.m
    -        quantity1 = array1 * original_unit
    -        quantity2 = array2 * original_unit
    -        x = coord * original_unit
    -        y = coord * original_unit
    +        variants = {
    +            "data": (unit_registry.m, 1, 1),
    +            "dims": (1, unit_registry.m, 1),
    +            "coords": (1, 1, unit_registry.m),
    +        }
    +        data_unit, dim_unit, coord_unit = variants.get(variant)
     
    -        units = {"data": (unit, 1, 1), "dims": (1, unit, 1), "coords": (1, 1, unit)}
    -        data_unit, dim_unit, coord_unit = units.get(variation)
    +        a = array1 * data_unit
    +        b = array2 * data_unit
    +        x = coord * dim_unit
    +        y = coord * coord_unit
     
             ds = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(data=quantity1, dims="x"),
    -                "b": xr.DataArray(data=quantity2, dims="x"),
    -            },
    +            data_vars={"a": ("x", a), "b": ("x", b)},
                 coords={"x": x, "y": ("x", y)},
             )
    +        units = extract_units(ds)
    +
    +        other_variants = {
    +            "data": (unit, 1, 1),
    +            "dims": (1, unit, 1),
    +            "coords": (1, 1, unit),
    +        }
    +        other_data_unit, other_dim_unit, other_coord_unit = other_variants.get(variant)
     
             other_units = {
    -            "a": data_unit if quantity1.check(data_unit) else None,
    -            "b": data_unit if quantity2.check(data_unit) else None,
    -            "x": dim_unit if x.check(dim_unit) else None,
    -            "y": coord_unit if y.check(coord_unit) else None,
    +            "a": other_data_unit,
    +            "b": other_data_unit,
    +            "x": other_dim_unit,
    +            "y": other_coord_unit,
             }
    -        other = attach_units(strip_units(convert_units(ds, other_units)), other_units)
     
    -        units = extract_units(ds)
    +        to_convert = {
    +            key: unit if is_compatible(unit, reference) else None
    +            for key, (unit, reference) in zip_mappings(units, other_units)
    +        }
    +        # convert units where possible, then attach all units to the converted dataset
    +        other = attach_units(strip_units(convert_units(ds, to_convert)), other_units)
             other_units = extract_units(other)
     
    +        # make sure all units are compatible and only then try to
    +        # convert and compare values
             equal_ds = all(
    -            is_compatible(units[name], other_units[name]) for name in units.keys()
    +            is_compatible(unit, other_unit)
    +            for _, (unit, other_unit) in zip_mappings(units, other_units)
             ) and (strip_units(ds).equals(strip_units(convert_units(other, units))))
             equal_units = units == other_units
             expected = equal_ds and (func.name != "identical" or equal_units)
    @@ -4343,6 +4538,9 @@ def is_compatible(a, b):
     
             assert expected == actual
     
    +    # TODO: eventually use another decorator / wrapper function that
    +    # applies a filter to the parametrize combinations:
    +    # we only need a single test for data
         @pytest.mark.parametrize(
             "unit",
             (
    @@ -4353,14 +4551,30 @@ def is_compatible(a, b):
                 pytest.param(unit_registry.m, id="identical_unit"),
             ),
         )
    -    def test_broadcast_like(self, unit, dtype):
    -        array1 = np.linspace(1, 2, 2 * 1).reshape(2, 1).astype(dtype) * unit_registry.Pa
    -        array2 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * unit_registry.Pa
    +    @pytest.mark.parametrize(
    +        "variant",
    +        (
    +            "data",
    +            pytest.param(
    +                "dims",
    +                marks=pytest.mark.skip(reason="indexes don't support units"),
    +            ),
    +        ),
    +    )
    +    def test_broadcast_like(self, variant, unit, dtype):
    +        variants = {
    +            "data": ((unit_registry.m, unit), (1, 1)),
    +            "dims": ((1, 1), (unit_registry.m, unit)),
    +        }
    +        (data_unit1, data_unit2), (dim_unit1, dim_unit2) = variants.get(variant)
     
    -        x1 = np.arange(2) * unit_registry.m
    -        x2 = np.arange(2) * unit
    -        y1 = np.array([0]) * unit_registry.m
    -        y2 = np.arange(3) * unit
    +        array1 = np.linspace(1, 2, 2 * 1).reshape(2, 1).astype(dtype) * data_unit1
    +        array2 = np.linspace(0, 1, 2 * 3).reshape(2, 3).astype(dtype) * data_unit2
    +
    +        x1 = np.arange(2) * dim_unit1
    +        x2 = np.arange(2) * dim_unit2
    +        y1 = np.array([0]) * dim_unit1
    +        y2 = np.arange(3) * dim_unit2
     
             ds1 = xr.Dataset(
                 data_vars={"a": (("x", "y"), array1)}, coords={"x": x1, "y": y1}
    @@ -4374,7 +4588,8 @@ def test_broadcast_like(self, unit, dtype):
             )
             actual = ds1.broadcast_like(ds2)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
         @pytest.mark.parametrize(
             "unit",
    @@ -4387,102 +4602,122 @@ def test_broadcast_like(self, unit, dtype):
             ),
         )
         def test_broadcast_equals(self, unit, dtype):
    +        # TODO: does this use indexes?
             left_array1 = np.ones(shape=(2, 3), dtype=dtype) * unit_registry.m
             left_array2 = np.zeros(shape=(3, 6), dtype=dtype) * unit_registry.m
     
             right_array1 = np.ones(shape=(2,)) * unit
    -        right_array2 = np.ones(shape=(3,)) * unit
    +        right_array2 = np.zeros(shape=(3,)) * unit
     
             left = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(data=left_array1, dims=("x", "y")),
    -                "b": xr.DataArray(data=left_array2, dims=("y", "z")),
    -            }
    -        )
    -        right = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(data=right_array1, dims="x"),
    -                "b": xr.DataArray(data=right_array2, dims="y"),
    -            }
    +            {"a": (("x", "y"), left_array1), "b": (("y", "z"), left_array2)},
             )
    +        right = xr.Dataset({"a": ("x", right_array1), "b": ("y", right_array2)})
     
    -        units = {
    -            **extract_units(left),
    -            **({} if left_array1.check(unit) else {"a": None, "b": None}),
    -        }
    -        expected = strip_units(left).broadcast_equals(
    -            strip_units(convert_units(right, units))
    -        ) & left_array1.check(unit)
    +        units = merge_mappings(
    +            extract_units(left),
    +            {} if is_compatible(left_array1, unit) else {"a": None, "b": None},
    +        )
    +        expected = is_compatible(left_array1, unit) and strip_units(
    +            left
    +        ).broadcast_equals(strip_units(convert_units(right, units)))
             actual = left.broadcast_equals(right)
     
             assert expected == actual
     
    +    def test_pad(self, dtype):
    +        a = np.linspace(0, 5, 10).astype(dtype) * unit_registry.Pa
    +        b = np.linspace(-5, 0, 10).astype(dtype) * unit_registry.degK
    +
    +        ds = xr.Dataset({"a": ("x", a), "b": ("x", b)})
    +        units = extract_units(ds)
    +
    +        expected = attach_units(strip_units(ds).pad(x=(2, 3)), units)
    +        actual = ds.pad(x=(2, 3))
    +
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
    +
         @pytest.mark.parametrize(
             "func",
             (method("unstack"), method("reset_index", "v"), method("reorder_levels")),
             ids=repr,
         )
    -    def test_stacking_stacked(self, func, dtype):
    -        array1 = (
    -            np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * unit_registry.m
    -        )
    +    @pytest.mark.parametrize(
    +        "variant",
    +        (
    +            "data",
    +            pytest.param(
    +                "dims",
    +                marks=pytest.mark.skip(reason="indexes don't support units"),
    +            ),
    +        ),
    +    )
    +    def test_stacking_stacked(self, variant, func, dtype):
    +        variants = {
    +            "data": (unit_registry.m, 1),
    +            "dims": (1, unit_registry.m),
    +        }
    +        data_unit, dim_unit = variants.get(variant)
    +
    +        array1 = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit
             array2 = (
                 np.linspace(-10, 0, 5 * 10 * 15).reshape(5, 10, 15).astype(dtype)
    -            * unit_registry.m
    +            * data_unit
             )
     
    -        x = np.arange(array1.shape[0])
    -        y = np.arange(array1.shape[1])
    -        z = np.arange(array2.shape[2])
    +        x = np.arange(array1.shape[0]) * dim_unit
    +        y = np.arange(array1.shape[1]) * dim_unit
    +        z = np.arange(array2.shape[2]) * dim_unit
     
             ds = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(data=array1, dims=("x", "y")),
    -                "b": xr.DataArray(data=array2, dims=("x", "y", "z")),
    -            },
    +            data_vars={"a": (("x", "y"), array1), "b": (("x", "y", "z"), array2)},
                 coords={"x": x, "y": y, "z": z},
             )
    +        units = extract_units(ds)
     
             stacked = ds.stack(v=("x", "y"))
     
    -        expected = attach_units(
    -            func(strip_units(stacked)), {"a": unit_registry.m, "b": unit_registry.m}
    -        )
    +        expected = attach_units(func(strip_units(stacked)), units)
             actual = func(stacked)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
    -    @pytest.mark.xfail(reason="does not work with quantities yet")
    +    @pytest.mark.xfail(
    +        reason="stacked dimension's labels have to be hashable, but is a numpy.array"
    +    )
         def test_to_stacked_array(self, dtype):
    -        labels = np.arange(5).astype(dtype) * unit_registry.s
    -        arrays = {name: np.linspace(0, 1, 10) * unit_registry.m for name in labels}
    +        labels = range(5) * unit_registry.s
    +        arrays = {
    +            name: np.linspace(0, 1, 10).astype(dtype) * unit_registry.m
    +            for name in labels
    +        }
     
    -        ds = xr.Dataset(
    -            data_vars={
    -                name: xr.DataArray(data=array, dims="x")
    -                for name, array in arrays.items()
    -            }
    -        )
    +        ds = xr.Dataset({name: ("x", array) for name, array in arrays.items()})
    +        units = {None: unit_registry.m, "y": unit_registry.s}
     
             func = method("to_stacked_array", "z", variable_dim="y", sample_dims=["x"])
     
             actual = func(ds).rename(None)
             expected = attach_units(
                 func(strip_units(ds)).rename(None),
    -            {None: unit_registry.m, "y": unit_registry.s},
    +            units,
             )
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
         @pytest.mark.parametrize(
             "func",
             (
                 method("transpose", "y", "x", "z1", "z2"),
    -            method("stack", a=("x", "y")),
    +            method("stack", u=("x", "y")),
                 method("set_index", x="x2"),
    +            method("shift", x=2),
                 pytest.param(
    -                method("shift", x=2),
    -                marks=pytest.mark.xfail(reason="tries to concatenate nan arrays"),
    +                method("rank", dim="x"),
    +                marks=pytest.mark.skip(reason="rank not implemented for non-ndarray"),
                 ),
                 method("roll", x=2, roll_coords=False),
                 method("sortby", "x2"),
    @@ -4508,20 +4743,19 @@ def test_stacking_reordering(self, func, dtype):
     
             ds = xr.Dataset(
                 data_vars={
    -                "a": xr.DataArray(data=array1, dims=("x", "y", "z1")),
    -                "b": xr.DataArray(data=array2, dims=("x", "y", "z2")),
    +                "a": (("x", "y", "z1"), array1),
    +                "b": (("x", "y", "z2"), array2),
                 },
                 coords={"x": x, "y": y, "z1": z1, "z2": z2, "x2": ("x", x2)},
             )
    +        units = extract_units(ds)
     
    -        expected = attach_units(
    -            func(strip_units(ds)), {"a": unit_registry.Pa, "b": unit_registry.degK}
    -        )
    +        expected = attach_units(func(strip_units(ds)), units)
             actual = func(ds)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
    -    @pytest.mark.xfail(reason="indexes strip units")
         @pytest.mark.parametrize(
             "indices",
             (
    @@ -4533,24 +4767,16 @@ def test_isel(self, indices, dtype):
             array1 = np.arange(10).astype(dtype) * unit_registry.s
             array2 = np.linspace(0, 1, 10).astype(dtype) * unit_registry.Pa
     
    -        x = np.arange(len(array1)) * unit_registry.m
    -        ds = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(data=array1, dims="x"),
    -                "b": xr.DataArray(data=array2, dims="x"),
    -            },
    -            coords={"x": x},
    -        )
    +        ds = xr.Dataset(data_vars={"a": ("x", array1), "b": ("x", array2)})
    +        units = extract_units(ds)
     
    -        expected = attach_units(
    -            strip_units(ds).isel(x=indices),
    -            {"a": unit_registry.s, "b": unit_registry.Pa, "x": unit_registry.m},
    -        )
    +        expected = attach_units(strip_units(ds).isel(x=indices), units)
             actual = ds.isel(x=indices)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
    -    @pytest.mark.xfail(reason="indexes don't support units")
    +    @pytest.mark.skip(reason="indexes don't support units")
         @pytest.mark.parametrize(
             "raw_values",
             (
    @@ -4565,7 +4791,7 @@ def test_isel(self, indices, dtype):
                 pytest.param(1, KeyError, id="no_units"),
                 pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"),
                 pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"),
    -            pytest.param(unit_registry.dm, KeyError, id="compatible_unit"),
    +            pytest.param(unit_registry.mm, KeyError, id="compatible_unit"),
                 pytest.param(unit_registry.m, None, id="identical_unit"),
             ),
         )
    @@ -4584,22 +4810,26 @@ def test_sel(self, raw_values, unit, error, dtype):
     
             values = raw_values * unit
     
    -        if error is not None and not (
    -            isinstance(raw_values, (int, float)) and x.check(unit)
    -        ):
    +        # TODO: if we choose dm as compatible unit, single value keys
    +        # can be found. Should we check that?
    +        if error is not None:
                 with pytest.raises(error):
                     ds.sel(x=values)
     
                 return
     
             expected = attach_units(
    -            strip_units(ds).sel(x=strip_units(convert_units(values, {None: x.units}))),
    -            {"a": array1.units, "b": array2.units, "x": x.units},
    +            strip_units(ds).sel(
    +                x=strip_units(convert_units(values, {None: unit_registry.m}))
    +            ),
    +            extract_units(ds),
             )
             actual = ds.sel(x=values)
    -        assert_equal_with_units(expected, actual)
     
    -    @pytest.mark.xfail(reason="indexes don't support units")
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
    +
    +    @pytest.mark.skip(reason="indexes don't support units")
         @pytest.mark.parametrize(
             "raw_values",
             (
    @@ -4614,7 +4844,7 @@ def test_sel(self, raw_values, unit, error, dtype):
                 pytest.param(1, KeyError, id="no_units"),
                 pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"),
                 pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"),
    -            pytest.param(unit_registry.dm, KeyError, id="compatible_unit"),
    +            pytest.param(unit_registry.mm, KeyError, id="compatible_unit"),
                 pytest.param(unit_registry.m, None, id="identical_unit"),
             ),
         )
    @@ -4633,9 +4863,9 @@ def test_drop_sel(self, raw_values, unit, error, dtype):
     
             values = raw_values * unit
     
    -        if error is not None and not (
    -            isinstance(raw_values, (int, float)) and x.check(unit)
    -        ):
    +        # TODO: if we choose dm as compatible unit, single value keys
    +        # can be found. Should we check that?
    +        if error is not None:
                 with pytest.raises(error):
                     ds.drop_sel(x=values)
     
    @@ -4643,14 +4873,16 @@ def test_drop_sel(self, raw_values, unit, error, dtype):
     
             expected = attach_units(
                 strip_units(ds).drop_sel(
    -                x=strip_units(convert_units(values, {None: x.units}))
    +                x=strip_units(convert_units(values, {None: unit_registry.m}))
                 ),
                 extract_units(ds),
             )
             actual = ds.drop_sel(x=values)
    -        assert_equal_with_units(expected, actual)
     
    -    @pytest.mark.xfail(reason="indexes don't support units")
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
    +
    +    @pytest.mark.skip(reason="indexes don't support units")
         @pytest.mark.parametrize(
             "raw_values",
             (
    @@ -4665,7 +4897,7 @@ def test_drop_sel(self, raw_values, unit, error, dtype):
                 pytest.param(1, KeyError, id="no_units"),
                 pytest.param(unit_registry.dimensionless, KeyError, id="dimensionless"),
                 pytest.param(unit_registry.degree, KeyError, id="incompatible_unit"),
    -            pytest.param(unit_registry.dm, KeyError, id="compatible_unit"),
    +            pytest.param(unit_registry.mm, KeyError, id="compatible_unit"),
                 pytest.param(unit_registry.m, None, id="identical_unit"),
             ),
         )
    @@ -4684,9 +4916,9 @@ def test_loc(self, raw_values, unit, error, dtype):
     
             values = raw_values * unit
     
    -        if error is not None and not (
    -            isinstance(raw_values, (int, float)) and x.check(unit)
    -        ):
    +        # TODO: if we choose dm as compatible unit, single value keys
    +        # can be found. Should we check that?
    +        if error is not None:
                 with pytest.raises(error):
                     ds.loc[{"x": values}]
     
    @@ -4694,12 +4926,14 @@ def test_loc(self, raw_values, unit, error, dtype):
     
             expected = attach_units(
                 strip_units(ds).loc[
    -                {"x": strip_units(convert_units(values, {None: x.units}))}
    +                {"x": strip_units(convert_units(values, {None: unit_registry.m}))}
                 ],
    -            {"a": array1.units, "b": array2.units, "x": x.units},
    +            extract_units(ds),
             )
             actual = ds.loc[{"x": values}]
    -        assert_equal_with_units(expected, actual)
    +
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
         @pytest.mark.parametrize(
             "func",
    @@ -4710,14 +4944,34 @@ def test_loc(self, raw_values, unit, error, dtype):
             ),
             ids=repr,
         )
    -    def test_head_tail_thin(self, func, dtype):
    -        array1 = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK
    -        array2 = np.linspace(1, 2, 10 * 8).reshape(10, 8) * unit_registry.Pa
    +    @pytest.mark.parametrize(
    +        "variant",
    +        (
    +            "data",
    +            pytest.param(
    +                "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +            ),
    +            "coords",
    +        ),
    +    )
    +    def test_head_tail_thin(self, func, variant, dtype):
    +        variants = {
    +            "data": ((unit_registry.degK, unit_registry.Pa), 1, 1),
    +            "dims": ((1, 1), unit_registry.m, 1),
    +            "coords": ((1, 1), 1, unit_registry.m),
    +        }
    +        (unit_a, unit_b), dim_unit, coord_unit = variants.get(variant)
    +
    +        array1 = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_a
    +        array2 = np.linspace(1, 2, 10 * 8).reshape(10, 8) * unit_b
     
             coords = {
    -            "x": np.arange(10) * unit_registry.m,
    -            "y": np.arange(5) * unit_registry.m,
    -            "z": np.arange(8) * unit_registry.m,
    +            "x": np.arange(10) * dim_unit,
    +            "y": np.arange(5) * dim_unit,
    +            "z": np.arange(8) * dim_unit,
    +            "u": ("x", np.linspace(0, 1, 10) * coord_unit),
    +            "v": ("y", np.linspace(1, 2, 5) * coord_unit),
    +            "w": ("z", np.linspace(-1, 0, 8) * coord_unit),
             }
     
             ds = xr.Dataset(
    @@ -4731,8 +4985,10 @@ def test_head_tail_thin(self, func, dtype):
             expected = attach_units(func(strip_units(ds)), extract_units(ds))
             actual = func(ds)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
    +    @pytest.mark.parametrize("dim", ("x", "y", "z", "t", "all"))
         @pytest.mark.parametrize(
             "shape",
             (
    @@ -4743,13 +4999,9 @@ def test_head_tail_thin(self, func, dtype):
                 pytest.param((1, 10, 1, 20), id="first and last dimension squeezable"),
             ),
         )
    -    def test_squeeze(self, shape, dtype):
    +    def test_squeeze(self, shape, dim, dtype):
             names = "xyzt"
    -        coords = {
    -            name: np.arange(length).astype(dtype)
    -            * (unit_registry.m if name != "t" else unit_registry.s)
    -            for name, length in zip(names, shape)
    -        }
    +        dim_lengths = dict(zip(names, shape))
             array1 = (
                 np.linspace(0, 1, 10 * 20).astype(dtype).reshape(shape) * unit_registry.degK
             )
    @@ -4759,74 +5011,59 @@ def test_squeeze(self, shape, dtype):
     
             ds = xr.Dataset(
                 data_vars={
    -                "a": xr.DataArray(data=array1, dims=tuple(names[: len(shape)])),
    -                "b": xr.DataArray(data=array2, dims=tuple(names[: len(shape)])),
    +                "a": (tuple(names[: len(shape)]), array1),
    +                "b": (tuple(names[: len(shape)]), array2),
                 },
    -            coords=coords,
             )
             units = extract_units(ds)
     
    -        expected = attach_units(strip_units(ds).squeeze(), units)
    +        kwargs = {"dim": dim} if dim != "all" and dim_lengths.get(dim, 0) == 1 else {}
     
    -        actual = ds.squeeze()
    -        assert_equal_with_units(actual, expected)
    +        expected = attach_units(strip_units(ds).squeeze(**kwargs), units)
     
    -        # try squeezing the dimensions separately
    -        names = tuple(dim for dim, coord in coords.items() if len(coord) == 1)
    -        for name in names:
    -            expected = attach_units(strip_units(ds).squeeze(dim=name), units)
    -            actual = ds.squeeze(dim=name)
    -            assert_equal_with_units(actual, expected)
    +        actual = ds.squeeze(**kwargs)
     
    -    @pytest.mark.xfail(reason="ignores units")
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
    +
    +    @pytest.mark.parametrize("variant", ("data", "coords"))
         @pytest.mark.parametrize(
    -        "unit,error",
    +        "func",
             (
    -            pytest.param(1, DimensionalityError, id="no_unit"),
                 pytest.param(
    -                unit_registry.dimensionless, DimensionalityError, id="dimensionless"
    +                method("interp"), marks=pytest.mark.xfail(reason="uses scipy")
                 ),
    -            pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
    -            pytest.param(unit_registry.cm, None, id="compatible_unit"),
    -            pytest.param(unit_registry.m, None, id="identical_unit"),
    +            method("reindex"),
             ),
    +        ids=repr,
         )
    -    def test_interp(self, unit, error):
    -        array1 = np.linspace(1, 2, 10 * 5).reshape(10, 5) * unit_registry.degK
    -        array2 = np.linspace(1, 2, 10 * 8).reshape(10, 8) * unit_registry.Pa
    -
    -        coords = {
    -            "x": np.arange(10) * unit_registry.m,
    -            "y": np.arange(5) * unit_registry.m,
    -            "z": np.arange(8) * unit_registry.s,
    +    def test_interp_reindex(self, func, variant, dtype):
    +        variants = {
    +            "data": (unit_registry.m, 1),
    +            "coords": (1, unit_registry.m),
             }
    +        data_unit, coord_unit = variants.get(variant)
     
    -        ds = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(data=array1, dims=("x", "y")),
    -                "b": xr.DataArray(data=array2, dims=("x", "z")),
    -            },
    -            coords=coords,
    -        )
    +        array1 = np.linspace(-1, 0, 10).astype(dtype) * data_unit
    +        array2 = np.linspace(0, 1, 10).astype(dtype) * data_unit
     
    -        new_coords = (np.arange(10) + 0.5) * unit
    +        y = np.arange(10) * coord_unit
     
    -        if error is not None:
    -            with pytest.raises(error):
    -                ds.interp(x=new_coords)
    +        x = np.arange(10)
    +        new_x = np.arange(8) + 0.5
     
    -            return
    -
    -        units = extract_units(ds)
    -        expected = attach_units(
    -            strip_units(ds).interp(x=strip_units(convert_units(new_coords, units))),
    -            units,
    +        ds = xr.Dataset(
    +            {"a": ("x", array1), "b": ("x", array2)}, coords={"x": x, "y": ("x", y)}
             )
    -        actual = ds.interp(x=new_coords)
    +        units = extract_units(ds)
    +
    +        expected = attach_units(func(strip_units(ds), x=new_x), units)
    +        actual = func(ds, x=new_x)
     
    -        assert_equal_with_units(actual, expected)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
    -    @pytest.mark.xfail(reason="ignores units")
    +    @pytest.mark.skip(reason="indexes don't support units")
         @pytest.mark.parametrize(
             "unit,error",
             (
    @@ -4839,108 +5076,69 @@ def test_interp(self, unit, error):
                 pytest.param(unit_registry.m, None, id="identical_unit"),
             ),
         )
    -    def test_interp_like(self, unit, error, dtype):
    -        array1 = (
    -            np.linspace(0, 10, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK
    -        )
    -        array2 = (
    -            np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa
    -        )
    -
    -        coords = {
    -            "x": np.arange(10) * unit_registry.m,
    -            "y": np.arange(5) * unit_registry.m,
    -            "z": np.arange(8) * unit_registry.m,
    -        }
    +    @pytest.mark.parametrize("func", (method("interp"), method("reindex")), ids=repr)
    +    def test_interp_reindex_indexing(self, func, unit, error, dtype):
    +        array1 = np.linspace(-1, 0, 10).astype(dtype)
    +        array2 = np.linspace(0, 1, 10).astype(dtype)
     
    -        ds = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(data=array1, dims=("x", "y")),
    -                "b": xr.DataArray(data=array2, dims=("x", "z")),
    -            },
    -            coords=coords,
    -        )
    +        x = np.arange(10) * unit_registry.m
    +        new_x = (np.arange(8) + 0.5) * unit
     
    -        other = xr.Dataset(
    -            data_vars={
    -                "c": xr.DataArray(data=np.empty((20, 10)), dims=("x", "y")),
    -                "d": xr.DataArray(data=np.empty((20, 15)), dims=("x", "z")),
    -            },
    -            coords={
    -                "x": (np.arange(20) + 0.3) * unit,
    -                "y": (np.arange(10) - 0.2) * unit,
    -                "z": (np.arange(15) + 0.4) * unit,
    -            },
    -        )
    +        ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}, coords={"x": x})
    +        units = extract_units(ds)
     
             if error is not None:
                 with pytest.raises(error):
    -                ds.interp_like(other)
    +                func(ds, x=new_x)
     
                 return
     
    -        units = extract_units(ds)
    -        expected = attach_units(
    -            strip_units(ds).interp_like(strip_units(convert_units(other, units))), units
    -        )
    -        actual = ds.interp_like(other)
    +        expected = attach_units(func(strip_units(ds), x=new_x), units)
    +        actual = func(ds, x=new_x)
     
    -        assert_equal_with_units(actual, expected)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
    -    @pytest.mark.xfail(reason="indexes don't support units")
    +    @pytest.mark.parametrize("variant", ("data", "coords"))
         @pytest.mark.parametrize(
    -        "unit,error",
    +        "func",
             (
    -            pytest.param(1, DimensionalityError, id="no_unit"),
                 pytest.param(
    -                unit_registry.dimensionless, DimensionalityError, id="dimensionless"
    +                method("interp_like"), marks=pytest.mark.xfail(reason="uses scipy")
                 ),
    -            pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"),
    -            pytest.param(unit_registry.cm, None, id="compatible_unit"),
    -            pytest.param(unit_registry.m, None, id="identical_unit"),
    +            method("reindex_like"),
             ),
    +        ids=repr,
         )
    -    def test_reindex(self, unit, error, dtype):
    -        array1 = (
    -            np.linspace(1, 2, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK
    -        )
    -        array2 = (
    -            np.linspace(1, 2, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa
    -        )
    -
    -        coords = {
    -            "x": np.arange(10) * unit_registry.m,
    -            "y": np.arange(5) * unit_registry.m,
    -            "z": np.arange(8) * unit_registry.s,
    +    def test_interp_reindex_like(self, func, variant, dtype):
    +        variants = {
    +            "data": (unit_registry.m, 1),
    +            "coords": (1, unit_registry.m),
             }
    +        data_unit, coord_unit = variants.get(variant)
     
    -        ds = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(data=array1, dims=("x", "y")),
    -                "b": xr.DataArray(data=array2, dims=("x", "z")),
    -            },
    -            coords=coords,
    -        )
    -
    -        new_coords = (np.arange(10) + 0.5) * unit
    +        array1 = np.linspace(-1, 0, 10).astype(dtype) * data_unit
    +        array2 = np.linspace(0, 1, 10).astype(dtype) * data_unit
     
    -        if error is not None:
    -            with pytest.raises(error):
    -                ds.reindex(x=new_coords)
    +        y = np.arange(10) * coord_unit
     
    -            return
    +        x = np.arange(10)
    +        new_x = np.arange(8) + 0.5
     
    -        expected = attach_units(
    -            strip_units(ds).reindex(
    -                x=strip_units(convert_units(new_coords, {None: coords["x"].units}))
    -            ),
    -            extract_units(ds),
    +        ds = xr.Dataset(
    +            {"a": ("x", array1), "b": ("x", array2)}, coords={"x": x, "y": ("x", y)}
             )
    -        actual = ds.reindex(x=new_coords)
    +        units = extract_units(ds)
    +
    +        other = xr.Dataset({"a": ("x", np.empty_like(new_x))}, coords={"x": new_x})
    +
    +        expected = attach_units(func(strip_units(ds), other), units)
    +        actual = func(ds, other)
     
    -        assert_equal_with_units(actual, expected)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
    -    @pytest.mark.xfail(reason="indexes don't support units")
    +    @pytest.mark.skip(reason="indexes don't support units")
         @pytest.mark.parametrize(
             "unit,error",
             (
    @@ -4953,54 +5151,32 @@ def test_reindex(self, unit, error, dtype):
                 pytest.param(unit_registry.m, None, id="identical_unit"),
             ),
         )
    -    def test_reindex_like(self, unit, error, dtype):
    -        array1 = (
    -            np.linspace(0, 10, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK
    -        )
    -        array2 = (
    -            np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa
    -        )
    +    @pytest.mark.parametrize(
    +        "func", (method("interp_like"), method("reindex_like")), ids=repr
    +    )
    +    def test_interp_reindex_like_indexing(self, func, unit, error, dtype):
    +        array1 = np.linspace(-1, 0, 10).astype(dtype)
    +        array2 = np.linspace(0, 1, 10).astype(dtype)
     
    -        coords = {
    -            "x": np.arange(10) * unit_registry.m,
    -            "y": np.arange(5) * unit_registry.m,
    -            "z": np.arange(8) * unit_registry.m,
    -        }
    +        x = np.arange(10) * unit_registry.m
    +        new_x = (np.arange(8) + 0.5) * unit
     
    -        ds = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(data=array1, dims=("x", "y")),
    -                "b": xr.DataArray(data=array2, dims=("x", "z")),
    -            },
    -            coords=coords,
    -        )
    +        ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}, coords={"x": x})
    +        units = extract_units(ds)
     
    -        other = xr.Dataset(
    -            data_vars={
    -                "c": xr.DataArray(data=np.empty((20, 10)), dims=("x", "y")),
    -                "d": xr.DataArray(data=np.empty((20, 15)), dims=("x", "z")),
    -            },
    -            coords={
    -                "x": (np.arange(20) + 0.3) * unit,
    -                "y": (np.arange(10) - 0.2) * unit,
    -                "z": (np.arange(15) + 0.4) * unit,
    -            },
    -        )
    +        other = xr.Dataset({"a": ("x", np.empty_like(new_x))}, coords={"x": new_x})
     
             if error is not None:
                 with pytest.raises(error):
    -                ds.reindex_like(other)
    +                func(ds, other)
     
                 return
     
    -        units = extract_units(ds)
    -        expected = attach_units(
    -            strip_units(ds).reindex_like(strip_units(convert_units(other, units))),
    -            units,
    -        )
    -        actual = ds.reindex_like(other)
    +        expected = attach_units(func(strip_units(ds), other), units)
    +        actual = func(ds, other)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
         @pytest.mark.parametrize(
             "func",
    @@ -5008,32 +5184,42 @@ def test_reindex_like(self, unit, error, dtype):
                 method("diff", dim="x"),
                 method("differentiate", coord="x"),
                 method("integrate", coord="x"),
    -            pytest.param(
    -                method("quantile", q=[0.25, 0.75]),
    -                marks=pytest.mark.xfail(reason="nanquantile not implemented"),
    -            ),
    +            method("quantile", q=[0.25, 0.75]),
                 method("reduce", func=np.sum, dim="x"),
                 method("map", np.fabs),
             ),
             ids=repr,
         )
    -    def test_computation(self, func, dtype):
    -        array1 = (
    -            np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK
    -        )
    -        array2 = (
    -            np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa
    -        )
    -        x = np.arange(10) * unit_registry.m
    -        y = np.arange(5) * unit_registry.m
    -        z = np.arange(8) * unit_registry.m
    +    @pytest.mark.parametrize(
    +        "variant",
    +        (
    +            "data",
    +            pytest.param(
    +                "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +            ),
    +            "coords",
    +        ),
    +    )
    +    def test_computation(self, func, variant, dtype):
    +        variants = {
    +            "data": ((unit_registry.degK, unit_registry.Pa), 1, 1),
    +            "dims": ((1, 1), unit_registry.m, 1),
    +            "coords": ((1, 1), 1, unit_registry.m),
    +        }
    +        (unit1, unit2), dim_unit, coord_unit = variants.get(variant)
    +
    +        array1 = np.linspace(-5, 5, 4 * 5).reshape(4, 5).astype(dtype) * unit1
    +        array2 = np.linspace(10, 20, 4 * 3).reshape(4, 3).astype(dtype) * unit2
    +        x = np.arange(4) * dim_unit
    +        y = np.arange(5) * dim_unit
    +        z = np.arange(3) * dim_unit
     
             ds = xr.Dataset(
                 data_vars={
                     "a": xr.DataArray(data=array1, dims=("x", "y")),
                     "b": xr.DataArray(data=array2, dims=("x", "z")),
                 },
    -            coords={"x": x, "y": y, "z": z},
    +            coords={"x": x, "y": y, "z": z, "y2": ("y", np.arange(5) * coord_unit)},
             )
     
             units = extract_units(ds)
    @@ -5041,69 +5227,96 @@ def test_computation(self, func, dtype):
             expected = attach_units(func(strip_units(ds)), units)
             actual = func(ds)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
         @pytest.mark.parametrize(
             "func",
             (
                 method("groupby", "x"),
    -            method("groupby_bins", "x", bins=4),
    +            method("groupby_bins", "x", bins=2),
                 method("coarsen", x=2),
                 pytest.param(
                     method("rolling", x=3), marks=pytest.mark.xfail(reason="strips units")
                 ),
                 pytest.param(
                     method("rolling_exp", x=3),
    -                marks=pytest.mark.xfail(reason="uses numbagg which strips units"),
    +                marks=pytest.mark.xfail(
    +                    reason="numbagg functions are not supported by pint"
    +                ),
                 ),
    +            method("weighted", xr.DataArray(data=np.linspace(0, 1, 5), dims="y")),
             ),
             ids=repr,
         )
    -    def test_computation_objects(self, func, dtype):
    -        array1 = (
    -            np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK
    -        )
    -        array2 = (
    -            np.linspace(10, 20, 10 * 5 * 8).reshape(10, 5, 8).astype(dtype)
    -            * unit_registry.Pa
    -        )
    -        x = np.arange(10) * unit_registry.m
    -        y = np.arange(5) * unit_registry.m
    -        z = np.arange(8) * unit_registry.m
    +    @pytest.mark.parametrize(
    +        "variant",
    +        (
    +            "data",
    +            pytest.param(
    +                "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +            ),
    +            "coords",
    +        ),
    +    )
    +    def test_computation_objects(self, func, variant, dtype):
    +        variants = {
    +            "data": ((unit_registry.degK, unit_registry.Pa), 1, 1),
    +            "dims": ((1, 1), unit_registry.m, 1),
    +            "coords": ((1, 1), 1, unit_registry.m),
    +        }
    +        (unit1, unit2), dim_unit, coord_unit = variants.get(variant)
    +
    +        array1 = np.linspace(-5, 5, 4 * 5).reshape(4, 5).astype(dtype) * unit1
    +        array2 = np.linspace(10, 20, 4 * 3).reshape(4, 3).astype(dtype) * unit2
    +        x = np.arange(4) * dim_unit
    +        y = np.arange(5) * dim_unit
    +        z = np.arange(3) * dim_unit
     
             ds = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(data=array1, dims=("x", "y")),
    -                "b": xr.DataArray(data=array2, dims=("x", "y", "z")),
    -            },
    -            coords={"x": x, "y": y, "z": z},
    +            data_vars={"a": (("x", "y"), array1), "b": (("x", "z"), array2)},
    +            coords={"x": x, "y": y, "z": z, "y2": ("y", np.arange(5) * coord_unit)},
             )
             units = extract_units(ds)
     
             args = [] if func.name != "groupby" else ["y"]
    -        reduce_func = method("mean", *args)
    -        expected = attach_units(reduce_func(func(strip_units(ds))), units)
    -        actual = reduce_func(func(ds))
    +        expected = attach_units(func(strip_units(ds)).mean(*args), units)
    +        actual = func(ds).mean(*args)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_allclose(expected, actual)
    +
    +    @pytest.mark.parametrize(
    +        "variant",
    +        (
    +            "data",
    +            pytest.param(
    +                "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +            ),
    +            "coords",
    +        ),
    +    )
    +    def test_resample(self, variant, dtype):
    +        # TODO: move this to test_computation_objects
    +        variants = {
    +            "data": ((unit_registry.degK, unit_registry.Pa), 1, 1),
    +            "dims": ((1, 1), unit_registry.m, 1),
    +            "coords": ((1, 1), 1, unit_registry.m),
    +        }
    +        (unit1, unit2), dim_unit, coord_unit = variants.get(variant)
    +
    +        array1 = np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit1
    +        array2 = np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit2
     
    -    def test_resample(self, dtype):
    -        array1 = (
    -            np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK
    -        )
    -        array2 = (
    -            np.linspace(10, 20, 10 * 8).reshape(10, 8).astype(dtype) * unit_registry.Pa
    -        )
             t = pd.date_range("10-09-2010", periods=array1.shape[0], freq="1y")
    -        y = np.arange(5) * unit_registry.m
    -        z = np.arange(8) * unit_registry.m
    +        y = np.arange(5) * dim_unit
    +        z = np.arange(8) * dim_unit
    +
    +        u = np.linspace(-1, 0, 5) * coord_unit
     
             ds = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(data=array1, dims=("time", "y")),
    -                "b": xr.DataArray(data=array2, dims=("time", "z")),
    -            },
    -            coords={"time": t, "y": y, "z": z},
    +            data_vars={"a": (("time", "y"), array1), "b": (("time", "z"), array2)},
    +            coords={"time": t, "y": y, "z": z, "u": ("y", u)},
             )
             units = extract_units(ds)
     
    @@ -5112,43 +5325,53 @@ def test_resample(self, dtype):
             expected = attach_units(func(strip_units(ds)).mean(), units)
             actual = func(ds).mean()
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
         @pytest.mark.parametrize(
             "func",
             (
                 method("assign", c=lambda ds: 10 * ds.b),
    -            method("assign_coords", v=("x", np.arange(10) * unit_registry.s)),
    +            method("assign_coords", v=("x", np.arange(5) * unit_registry.s)),
                 method("first"),
                 method("last"),
    +            method("quantile", q=[0.25, 0.5, 0.75], dim="x"),
    +        ),
    +        ids=repr,
    +    )
    +    @pytest.mark.parametrize(
    +        "variant",
    +        (
    +            "data",
                 pytest.param(
    -                method("quantile", q=[0.25, 0.5, 0.75], dim="x"),
    -                marks=pytest.mark.xfail(reason="nanquantile not implemented"),
    +                "dims", marks=pytest.mark.skip(reason="indexes don't support units")
                 ),
    +            "coords",
             ),
    -        ids=repr,
         )
    -    def test_grouped_operations(self, func, dtype):
    -        array1 = (
    -            np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype) * unit_registry.degK
    -        )
    -        array2 = (
    -            np.linspace(10, 20, 10 * 5 * 8).reshape(10, 5, 8).astype(dtype)
    -            * unit_registry.Pa
    -        )
    -        x = np.arange(10) * unit_registry.m
    -        y = np.arange(5) * unit_registry.m
    -        z = np.arange(8) * unit_registry.m
    +    def test_grouped_operations(self, func, variant, dtype):
    +        variants = {
    +            "data": ((unit_registry.degK, unit_registry.Pa), 1, 1),
    +            "dims": ((1, 1), unit_registry.m, 1),
    +            "coords": ((1, 1), 1, unit_registry.m),
    +        }
    +        (unit1, unit2), dim_unit, coord_unit = variants.get(variant)
    +
    +        array1 = np.linspace(-5, 5, 5 * 4).reshape(5, 4).astype(dtype) * unit1
    +        array2 = np.linspace(10, 20, 5 * 4 * 3).reshape(5, 4, 3).astype(dtype) * unit2
    +        x = np.arange(5) * dim_unit
    +        y = np.arange(4) * dim_unit
    +        z = np.arange(3) * dim_unit
    +
    +        u = np.linspace(-1, 0, 4) * coord_unit
     
             ds = xr.Dataset(
    -            data_vars={
    -                "a": xr.DataArray(data=array1, dims=("x", "y")),
    -                "b": xr.DataArray(data=array2, dims=("x", "y", "z")),
    -            },
    -            coords={"x": x, "y": y, "z": z},
    +            data_vars={"a": (("x", "y"), array1), "b": (("x", "y", "z"), array2)},
    +            coords={"x": x, "y": y, "z": z, "u": ("y", u)},
             )
    -        units = extract_units(ds)
    -        units.update({"c": unit_registry.Pa, "v": unit_registry.s})
    +
    +        assigned_units = {"c": unit2, "v": unit_registry.s}
    +        units = merge_mappings(extract_units(ds), assigned_units)
     
             stripped_kwargs = {
                 name: strip_units(value) for name, value in func.kwargs.items()
    @@ -5158,20 +5381,26 @@ def test_grouped_operations(self, func, dtype):
             )
             actual = func(ds.groupby("y"))
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
         @pytest.mark.parametrize(
             "func",
             (
                 method("pipe", lambda ds: ds * 10),
                 method("assign", d=lambda ds: ds.b * 10),
    -            method("assign_coords", y2=("y", np.arange(5) * unit_registry.mm)),
    +            method("assign_coords", y2=("y", np.arange(4) * unit_registry.mm)),
                 method("assign_attrs", attr1="value"),
                 method("rename", x2="x_mm"),
                 method("rename_vars", c="temperature"),
                 method("rename_dims", x="offset_x"),
    -            method("swap_dims", {"x": "x2"}),
    -            method("expand_dims", v=np.linspace(10, 20, 12) * unit_registry.s, axis=1),
    +            method("swap_dims", {"x": "u"}),
    +            pytest.param(
    +                method(
    +                    "expand_dims", v=np.linspace(10, 20, 12) * unit_registry.s, axis=1
    +                ),
    +                marks=pytest.mark.skip(reason="indexes don't support units"),
    +            ),
                 method("drop_vars", "x"),
                 method("drop_dims", "z"),
                 method("set_coords", names="c"),
    @@ -5180,40 +5409,55 @@ def test_grouped_operations(self, func, dtype):
             ),
             ids=repr,
         )
    -    def test_content_manipulation(self, func, dtype):
    -        array1 = (
    -            np.linspace(-5, 5, 10 * 5).reshape(10, 5).astype(dtype)
    -            * unit_registry.m ** 3
    -        )
    -        array2 = (
    -            np.linspace(10, 20, 10 * 5 * 8).reshape(10, 5, 8).astype(dtype)
    -            * unit_registry.Pa
    -        )
    -        array3 = np.linspace(0, 10, 10).astype(dtype) * unit_registry.degK
    +    @pytest.mark.parametrize(
    +        "variant",
    +        (
    +            "data",
    +            pytest.param(
    +                "dims", marks=pytest.mark.skip(reason="indexes don't support units")
    +            ),
    +            "coords",
    +        ),
    +    )
    +    def test_content_manipulation(self, func, variant, dtype):
    +        variants = {
    +            "data": (
    +                (unit_registry.m ** 3, unit_registry.Pa, unit_registry.degK),
    +                1,
    +                1,
    +            ),
    +            "dims": ((1, 1, 1), unit_registry.m, 1),
    +            "coords": ((1, 1, 1), 1, unit_registry.m),
    +        }
    +        (unit1, unit2, unit3), dim_unit, coord_unit = variants.get(variant)
     
    -        x = np.arange(10) * unit_registry.m
    -        x2 = x.to(unit_registry.mm)
    -        y = np.arange(5) * unit_registry.m
    -        z = np.arange(8) * unit_registry.m
    +        array1 = np.linspace(-5, 5, 5 * 4).reshape(5, 4).astype(dtype) * unit1
    +        array2 = np.linspace(10, 20, 5 * 4 * 3).reshape(5, 4, 3).astype(dtype) * unit2
    +        array3 = np.linspace(0, 10, 5).astype(dtype) * unit3
    +
    +        x = np.arange(5) * dim_unit
    +        y = np.arange(4) * dim_unit
    +        z = np.arange(3) * dim_unit
    +
    +        x2 = np.linspace(-1, 0, 5) * coord_unit
     
             ds = xr.Dataset(
                 data_vars={
    -                "a": xr.DataArray(data=array1, dims=("x", "y")),
    -                "b": xr.DataArray(data=array2, dims=("x", "y", "z")),
    -                "c": xr.DataArray(data=array3, dims="x"),
    +                "a": (("x", "y"), array1),
    +                "b": (("x", "y", "z"), array2),
    +                "c": ("x", array3),
                 },
                 coords={"x": x, "y": y, "z": z, "x2": ("x", x2)},
             )
    -        units = {
    -            **extract_units(ds),
    -            **{
    -                "y2": unit_registry.mm,
    -                "x_mm": unit_registry.mm,
    -                "offset_x": unit_registry.m,
    -                "d": unit_registry.Pa,
    -                "temperature": unit_registry.degK,
    -            },
    +
    +        new_units = {
    +            "y2": unit_registry.mm,
    +            "x_mm": coord_unit,
    +            "offset_x": unit_registry.m,
    +            "d": unit2,
    +            "temperature": unit3,
             }
    +        units = merge_mappings(extract_units(ds), new_units)
     
             stripped_kwargs = {
                 key: strip_units(value) for key, value in func.kwargs.items()
    @@ -5221,7 +5465,8 @@ def test_content_manipulation(self, func, dtype):
             expected = attach_units(func(strip_units(ds), **stripped_kwargs), units)
             actual = func(ds)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
     
         @pytest.mark.parametrize(
             "unit,error",
    @@ -5240,31 +5485,35 @@ def test_content_manipulation(self, func, dtype):
             (
                 "data",
                 pytest.param(
    -                "dims", marks=pytest.mark.xfail(reason="indexes don't support units")
    +                "dims", marks=pytest.mark.skip(reason="indexes don't support units")
                 ),
                 "coords",
             ),
         )
         def test_merge(self, variant, unit, error, dtype):
    -        original_data_unit = unit_registry.m
    -        original_dim_unit = unit_registry.m
    -        original_coord_unit = unit_registry.m
    +        left_variants = {
    +            "data": (unit_registry.m, 1, 1),
    +            "dims": (1, unit_registry.m, 1),
    +            "coords": (1, 1, unit_registry.m),
    +        }
     
    -        variants = {
    -            "data": (unit, original_dim_unit, original_coord_unit),
    -            "dims": (original_data_unit, unit, original_coord_unit),
    -            "coords": (original_data_unit, original_dim_unit, unit),
    +        left_data_unit, left_dim_unit, left_coord_unit = left_variants.get(variant)
    +
    +        right_variants = {
    +            "data": (unit, 1, 1),
    +            "dims": (1, unit, 1),
    +            "coords": (1, 1, unit),
             }
    -        data_unit, dim_unit, coord_unit = variants.get(variant)
    +        right_data_unit, right_dim_unit, right_coord_unit = right_variants.get(variant)
     
    -        left_array = np.arange(10).astype(dtype) * original_data_unit
    -        right_array = np.arange(-5, 5).astype(dtype) * data_unit
    +        left_array = np.arange(10).astype(dtype) * left_data_unit
    +        right_array = np.arange(-5, 5).astype(dtype) * right_data_unit
     
    -        left_dim = np.arange(10, 20) * original_dim_unit
    -        right_dim = np.arange(5, 15) * dim_unit
    +        left_dim = np.arange(10, 20) * left_dim_unit
    +        right_dim = np.arange(5, 15) * right_dim_unit
     
    -        left_coord = np.arange(-10, 0) * original_coord_unit
    -        right_coord = np.arange(-15, -5) * coord_unit
    +        left_coord = np.arange(-10, 0) * left_coord_unit
    +        right_coord = np.arange(-15, -5) * right_coord_unit
     
             left = xr.Dataset(
                 data_vars={"a": ("x", left_array)},
    @@ -5287,4 +5536,5 @@ def test_merge(self, variant, unit, error, dtype):
             expected = attach_units(strip_units(left).merge(strip_units(converted)), units)
             actual = left.merge(right)
     
    -        assert_equal_with_units(expected, actual)
    +        assert_units_equal(expected, actual)
    +        assert_equal(expected, actual)
    diff --git a/xarray/tests/test_utils.py b/xarray/tests/test_utils.py
    index 5f8b1770bd3..193c45f01cd 100644
    --- a/xarray/tests/test_utils.py
    +++ b/xarray/tests/test_utils.py
    @@ -39,6 +39,33 @@ def test_safe_cast_to_index():
             assert expected.dtype == actual.dtype
     
     
    +@pytest.mark.parametrize(
    +    "a, b, expected", [["a", "b", np.array(["a", "b"])], [1, 2, pd.Index([1, 2])]]
    +)
    +def test_maybe_coerce_to_str(a, b, expected):
    +
    +    a = np.array([a])
    +    b = np.array([b])
    +    index = pd.Index(a).append(pd.Index(b))
    +
    +    actual = utils.maybe_coerce_to_str(index, [a, b])
    +
    +    assert_array_equal(expected, actual)
    +    assert expected.dtype == actual.dtype
    +
    +
    +def test_maybe_coerce_to_str_minimal_str_dtype():
    +
    +    a = np.array(["a", "a_long_string"])
    +    index = pd.Index(["a"])
    +
    +    actual = utils.maybe_coerce_to_str(index, [a])
    +    expected = np.array("a")
    +
    +    assert_array_equal(expected, actual)
    +    assert expected.dtype == actual.dtype
    +
    +
     @requires_cftime
     def test_safe_cast_to_index_cftimeindex():
         date_types = _all_cftime_date_types()
    diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py
    index 78e3848b8fb..e1ae3e1f258 100644
    --- a/xarray/tests/test_variable.py
    +++ b/xarray/tests/test_variable.py
    @@ -294,6 +294,19 @@ def test_object_conversion(self):
             actual = self.cls("x", data)
             assert actual.dtype == data.dtype
     
    +    def test_datetime64_valid_range(self):
    +        data = np.datetime64("1250-01-01", "us")
    +        pderror = pd.errors.OutOfBoundsDatetime
    +        with raises_regex(pderror, "Out of bounds nanosecond"):
    +            self.cls(["t"], [data])
    +
    +    @pytest.mark.xfail(reason="pandas issue 36615")
    +    def test_timedelta64_valid_range(self):
    +        data = np.timedelta64("200000", "D")
    +        pderror = pd.errors.OutOfBoundsTimedelta
    +        with raises_regex(pderror, "Out of bounds nanosecond"):
    +            self.cls(["t"], [data])
    +
         def test_pandas_data(self):
             v = self.cls(["x"], pd.Series([0, 1, 2], index=[3, 2, 1]))
             assert_identical(v, v[[0, 1, 2]])
    @@ -329,7 +342,8 @@ def test_1d_math(self):
             assert_array_equal(y - v, 1 - v)
             # verify attributes are dropped
             v2 = self.cls(["x"], x, {"units": "meters"})
    -        assert_identical(base_v, +v2)
    +        with set_options(keep_attrs=False):
    +            assert_identical(base_v, +v2)
             # binary ops with all variables
             assert_array_equal(v + v, 2 * v)
             w = self.cls(["x"], y, {"foo": "bar"})
    @@ -821,6 +835,9 @@ def test_getitem_error(self):
             ],
         )
         @pytest.mark.parametrize("xr_arg, np_arg", _PAD_XR_NP_ARGS)
    +    @pytest.mark.filterwarnings(
    +        r"ignore:dask.array.pad.+? converts integers to floats."
    +    )
         def test_pad(self, mode, xr_arg, np_arg):
             data = np.arange(4 * 3 * 2).reshape(4, 3, 2)
             v = self.cls(["x", "y", "z"], data)
    @@ -1256,13 +1273,13 @@ def test_isel(self):
             assert_identical(v.isel(time=[]), v[[]])
             with raises_regex(
                 ValueError,
    -            r"dimensions {'not_a_dim'} do not exist. Expected one or more of "
    +            r"Dimensions {'not_a_dim'} do not exist. Expected one or more of "
                 r"\('time', 'x'\)",
             ):
                 v.isel(not_a_dim=0)
             with pytest.warns(
                 UserWarning,
    -            match=r"dimensions {'not_a_dim'} do not exist. Expected one or more of "
    +            match=r"Dimensions {'not_a_dim'} do not exist. Expected one or more of "
                 r"\('time', 'x'\)",
             ):
                 v.isel(not_a_dim=0, missing_dims="warn")
    @@ -1378,7 +1395,7 @@ def test_transpose_0d(self):
             ]:
                 variable = Variable([], value)
                 actual = variable.transpose()
    -            assert actual.identical(variable)
    +            assert_identical(actual, variable)
     
         def test_squeeze(self):
             v = Variable(["x", "y"], [[1]])
    @@ -1431,7 +1448,7 @@ def test_set_dims_object_dtype(self):
             for i in range(3):
                 exp_values[i] = ("a", 1)
             expected = Variable(["x"], exp_values)
    -        assert actual.identical(expected)
    +        assert_identical(actual, expected)
     
         def test_stack(self):
             v = Variable(["x", "y"], [[0, 1], [2, 3]], {"foo": "bar"})
    @@ -1557,10 +1574,6 @@ def test_reduce(self):
     
             with raises_regex(ValueError, "cannot supply both"):
                 v.mean(dim="x", axis=0)
    -        with pytest.warns(DeprecationWarning, match="allow_lazy is deprecated"):
    -            v.mean(dim="x", allow_lazy=True)
    -        with pytest.warns(DeprecationWarning, match="allow_lazy is deprecated"):
    -            v.mean(dim="x", allow_lazy=False)
     
         @pytest.mark.parametrize("skipna", [True, False])
         @pytest.mark.parametrize("q", [0.25, [0.50], [0.25, 0.75]])
    @@ -1588,7 +1601,8 @@ def test_quantile_dask(self, q, axis, dim):
         def test_quantile_chunked_dim_error(self):
             v = Variable(["x", "y"], self.d).chunk({"x": 2})
     
    -        with raises_regex(ValueError, "dimension 'x'"):
    +        # this checks for ValueError in dask.array.apply_gufunc
    +        with raises_regex(ValueError, "consists of multiple chunks"):
                 v.quantile(0.5, dim="x")
     
         @pytest.mark.parametrize("q", [-0.1, 1.1, [2], [0.25, 2]])
    @@ -1657,7 +1671,7 @@ def test_reduce_funcs(self):
             assert_identical(v.all(dim="x"), Variable([], False))
     
             v = Variable("t", pd.date_range("2000-01-01", periods=3))
    -        assert v.argmax(skipna=True) == 2
    +        assert v.argmax(skipna=True, dim="t") == 2
     
             assert_identical(v.max(), Variable([], pd.Timestamp("2000-01-03")))
     
    @@ -1948,7 +1962,10 @@ def test_coarsen_keep_attrs(self, operation="mean"):
             # Test kept attrs
             with set_options(keep_attrs=True):
                 new = Variable(["coord"], np.linspace(1, 10, 100), attrs=_attrs).coarsen(
    -                windows={"coord": 1}, func=test_func, boundary="exact", side="left"
    +                windows={"coord": 1},
    +                func=test_func,
    +                boundary="exact",
    +                side="left",
                 )
             assert new.attrs == _attrs
     
    @@ -2061,12 +2078,12 @@ def test_concat_periods(self):
             coords = [IndexVariable("t", periods[:5]), IndexVariable("t", periods[5:])]
             expected = IndexVariable("t", periods)
             actual = IndexVariable.concat(coords, dim="t")
    -        assert actual.identical(expected)
    +        assert_identical(actual, expected)
             assert isinstance(actual.to_index(), pd.PeriodIndex)
     
             positions = [list(range(5)), list(range(5, 10))]
             actual = IndexVariable.concat(coords, dim="t", positions=positions)
    -        assert actual.identical(expected)
    +        assert_identical(actual, expected)
             assert isinstance(actual.to_index(), pd.PeriodIndex)
     
         def test_concat_multiindex(self):
    @@ -2074,9 +2091,20 @@ def test_concat_multiindex(self):
             coords = [IndexVariable("x", idx[:2]), IndexVariable("x", idx[2:])]
             expected = IndexVariable("x", idx)
             actual = IndexVariable.concat(coords, dim="x")
    -        assert actual.identical(expected)
    +        assert_identical(actual, expected)
             assert isinstance(actual.to_index(), pd.MultiIndex)
     
    +    @pytest.mark.parametrize("dtype", [str, bytes])
    +    def test_concat_str_dtype(self, dtype):
    +
    +        a = IndexVariable("x", np.array(["a"], dtype=dtype))
    +        b = IndexVariable("x", np.array(["b"], dtype=dtype))
    +        expected = IndexVariable("x", np.array(["a", "b"], dtype=dtype))
    +
    +        actual = IndexVariable.concat([a, b])
    +        assert actual.identical(expected)
    +        assert np.issubdtype(actual.dtype, dtype)
    +
         def test_coordinate_alias(self):
             with pytest.warns(Warning, match="deprecated"):
                 x = Coordinate("x", [1, 2, 3])
    @@ -2213,6 +2241,10 @@ def test_full_like(self):
             assert expect.dtype == bool
             assert_identical(expect, full_like(orig, True, dtype=bool))
     
    +        # raise error on non-scalar fill_value
    +        with raises_regex(ValueError, "must be scalar"):
    +            full_like(orig, [1.0, 2.0])
    +
         @requires_dask
         def test_full_like_dask(self):
             orig = Variable(
    diff --git a/xarray/tests/test_weighted.py b/xarray/tests/test_weighted.py
    index 24531215dfb..dc79d417b9c 100644
    --- a/xarray/tests/test_weighted.py
    +++ b/xarray/tests/test_weighted.py
    @@ -5,6 +5,8 @@
     from xarray import DataArray
     from xarray.tests import assert_allclose, assert_equal, raises_regex
     
    +from . import raise_if_dask_computes, requires_cftime, requires_dask
    +
     
     @pytest.mark.parametrize("as_dataset", (True, False))
     def test_weighted_non_DataArray_weights(as_dataset):
    @@ -29,6 +31,47 @@ def test_weighted_weights_nan_raises(as_dataset, weights):
             data.weighted(DataArray(weights))
     
     
    +@requires_dask
    +@pytest.mark.parametrize("as_dataset", (True, False))
    +@pytest.mark.parametrize("weights", ([np.nan, 2], [np.nan, np.nan]))
    +def test_weighted_weights_nan_raises_dask(as_dataset, weights):
    +
    +    data = DataArray([1, 2]).chunk({"dim_0": -1})
    +    if as_dataset:
    +        data = data.to_dataset(name="data")
    +
    +    weights = DataArray(weights).chunk({"dim_0": -1})
    +
    +    with raise_if_dask_computes():
    +        weighted = data.weighted(weights)
    +
    +    with pytest.raises(ValueError, match="`weights` cannot contain missing values."):
    +        weighted.sum().load()
    +
    +
    +@requires_cftime
    +@requires_dask
    +@pytest.mark.parametrize("time_chunks", (1, 5))
    +@pytest.mark.parametrize("resample_spec", ("1AS", "5AS", "10AS"))
    +def test_weighted_lazy_resample(time_chunks, resample_spec):
    +    # https://github.com/pydata/xarray/issues/4625
    +
    +    # simple customized weighted mean function
    +    def mean_func(ds):
    +        return ds.weighted(ds.weights).mean("time")
    +
    +    # example dataset
    +    t = xr.cftime_range(start="2000", periods=20, freq="1AS")
    +    weights = xr.DataArray(np.random.rand(len(t)), dims=["time"], coords={"time": t})
    +    data = xr.DataArray(
    +        np.random.rand(len(t)), dims=["time"], coords={"time": t, "weights": weights}
    +    )
    +    ds = xr.Dataset({"data": data}).chunk({"time": time_chunks})
    +
    +    with raise_if_dask_computes():
    +        ds.resample(time=resample_spec).map(mean_func)
    +
    +
     @pytest.mark.parametrize(
         ("weights", "expected"),
         (([1, 2], 3), ([2, 0], 2), ([0, 0], np.nan), ([-1, 1], np.nan)),
    @@ -59,6 +102,18 @@ def test_weighted_sum_of_weights_nan(weights, expected):
         assert_equal(expected, result)
     
     
    +def test_weighted_sum_of_weights_bool():
    +    # https://github.com/pydata/xarray/issues/4074
    +
    +    da = DataArray([1, 2])
    +    weights = DataArray([True, True])
    +    result = da.weighted(weights).sum_of_weights()
    +
    +    expected = DataArray(2)
    +
    +    assert_equal(expected, result)
    +
    +
     @pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan], [np.nan, np.nan]))
     @pytest.mark.parametrize("factor", [0, 1, 3.14])
     @pytest.mark.parametrize("skipna", (True, False))
    @@ -107,7 +162,7 @@ def test_weighted_sum_nan(weights, expected, skipna):
         assert_equal(expected, result)
     
     
    -@pytest.mark.filterwarnings("ignore:Mean of empty slice")
    +@pytest.mark.filterwarnings("error")
     @pytest.mark.parametrize("da", ([1.0, 2], [1, np.nan], [np.nan, np.nan]))
     @pytest.mark.parametrize("skipna", (True, False))
     @pytest.mark.parametrize("factor", [1, 2, 3.14])
    @@ -158,6 +213,17 @@ def test_weighted_mean_nan(weights, expected, skipna):
         assert_equal(expected, result)
     
     
    +def test_weighted_mean_bool():
    +    # https://github.com/pydata/xarray/issues/4074
    +    da = DataArray([1, 1])
    +    weights = DataArray([True, True])
    +    expected = DataArray(1)
    +
    +    result = da.weighted(weights).mean()
    +
    +    assert_equal(expected, result)
    +
    +
     def expected_weighted(da, weights, dim, skipna, operation):
         """
         Generate expected result using ``*`` and ``sum``. This is checked against
    @@ -183,12 +249,28 @@ def expected_weighted(da, weights, dim, skipna, operation):
             return weighted_mean
     
     
    +def check_weighted_operations(data, weights, dim, skipna):
    +
    +    # check sum of weights
    +    result = data.weighted(weights).sum_of_weights(dim)
    +    expected = expected_weighted(data, weights, dim, skipna, "sum_of_weights")
    +    assert_allclose(expected, result)
    +
    +    # check weighted sum
    +    result = data.weighted(weights).sum(dim, skipna=skipna)
    +    expected = expected_weighted(data, weights, dim, skipna, "sum")
    +    assert_allclose(expected, result)
    +
    +    # check weighted mean
    +    result = data.weighted(weights).mean(dim, skipna=skipna)
    +    expected = expected_weighted(data, weights, dim, skipna, "mean")
    +    assert_allclose(expected, result)
    +
    +
     @pytest.mark.parametrize("dim", ("a", "b", "c", ("a", "b"), ("a", "b", "c"), None))
    -@pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean"))
     @pytest.mark.parametrize("add_nans", (True, False))
     @pytest.mark.parametrize("skipna", (None, True, False))
    -@pytest.mark.parametrize("as_dataset", (True, False))
    -def test_weighted_operations_3D(dim, operation, add_nans, skipna, as_dataset):
    +def test_weighted_operations_3D(dim, add_nans, skipna):
     
         dims = ("a", "b", "c")
         coords = dict(a=[0, 1, 2, 3], b=[0, 1, 2, 3], c=[0, 1, 2, 3])
    @@ -204,46 +286,29 @@ def test_weighted_operations_3D(dim, operation, add_nans, skipna, as_dataset):
     
         data = DataArray(data, dims=dims, coords=coords)
     
    -    if as_dataset:
    -        data = data.to_dataset(name="data")
    -
    -    if operation == "sum_of_weights":
    -        result = data.weighted(weights).sum_of_weights(dim)
    -    else:
    -        result = getattr(data.weighted(weights), operation)(dim, skipna=skipna)
    +    check_weighted_operations(data, weights, dim, skipna)
     
    -    expected = expected_weighted(data, weights, dim, skipna, operation)
    -
    -    assert_allclose(expected, result)
    +    data = data.to_dataset(name="data")
    +    check_weighted_operations(data, weights, dim, skipna)
     
     
    -@pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean"))
    -@pytest.mark.parametrize("as_dataset", (True, False))
    -def test_weighted_operations_nonequal_coords(operation, as_dataset):
    +def test_weighted_operations_nonequal_coords():
     
         weights = DataArray(np.random.randn(4), dims=("a",), coords=dict(a=[0, 1, 2, 3]))
         data = DataArray(np.random.randn(4), dims=("a",), coords=dict(a=[1, 2, 3, 4]))
     
    -    if as_dataset:
    -        data = data.to_dataset(name="data")
    -
    -    expected = expected_weighted(
    -        data, weights, dim="a", skipna=None, operation=operation
    -    )
    -    result = getattr(data.weighted(weights), operation)(dim="a")
    +    check_weighted_operations(data, weights, dim="a", skipna=None)
     
    -    assert_allclose(expected, result)
    +    data = data.to_dataset(name="data")
    +    check_weighted_operations(data, weights, dim="a", skipna=None)
     
     
    -@pytest.mark.parametrize("dim", ("dim_0", None))
     @pytest.mark.parametrize("shape_data", ((4,), (4, 4), (4, 4, 4)))
     @pytest.mark.parametrize("shape_weights", ((4,), (4, 4), (4, 4, 4)))
    -@pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean"))
     @pytest.mark.parametrize("add_nans", (True, False))
     @pytest.mark.parametrize("skipna", (None, True, False))
    -@pytest.mark.parametrize("as_dataset", (True, False))
     def test_weighted_operations_different_shapes(
    -    dim, shape_data, shape_weights, operation, add_nans, skipna, as_dataset
    +    shape_data, shape_weights, add_nans, skipna
     ):
     
         weights = DataArray(np.random.randn(*shape_weights))
    @@ -257,17 +322,12 @@ def test_weighted_operations_different_shapes(
     
         data = DataArray(data)
     
    -    if as_dataset:
    -        data = data.to_dataset(name="data")
    -
    -    if operation == "sum_of_weights":
    -        result = getattr(data.weighted(weights), operation)(dim)
    -    else:
    -        result = getattr(data.weighted(weights), operation)(dim, skipna=skipna)
    +    check_weighted_operations(data, weights, "dim_0", skipna)
    +    check_weighted_operations(data, weights, None, skipna)
     
    -    expected = expected_weighted(data, weights, dim, skipna, operation)
    -
    -    assert_allclose(expected, result)
    +    data = data.to_dataset(name="data")
    +    check_weighted_operations(data, weights, "dim_0", skipna)
    +    check_weighted_operations(data, weights, None, skipna)
     
     
     @pytest.mark.parametrize("operation", ("sum_of_weights", "sum", "mean"))
    @@ -297,7 +357,6 @@ def test_weighted_operations_keep_attr(operation, as_dataset, keep_attrs):
         assert not result.attrs
     
     
    -@pytest.mark.xfail(reason="xr.Dataset.map does not copy attrs of DataArrays GH: 3595")
     @pytest.mark.parametrize("operation", ("sum", "mean"))
     def test_weighted_operations_keep_attr_da_in_ds(operation):
         # GH #3595
    diff --git a/xarray/tutorial.py b/xarray/tutorial.py
    index d662f2fcaaf..055be36d80b 100644
    --- a/xarray/tutorial.py
    +++ b/xarray/tutorial.py
    @@ -45,13 +45,13 @@ def open_dataset(
             Name of the file containing the dataset. If no suffix is given, assumed
             to be netCDF ('.nc' is appended)
             e.g. 'air_temperature'
    -    cache_dir : string, optional
    +    cache_dir : str, optional
             The directory in which to search for and write cached data.
    -    cache : boolean, optional
    +    cache : bool, optional
             If True, then cache data locally for use on subsequent calls
    -    github_url : string
    +    github_url : str
             Github repository where the data is stored
    -    branch : string
    +    branch : str
             The git branch to download from
         kws : dict, optional
             Passed to xarray.open_dataset
    @@ -83,7 +83,7 @@ def open_dataset(
             urlretrieve(url, md5file)
     
             localmd5 = file_md5_checksum(localfile)
    -        with open(md5file, "r") as f:
    +        with open(md5file) as f:
                 remotemd5 = f.read()
             if localmd5 != remotemd5:
                 _os.remove(localfile)
    diff --git a/xarray/util/print_versions.py b/xarray/util/print_versions.py
    index 32051bb6843..d643d768093 100755
    --- a/xarray/util/print_versions.py
    +++ b/xarray/util/print_versions.py
    @@ -78,7 +78,7 @@ def netcdf_and_hdf5_versions():
     
     
     def show_versions(file=sys.stdout):
    -    """ print the versions of xarray and its dependencies
    +    """print the versions of xarray and its dependencies
     
         Parameters
         ----------
    @@ -129,7 +129,7 @@ def show_versions(file=sys.stdout):
             ("sphinx", lambda mod: mod.__version__),
         ]
     
    -    deps_blob = list()
    +    deps_blob = []
         for (modname, ver_f) in deps:
             try:
                 if modname in sys.modules: