Skip to content

Commit

Permalink
Merge pull request #2281 from gerritholl/split-by-file
Browse files Browse the repository at this point in the history
group save_datasets result by file
  • Loading branch information
gerritholl committed Feb 2, 2023
2 parents 8a0b65d + 1fe056e commit 864983c
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 0 deletions.
50 changes: 50 additions & 0 deletions satpy/tests/test_writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from __future__ import annotations

import datetime
import os
import shutil
import unittest
Expand Down Expand Up @@ -863,3 +864,52 @@ def test_add_decorate_basic_l(self):
from satpy.writers import add_decorate
new_img = add_decorate(self.orig_l_img, **self.decorate)
self.assertEqual('RGBA', new_img.mode)


def test_group_results_by_output_file(tmp_path):
"""Test grouping results by output file.
Add a test for grouping the results from save_datasets(..., compute=False)
by output file. This is useful if for some reason we want to treat each
output file as a seperate computation (that can still be computed together
later).
"""
from pyresample import create_area_def

from satpy.writers import group_results_by_output_file

from .utils import make_fake_scene
x = 10
fake_area = create_area_def("sargasso", 4326, resolution=1, width=x, height=x, center=(0, 0))
fake_scene = make_fake_scene(
{"dragon_top_height": (dat := xr.DataArray(
dims=("y", "x"),
data=da.arange(x*x).reshape((x, x)))),
"penguin_bottom_height": dat,
"kraken_depth": dat},
daskify=True,
area=fake_area,
common_attrs={"start_time": datetime.datetime(2022, 11, 16, 13, 27)})
# NB: even if compute=False, ``save_datasets`` creates (empty) files
(sources, targets) = fake_scene.save_datasets(
filename=os.fspath(tmp_path / "test-{name}.tif"),
writer="ninjogeotiff",
compress="NONE",
fill_value=0,
compute=False,
ChannelID="x",
DataType="x",
PhysicUnit="K",
PhysicValue="Temperature",
SatelliteNameID="x")

grouped = group_results_by_output_file(sources, targets)

assert len(grouped) == 3
assert len({x.rfile.path for x in grouped[0][1]}) == 1
for x in grouped:
assert len(x[0]) == len(x[1])
assert sources[:5] == grouped[0][0]
assert targets[:5] == grouped[0][1]
assert sources[10:] == grouped[2][0]
assert targets[10:] == grouped[2][1]
63 changes: 63 additions & 0 deletions satpy/writers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,69 @@ def flatten(results):
return sources, targets, delayeds


def group_results_by_output_file(sources, targets):
"""Group results by output file.
For writers that return sources and targets for ``compute=False``, split
the results by output file.
When not only the data but also GeoTIFF tags are dask arrays, then
``save_datasets(..., compute=False)``` returns a tuple of flat lists,
where the second list consists of a mixture of ``RIOTag`` and ``RIODataset``
objects (from trollimage). In some cases, we may want to get a seperate
delayed object for each file; for example, if we want to add a wrapper to do
something with the file as soon as it's finished. This function unflattens
the flat lists into a list of (src, target) tuples.
For example, to close files as soon as computation is completed::
>>> @dask.delayed
>>> def closer(obj, targs):
... for targ in targs:
... targ.close()
... return obj
>>> (srcs, targs) = sc.save_datasets(writer="ninjogeotiff", compute=False, **ninjo_tags)
>>> for (src, targ) in group_results_by_output_file(srcs, targs):
... delayed_store = da.store(src, targ, compute=False)
... wrapped_store = closer(delayed_store, targ)
... wrapped.append(wrapped_store)
>>> compute_writer_results(wrapped)
In the wrapper you can do other useful tasks, such as writing a log message
or moving files to a different directory.
.. warning::
Adding a callback may impact runtime and RAM. The pattern or cause is
unclear. Tests with FCI data show that for resampling with high RAM
use (from around 15 GB), runtime increases when a callback is added.
Tests with ABI or low RAM consumption rather show a decrease in runtime.
More information, see `these GitHub comments
<https://github.com/pytroll/satpy/pull/2281#issuecomment-1324910253>`_
Users who find out more are encouraged to contact the Satpy developers
with clues.
Args:
sources: List of sources (typically dask.array) as returned by
:meth:`Scene.save_datasets`.
targets: List of targets (should be ``RIODataset`` or ``RIOTag``) as
returned by :meth:`Scene.save_datasets`.
Returns:
List of ``Tuple(List[sources], List[targets])`` with a length equal to
the number of output files planned to be written by
:meth:`Scene.save_datasets`.
"""
ofs = {}
for (src, targ) in zip(sources, targets):
fn = targ.rfile.path
if fn not in ofs:
ofs[fn] = ([], [])
ofs[fn][0].append(src)
ofs[fn][1].append(targ)
return list(ofs.values())


def compute_writer_results(results):
"""Compute all the given dask graphs `results` so that the files are saved.
Expand Down

0 comments on commit 864983c

Please sign in to comment.