Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

group save_datasets result by file #2281

Merged
merged 7 commits into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
50 changes: 50 additions & 0 deletions satpy/tests/test_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from __future__ import annotations

import datetime
import os
import shutil
import unittest
Expand Down Expand Up @@ -863,3 +864,52 @@ def test_add_decorate_basic_l(self):
from satpy.writers import add_decorate
new_img = add_decorate(self.orig_l_img, **self.decorate)
self.assertEqual('RGBA', new_img.mode)


def test_group_results_by_output_file(tmp_path):
"""Test grouping results by output file.

Add a test for grouping the results from save_datasets(..., compute=False)
by output file. This is useful if for some reason we want to treat each
output file as a seperate computation (that can still be computed together
later).
"""
from pyresample import create_area_def

from satpy.writers import group_results_by_output_file

from .utils import make_fake_scene
x = 10
fake_area = create_area_def("sargasso", 4326, resolution=1, width=x, height=x, center=(0, 0))
fake_scene = make_fake_scene(
{"dragon_top_height": (dat := xr.DataArray(
dims=("y", "x"),
data=da.arange(x*x).reshape((x, x)))),
"penguin_bottom_height": dat,
"kraken_depth": dat},
daskify=True,
area=fake_area,
common_attrs={"start_time": datetime.datetime(2022, 11, 16, 13, 27)})
# NB: even if compute=False, ``save_datasets`` creates (empty) files
(sources, targets) = fake_scene.save_datasets(
filename=os.fspath(tmp_path / "test-{name}.tif"),
writer="ninjogeotiff",
compress="NONE",
fill_value=0,
compute=False,
ChannelID="x",
DataType="x",
PhysicUnit="K",
PhysicValue="Temperature",
SatelliteNameID="x")

grouped = group_results_by_output_file(sources, targets)

assert len(grouped) == 3
assert len({x.rfile.path for x in grouped[0][1]}) == 1
for x in grouped:
assert len(x[0]) == len(x[1])
assert sources[:5] == grouped[0][0]
assert targets[:5] == grouped[0][1]
assert sources[10:] == grouped[2][0]
assert targets[10:] == grouped[2][1]
63 changes: 63 additions & 0 deletions satpy/writers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,69 @@ def flatten(results):
return sources, targets, delayeds


def group_results_by_output_file(sources, targets):
"""Group results by output file.

For writers that return sources and targets for ``compute=False``, split
the results by output file.

When not only the data but also GeoTIFF tags are dask arrays, then
``save_datasets(..., compute=False)``` returns a tuple of flat lists,
where the second list consists of a mixture of ``RIOTag`` and ``RIODataset``
objects (from trollimage). In some cases, we may want to get a seperate
delayed object for each file; for example, if we want to add a wrapper to do
something with the file as soon as it's finished. This function unflattens
the flat lists into a list of (src, target) tuples.

For example, to close files as soon as computation is completed::

>>> @dask.delayed
>>> def closer(obj, targs):
... for targ in targs:
... targ.close()
... return obj
>>> (srcs, targs) = sc.save_datasets(writer="ninjogeotiff", compute=False, **ninjo_tags)
>>> for (src, targ) in group_results_by_output_file(srcs, targs):
... delayed_store = da.store(src, targ, compute=False)
... wrapped_store = closer(delayed_store, targ)
... wrapped.append(wrapped_store)
>>> compute_writer_results(wrapped)

In the wrapper you can do other useful tasks, such as writing a log message
or moving files to a different directory.

.. warning::

Adding a callback may impact runtime and RAM. The pattern or cause is
unclear. Tests with FCI data show that for resampling with high RAM
use (from around 15 GB), runtime increases when a callback is added.
Tests with ABI or low RAM consumption rather show a decrease in runtime.
More information, see `these GitHub comments
<https://github.com/pytroll/satpy/pull/2281#issuecomment-1324910253>`_
Users who find out more are encouraged to contact the Satpy developers
with clues.

Args:
sources: List of sources (typically dask.array) as returned by
:meth:`Scene.save_datasets`.
targets: List of targets (should be ``RIODataset`` or ``RIOTag``) as
returned by :meth:`Scene.save_datasets`.

Returns:
List of ``Tuple(List[sources], List[targets])`` with a length equal to
the number of output files planned to be written by
:meth:`Scene.save_datasets`.
"""
ofs = {}
for (src, targ) in zip(sources, targets):
fn = targ.rfile.path
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this rfile is a XRImage thing only, right? Should we maybe add something to this object in trollimage so you can do .path on the target or str(targ) to get the path?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would fspath be appropriate?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤷‍♂️ I don't use it often, but maybe? I guess it depends if other popular output writing libraries have support for it (netcdf4-python, rasterio, PIL, etc).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the other hand, fspath is rather for objects that represent a filesystem path, rather than an object that reprents an open file. I don't know if there exists a standard to get the path corresponding to an open file.

if fn not in ofs:
ofs[fn] = ([], [])
ofs[fn][0].append(src)
ofs[fn][1].append(targ)
return list(ofs.values())


def compute_writer_results(results):
"""Compute all the given dask graphs `results` so that the files are saved.

Expand Down