Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

group save_datasets result by file #2281

Merged
merged 7 commits into from Feb 2, 2023
Merged

Conversation

gerritholl
Copy link
Collaborator

@gerritholl gerritholl commented Nov 16, 2022

Group the results of Scene.save_datasets(..., compute=False) by output file, when multiple files are to be written and one or more files have multiple RIODataset or RIOTag objects. This is helpful when we want a wrapper per file and therefore a single da.store call for each file, which is in turn needed for pytroll/trollflow2#168.

  • Closes #xxxx
  • Tests added
  • Fully documented

Group the results of save_datasets by output file, when multiple files
are being written and each file has multiple RIODataset or RIOTag
objects.  This is helpful when we want a wrapper per file and therefore
a single ``da.store`` call for each file.
@djhoese
Copy link
Member

djhoese commented Nov 16, 2022

This is an interesting idea. So in your trollflow callback PR you would do the grouping then pass each group to the callback? Biggest potential issue I see with that is if you have a dask array going to more than one file and you compute the groups separately, you'll likely end up computing each input image dask array multiple times. If you pass them all to da.store at the same time then this shouldn't matter. Or if you pass them to da.store separately but compute the Delayed results at the same time then that should be fine too.

@gerritholl
Copy link
Collaborator Author

gerritholl commented Nov 16, 2022

The ones that go to the same file go to da.store at the same time. I still send everything to da.compute at the same time. The difference would be:

Now:

obj = da.store(sources_for_all_files, targets_for_all_files, compute=False)
da.compute(obj)

Later, optionally:

obj1 = da.store(sources_for_file_1, targets_for_file_1, compute=False)
obj2 = da.store(sources_for_file_2, targets_for_file_2, compute=False)
obj3 = da.store(sources_for_file_3, targets_for_file_3, compute=False)
da.compute([obj1, obj2, obj3])

I don't know what that would mean for performance, but I need approach 2 so that I can make a wrapper that does something as soon as an individual file is completed. At least I can't think of another way to achieve that. The wrapper would encompass obj1, obj2, and obj3 in the second example. If I wrap obj in the first example, it is still called only when all files are completed, which means I have gained nothing compared to the status quo.

@djhoese
Copy link
Member

djhoese commented Nov 16, 2022

Approach 2 should be fine. As long as dask is given all of the dask graphs at the same time it will be able to optimize things as necessary (ex. "this file output uses 'C01' and so does this file, I'll only compute 'C01' once").

Add an implementation to split sources/targets by file.  This helps if
we want to have a wrapper function to execute immediately when each file
is completed when time for copmutation comes.
@gerritholl gerritholl marked this pull request as ready for review November 16, 2022 17:20
@codecov
Copy link

codecov bot commented Nov 16, 2022

Codecov Report

Merging #2281 (1fe056e) into main (bc7dea3) will increase coverage by 0.00%.
The diff coverage is 100.00%.

@@           Coverage Diff           @@
##             main    #2281   +/-   ##
=======================================
  Coverage   94.58%   94.58%           
=======================================
  Files         314      314           
  Lines       47511    47538   +27     
=======================================
+ Hits        44936    44963   +27     
  Misses       2575     2575           
Flag Coverage Δ
behaviourtests 4.50% <3.70%> (-0.01%) ⬇️
unittests 95.21% <100.00%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
satpy/tests/test_writers.py 99.04% <100.00%> (+0.04%) ⬆️
satpy/writers/__init__.py 90.59% <100.00%> (+0.18%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

Adapt call to zip for compatibility with older versions of Python
@coveralls
Copy link

coveralls commented Nov 17, 2022

Coverage Status

Coverage increased (+0.003%) to 95.163% when pulling 1fe056e on gerritholl:split-by-file into bc7dea3 on pytroll:main.

gerritholl added a commit to gerritholl/trollflow2 that referenced this pull request Nov 17, 2022
Apply the callbacks once per set of targets sharing a file.
This requires pytroll/satpy#2281
@gerritholl
Copy link
Collaborator Author

pre-commit did not complain when I did a commit locally...

Copy link
Member

@djhoese djhoese left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An example usage in the docstring might be helpful. I think you can do an Examples: section, but formatting is always confusing to get sphinx to render it correctly.

The reason I ask for an example is I'm curiosu if this should be used after compute_writer_results, before, or they are completely separate?

@gerritholl
Copy link
Collaborator Author

I did some performance tests to compare between computing da.store(src, targ, compute=False) or [da.store(s, t, compute=False) for (s, t) in zip(src, targ)], because in the context of the linked trollflow2 issue I find a severe impact on performance. However, the performance penalty of having more da.store delayed objects to be computed is small.

import os
import dask.array as da
import dask
from dask.diagnostics import Profiler, ResourceProfiler, visualize
from sattools.io import plotdir
from glob import glob
from satpy import Scene
from dask import delayed

seviri_files = glob("/media/nas/x21308/scratch/SEVIRI/202103300900/H-000*")
sc = Scene(filenames={"seviri_l1b_hrit": seviri_files})
names = sc.available_dataset_names()
sc.load(names)

(src, targ) = sc.save_datasets(
    writer="geotiff",
    filename=os.fspath(plotdir() / "test-{name}.tif"),
    fill_value=0,
    compute=False)
#delayeds = da.store(src, targ, compute=False)
delayeds = [da.store(s, t, compute=False) for (s, t) in zip(src, targ)]
with Profiler() as prof, ResourceProfiler(dt=0.05) as rprof:
    dask.compute(delayeds)
visualize([prof, rprof], show=False, save=True, filename=os.fspath(plotdir() /
    f"dask-profile-many-store.html"))

Using just da.store(src, targ, compute=False):

image

Using [da.store(s, t, compute=False) for (s, t) in zip(src, targ)], RAM increases from 1.38 GB to 1.45 GB, with wall clock time increasing from 0:13.21 to 0:13.53.

image

@djhoese
Copy link
Member

djhoese commented Nov 17, 2022

A couple things come to mind:

  1. I think you should find a larger test case. One that takes at least 30 seconds if not closer to multiple minutes. One that only takes a few seconds feels like it would be heavily influenced by disk/OS caches.
  2. There was (or still exists) a bug in dask.array.store where some cases aren't as optimized as they should be. I believe I filed a bug with dask or contributed to an existing one, but I don't remember if it was ever resolved. It may be one where I was expected to fix it but I didn't understand the problem enough to fix it.

@gerritholl
Copy link
Collaborator Author

@djhoese This one? dask/dask#8380

@gerritholl
Copy link
Collaborator Author

gerritholl commented Nov 18, 2022

Loading FCI with 4 composites and 3 channels, then resampling to an equirectangular 2km-grid containing the full FCI field of view. Calling da.store(..., compute=False) once for each output file leads to an increase in RAM by 19%, but no increase in runtime. The test machine has 8 CPUs and 128 GB RAM. There is no additional increase of resources caused by adding a wrapper function.

With a single da.store(..., compute=False) call passed to dask.compute:

3:16.47, 33.6 GB

image

dask-profile-no-wrap-single-store.html.gz

With multiple da.store(..., compute=False) calls passed to dask.compute, no wrapper:

3:18.61, 40.1 GB

dask-profile-no-wrap-multi-store.html.gz

image

With multiple da.store(..., compute=False) calls passed to dask.compute, and a noop wrapper:

3:17.38, 39.8 GB

image

dask-profile-wrap-multi-store.html.gz

Timings performed with:

import hdf5plugin
import os
import dask.array as da
import dask
from dask.diagnostics import Profiler, ResourceProfiler, visualize
from glob import glob
from satpy import Scene
from dask import delayed

fci_files = glob("/media/x21308/MTG_test_data/2022_05_MTG_Testdata/RC0042/*BODY*.nc")
sc = Scene(filenames={"fci_l1c_nc": fci_files})
names = ["natural_color", "true_color", "airmass", "cimss_cloud_type",
        "ir_105", "vis_04", "vis_06"]
sc.load(names)

def noop(obj):
    """Do nothing."""
    return obj

ls = sc.resample("nq0002km")
(src, targ) = ls.save_datasets(
    writer="geotiff",
    filename="/media/x21308/scratch/dask-tests/{start_time:%Y%m%d%H%M}-{platform_name}-{sensor}-{area.area_id}-{name}-dask-test.tif",
    fill_value=0,
    compute=False)
#delayeds = da.store(src, targ, compute=False)
delayeds = [da.store(s, t, compute=False) for (s, t) in zip(src, targ)]
#delayeds = [delayed(noop)(da.store(s, t, compute=False)) for (s, t) in zip(src, targ)]
with Profiler() as prof, ResourceProfiler(dt=0.05) as rprof:
    dask.compute(delayeds)
visualize([prof, rprof], show=False, save=True, filename="/tmp/dask-profile-no-wrap-multi-store.html")

and then called with

\time -v python dask-delay-wrapper-resource-problem-large.py

@gerritholl
Copy link
Collaborator Author

When I increase this to 7 composites and 6 channels:

  • single da.store call: 4:11.41, 38.1 GB
  • multiple da.store calls: 4:21.06, 43.2 GB

so the differences are relatively smaller

@gerritholl
Copy link
Collaborator Author

When storing 7 composites and 6 channels to ninjogeotiff, which has some tags that are also dask-based (RIOTag), we get:

  • single call: 3:26.81, 35.2 GB
  • one call per target: 5:30.63, 37.4 GB
  • one call per output file (using the splitter): 3:36.22, 38.7 GB

I'm getting inconsistent results and remain confused.

@gerritholl
Copy link
Collaborator Author

If I wrap the da.store(..., compute=False) calls with a with dask.config.set({"optimization.fuse.active": False}), such as suggested by @djhoese in dask/dask#8380 (comment), RAM goes up even more. The one call per target is faster at 4:08.82, but uses 47.6 GB RAM.

But maybe that's not at all surprising, considering I have no idea what I'm doing ;)

@djhoese
Copy link
Member

djhoese commented Nov 18, 2022

That is the dask issue I was thinking of. The idea is that when dask computes one or more graphs it is able to say "this task and this task are the same, let's compute them once and share the result for the future tasks that need it". In the da.store(compute=False); da.store(compute=False); da.compute case, the store functions are returning optimized dask graphs. This means some of the tasks may be "fused". So if you had a task that did b = a + 1 and then followed that with c = b * 2, dask might be smart enough to say "there's no reason to keep these as separate tasks, that just adds overhead, let's combine them and run them as one task (function) that does both operations.

The problem with this type of optimization in this store/store/compute case is that the compute can no longer tell that one fused task is the same as another fused task (I think the fused tasks get new unique names) or the graph is optimized differently.

I'm not sure I've kept track of which things are working for you and which things aren't between this PR and the trollflow one. I wonder if you/we could make an example with only Delayed functions that include print statements in them so it was very clear when they are being run multiple times. I may need to generate some dask visualize SVGs myself.

@gerritholl
Copy link
Collaborator Author

gerritholl commented Nov 18, 2022

I'm not sure I've kept track of which things are working for you and which things aren't between this PR and the trollflow one.

It's functional, but I think I'm seeing an increase in RAM, runtime, and the complexity of the dask graph, when I split the sources/targets across multiple da.store(compute=False) calls. I'm writing I think because I find it difficult to debug, because the performance differences are not obvious in short examples, and doing trial and error with long-running examples is slow, in particular since I'm not sure what I'm doing so I'm more or less trying things at random and then waiting 3–5 minutes for the result every time.

I wonder if you/we could make an example with only Delayed functions that include print statements in them so it was very clear when they are being run multiple times.

I will try. Thanks for your help. Worst that can reasonably happen is that I learn to understand dask better ☺

@djhoese
Copy link
Member

djhoese commented Nov 18, 2022

If you can post the dask graphs for some smaller examples that may be useful to nail down why they are different between being split or not split or with a callback function or not.

@gerritholl
Copy link
Collaborator Author

gerritholl commented Nov 18, 2022

I tried to make a small example based on a fake scene with three 10×10 datasets, but now I don't see much difference between the dask graphs. For single, there is one empty rectangle next to the bottom box that has (0, 0, 0). For multi, there are three empty rectangles.

import xarray as xr
import dask.array as da
from pyresample import create_area_def
from satpy.tests.utils import make_fake_scene
from satpy.writers import group_results_by_output_file
import numpy as np
import dask
import datetime
import os
from sattools.io import plotdir

mode = "multi"

x = 10
fake_area = create_area_def("sargasso", 4326, resolution=1, width=x, height=x, center=(0, 0))
fake_scene = make_fake_scene(
    {k: xr.DataArray(
        dims=("y", "x"),
        data=np.linspace(200, 300, x*x).reshape((x, x)))
        for k in ("dragon_top_height", "penguin_bottom_height", "kraken_depth")},
    daskify=True,
    area=fake_area,
    common_attrs={"start_time": datetime.datetime(2022, 11, 18, 18)})

objs = []
fn = os.fspath(plotdir() / "test-{name}.tif")
(srcs, targs) = fake_scene.save_datasets(
        writer="ninjogeotiff", filename=fn, compute=False, fill_value=0,
        ChannelID="x", DataType="x", PhysicUnit="K", PhysicValue="Temperature",
        SatelliteNameID="x")
if mode == "single":
    objs = da.store(srcs, targs, compute=False)
else:
    for (src, targ) in group_results_by_output_file(srcs, targs):
        objs.append(da.store(src, targ, compute=False))
#da.compute(objs)
dask.visualize(objs, filename=os.fspath(plotdir() / f"dask-graph-smallish-{mode:s}-store.svg"))

Result for mode=="single":

dask-graph-smallish-single-store

Result for mode=="multi":

dask-graph-smallish-multi-store

@gerritholl
Copy link
Collaborator Author

When I replace objs.append(da.store(src, targ, compute=False)) by objs.append(delayed(noop)(da.store(src, targ, compute=False))) , the noop-wrapper is displayed disconnected from the rest of the graph:

dask-graph-smallish-callback-store

@gerritholl
Copy link
Collaborator Author

I made some versions with enhance=False, to make the dask graphs simpler.

Single store:

dask-graph-smallish-single-store

Multiple store:

dask-graph-smallish-multi-store

Multiple store with callback:

dask-graph-smallish-callback-store

@gerritholl
Copy link
Collaborator Author

I don't know why noop is disconnected. In my experiment with trollflow2, it's at the top of the graph, where I would expect it.

@gerritholl
Copy link
Collaborator Author

In my trollflow2 case I do get a big difference when I add a noop wrapper, see pytroll/trollflow2#168 (comment)

@gerritholl
Copy link
Collaborator Author

gerritholl commented Nov 18, 2022

New clue: when the callback gets passed the source and the target, the dask graph starts to look very different.

import xarray as xr
import dask.array as da
from pyresample import create_area_def
from satpy.tests.utils import make_fake_scene
from satpy.writers import group_results_by_output_file
import numpy as np
import dask
import datetime
import os
from sattools.io import plotdir
from dask import delayed
from dask.graph_manipulation import bind

mode = "multi"

x = 10
fake_area = create_area_def("sargasso", 4326, resolution=1, width=x, height=x, center=(0, 0))
fake_scene = make_fake_scene(
    {k: xr.DataArray(
        dims=("y", "x"),
        data=np.linspace(200, 300, x*x).reshape((x, x)))
        for k in ("dragon_top_height", "penguin_bottom_height", "kraken_depth")},
    daskify=True,
    area=fake_area,
    common_attrs={"start_time": datetime.datetime(2022, 11, 18, 18)})

def noop(obj, src, targ):
    print(obj)
    return obj

objs = []
fn = os.fspath(plotdir() / "test-{name}.tif")
(srcs, targs) = fake_scene.save_datasets(
        writer="ninjogeotiff", filename=fn, compute=False, fill_value=0,
        ChannelID="x", DataType="x", PhysicUnit="K", PhysicValue="Temperature",
        SatelliteNameID="x", enhance=False)
if mode == "single":
    objs = da.store(srcs, targs, compute=False)
else:
    for (src, targ) in group_results_by_output_file(srcs, targs):
        if mode == "callback":
            #objs.append(bind(delayed(noop), [da.store(src, targ, compute=False)]))
            objs.append(delayed(noop)(da.store(src, targ, compute=False), src, targ))
        else:
            objs.append(da.store(src, targ, compute=False))
da.compute(objs)
dask.visualize(objs, filename=os.fspath(plotdir() / f"dask-graph-smallish-{mode:s}-store.svg"))

Multiple calls to da.store, no callback:

dask-graph-smallish-multi-store

Multiple calls to da.store, callback that also gets passed src and targ:

dask-graph-smallish-callback-store

The culprit appears to be src. I can pass targ without it messing up the task graph, but if I pass src to the wrapper then things change (probably for the worse).

@djhoese
Copy link
Member

djhoese commented Nov 18, 2022

Some guesses and other comments:

  1. Be careful comparing the graphs of things returned from store with the separate return values from store that you would have passed to compute. compute is going to optimize the graph between the 3 for you (or at least it is supposed to). This will also be where the bad things will happen related to that dask issue I filed.
  2. In your most recent comment with the simplified (enhance=False) graph and passing src, the one with the callback looks like the scaling XRImage does to go from in-memory scaled data to file-format-friendly scaled data (like converting 32-bit floats to uint8 for an 8-bit geotiff). For some reason dask decided not to "fuse" them into one task.
  3. In your most recent comment's code you actually call da.compute. I wonder if this is modifying one or more of the graphs in place. I would be very surprised by that, but I'm surprised at how different the graphs are too.
  4. I'm trying to look at the SVG to see if the disconnected nodes are just a shortcut by the browser rendering the SVG and saying "yes, I know it should be connected to at the top, but this is prettier".

Edit: Ok so those nodes are really just not connected.

@djhoese
Copy link
Member

djhoese commented Nov 18, 2022

Ok I played around with some stuff, but I'm not sure I learned much. My code looks something like this:

import dask.array as da
import dask
import numpy as np
from dask import delayed

a = da.random.random((5,))
b = da.random.random((5,))
a2 = a + 5
b2 = b + 5

dst_a = np.zeros((5,))
dst_b = np.zeros((5,))

a2_store = da.store(a2, dst_a, compute=False)
b2_store = da.store(b2, dst_b, compute=False)

a2b2_store = da.store([a2, b2], [dst_a, dst_b], compute=False)

pdelayed = delayed(print)
a2_cb_res = pdelayed(a2_store)
b2_cb_res = pdelayed(b2_store)
a2b2_cb_res = pdelayed(a2b2_store)

dask.visualize(a2_store, filename="a2_store.svg")
dask.visualize(b2_store, filename="b2_store.svg")
dask.visualize(a2b2_store, filename="a2b2_store.svg")
dask.visualize(a2_cb_res, filename="a2_cb_res.svg")
dask.visualize(b2_cb_res, filename="b2_cb_res.svg")
dask.visualize(a2b2_cb_res, filename="a2b2_cb_res.svg")

with dask.config.set({"optimization.fuse.active": False}):
    a2_store_nofuse = da.store(a2, dst_a, compute=False)
dask.visualize(a2_store_nofuse, filename="a2_store_nofuse.svg")

And some of the graphs:

a2_store.svg

a2_store

a2_store_nofuse.svg

a2_store_nofuse

a2b2_store.svg

a2b2_store

a2_cb_res.svg

a2_cb_res

And after looking at the source code for da.store again, I think I have some ideas of what we're seeing in the graphs.

I think the separate blocks from the rest of the graph are the "targets" of the store operation. In my example code these are numpy arrays, but they still need to be turned into a "task" to be able to get written.

Note you can use my_dask_arr.__dask_graph__() to get a HighLevelGraph object which has a __str__/__repr__ that shows the keys/nodes in the graph. You can also look at .dependencies to see a dictionary mapping node to the nodes that it depends on. From what I see dask wraps all of the source logic into a store-sources node which then goes to store-map which I think combines the sources and the targets, then for compute=False it goes to a final store-XXX task which I think is just a way to represent all of the store operations as a single Delayed object.

@gerritholl
Copy link
Collaborator Author

When I use the synchronous scheduler, adding a callback that closes each file as soon as it's finished makes a test script run 14 seconds or around 20% faster. Even adding a callback that does nothing makes it run 2 seconds or around 3% faster.

import hdf5plugin
import os
import dask
from satpy import Scene
import pathlib
from dask import delayed
from satpy.writers import group_results_by_output_file
import dask.array as da

mode = "direct"

def noop(obj, targ):
    """Do nothing."""
    return obj

def close(obj, targs):
    """Close target."""
    for targ in targs:
        targ.close()
    return obj

names = ['vis_04', 'vis_05', 'vis_06', 'vis_08', 'vis_09', 'nir_13', 'nir_16',
        'nir_22', 'ir_38', 'wv_63', 'wv_73', 'ir_87', 'ir_97', 'ir_105', 'ir_123',
        'ir_133']

fci_dir = pathlib.Path("/media/nas/x21308/MTG_test_data/2022_05_MTG_Testdata/RC0099/")

fci_files = [fci_dir / x for x in
    [
     'W_XX-EUMETSAT-Darmstadt,IMG+SAT,MTI1+FCI-1C-RRAD-FDHSI-FD--CHK-BODY---NC4E_C_EUMT_20170920163256_GTT_DEV_20170920162711_20170920162756_N_JLS_T_0099_0031.nc',
     'W_XX-EUMETSAT-Darmstadt,IMG+SAT,MTI1+FCI-1C-RRAD-FDHSI-FD--CHK-BODY---NC4E_C_EUMT_20170920163310_GTT_DEV_20170920162728_20170920162810_N_JLS_T_0099_0032.nc',
     'W_XX-EUMETSAT-Darmstadt,IMG+SAT,MTI1+FCI-1C-RRAD-FDHSI-FD--CHK-BODY---NC4E_C_EUMT_20170920163325_GTT_DEV_20170920162743_20170920162825_N_JLS_T_0099_0033.nc',
     'W_XX-EUMETSAT-Darmstadt,IMG+SAT,MTI1+FCI-1C-RRAD-FDHSI-FD--CHK-BODY---NC4E_C_EUMT_20170920163340_GTT_DEV_20170920162800_20170920162840_N_JLS_T_0099_0034.nc',
     'W_XX-EUMETSAT-Darmstadt,IMG+SAT,MTI1+FCI-1C-RRAD-FDHSI-FD--CHK-BODY---NC4E_C_EUMT_20170920163344_GTT_DEV_20170920162803_20170920162844_N_JLS_T_0099_0035.nc',
     'W_XX-EUMETSAT-Darmstadt,IMG+SAT,MTI1+FCI-1C-RRAD-FDHSI-FD--CHK-BODY---NC4E_C_EUMT_20170920163357_GTT_DEV_20170920162819_20170920162857_N_JLS_T_0099_0036.nc',
     'W_XX-EUMETSAT-Darmstadt,IMG+SAT,MTI1+FCI-1C-RRAD-FDHSI-FD--CHK-BODY---NC4E_C_EUMT_20170920163407_GTT_DEV_20170920162835_20170920162907_N_JLS_T_0099_0037.nc',
     'W_XX-EUMETSAT-Darmstadt,IMG+SAT,MTI1+FCI-1C-RRAD-FDHSI-FD--CHK-BODY---NC4E_C_EUMT_20170920163411_GTT_DEV_20170920162849_20170920162911_N_JLS_T_0099_0038.nc',
     'W_XX-EUMETSAT-Darmstadt,IMG+SAT,MTI1+FCI-1C-RRAD-FDHSI-FD--CHK-BODY---NC4E_C_EUMT_20170920163421_GTT_DEV_20170920162900_20170920162921_N_JLS_T_0099_0039.nc',
     ]]

def main():
    sc = Scene(filenames={"fci_l1c_nc": [os.fspath(f) for f in fci_files]})
    sc.load(names)
    (srcs, targs) = sc.save_datasets(
            writer="geotiff",
            enhance=False,
            compute=False)
    if mode == "direct":
        delayeds = [da.store(s, t, compute=False) for (s, t) in group_results_by_output_file(srcs, targs)]
    elif mode == "noop":
        delayeds = [delayed(noop)(da.store(s, t, compute=False), t) for (s, t) in group_results_by_output_file(srcs, targs)]
    elif mode == "close":
        delayeds = [delayed(close)(da.store(s, t, compute=False), t) for (s, t) in group_results_by_output_file(srcs, targs)]

    dask.compute(delayeds)

if __name__ == "__main__":
    with dask.config.set(scheduler="synchronous"):
        main()

With mode="direct": 1:09, 2.80 GB RAM
With mode="close": 0:56, 2.65 GB RAM
With mode="noop": 1:07, 2.80 GB RAM

@gerritholl
Copy link
Collaborator Author

gerritholl commented Nov 23, 2022

Comparison:

Enhancing Resampling DWD channels Composites scheduler writer time baseline RAM baseline time callback RAM callback
True eqc 1km/2km FD True 7 threads ninjogeotiff 08:56.4 83.7 GB 11:32.9 64.6 GB
True eqc 3km FD True 3 threads geotiff 01:48.5 15.4 GB 02:00.0 15.4 GB
True False True 3 threads geotiff 02:04.4 6.51 GB 01:53.9 8.21 GB
True eqc 3km FD False 3 threads geotiff 01:26.6 15.4 GB 01:39.2 15.4 GB
True False False 3 threads geotiff 01:26.3 8.74 GB 01:16.1 7.67 GB
True False False False threads geotiff 01:10.4 6.19 GB 01:07.3 5.20 GB
True native True 3 threads geotiff 02:38.9 13.0 GB 02:38.5 11.8 GB
True eqc 2km Europe True 3 threads geotiff 01:09.3 8.0 GB 1:05.0 6.9 GB
True eqc 3km Central Europe True 3 threads geotiff 0:17.6 3.3 GB 0:17.6 3.1 GB
True eqc 3km Central Europe True 3 threads ninjogeotiff 0:17.9 3.3 GB 0:18.0 3.1 GB
False eqc 3 km FD False False threads geotiff 0:55.7 15.3 GB 0:54.4 15.3 GB
False False False False threads geotiff 01:25.0 6.47 GB 01:17.2 6.33 GB
False eqc 3km Central Europe False False threads geotiff 0:12.2 1.73 GB 0:12.1 1.69 GB
True eqc 3 km Central Europe False False threads geotiff 0:12.9 1.85 GB 0:12.7 1.83 GB
True eqc 3 km Central Europe False False threads ninjogeotiff 0:12.8 1.66 GB 0:13.0 1.66 GB
False False False False synchronous geotiff 01:52.0 5.88 GB 01:25.8 5.43 GB
True africa False True threads geotiff 02:45.3 20.0 GB 02:26.5 15.7 GB

This reads full disc FCI test data and writes at least all channels in various configurations.

  • Callback: a single callback that closes the targets, nothing else
  • DWD channels: if true, replace single channels (with default stretching) by single-channel composites with enhancements with crude stretching and inverting IR imagery
  • Composites: in addition to all FCI channels, also do airmass, dust, and ash RGBs

The first line represents the experimental processing with all products.

So, it seems that adding the callback makes things slower only when RAM usage is high, unless we're resampling to africa. Without resampling, adding the callback may even make things faster.

@gerritholl
Copy link
Collaborator Author

When I replace

        sc = Scene(filenames={"fci_l1c_nc": [os.fspath(f) for f in fci_files]})
        sc.load(names, upper_right_corner="NE", pad_data=False)

by artificial data

        ar = satpy.resample.get_area_def("mtg_fci_fdss_2km")
        sc = make_fake_scene(
            {f"arr{x:d}": xr.DataArray(
                dims=("y", "x"),
                data=da.linspace(180+x, 210-x, ar.size).reshape(ar.shape) +
                     da.random.random(ar.shape))
                for x in range(15)},
            daskify=True,
            area=ar,
            common_attrs={
                "start_time": datetime.datetime(2022, 11, 18, 18),
                "units": "K"})

I cannot reproduce the slowdown due to the callback. Will test with other readers...

Add an example on how to use the utility function
group_results_by_output_file.  Also add a warning that for large
calculations, this appears to cause a slowdown.
@gerritholl
Copy link
Collaborator Author

I tried with ABI, and with ABI, I only find a decrease in runtime:

Enhancing Resampling DWD channels Composites scheduler writer time baseline RAM baseline time callback RAM callback
True eqc 2km FD False 3 threads geotiff 05:51.9 73.1 GB 05:24.2 63.3 GB
True eqc 3km FD False 3 threads geotiff 03:49.9 43.6 GB 03:36.6 43.4 GB

@gerritholl
Copy link
Collaborator Author

For the failing unstable test, see #2297 for a fix.

@gerritholl
Copy link
Collaborator Author

@djhoese I'm giving up on digging into this one any deeper. When the same action makes processing ABI faster but processing FCI slower, but only if also resampling, I am beaten. The PR works and I've added a warning that it might make processing slower or faster or neither.

@djhoese djhoese added enhancement code enhancements, features, improvements component:writers labels Nov 24, 2022
@djhoese
Copy link
Member

djhoese commented Nov 24, 2022

Since this isn't being used automatically by anything and is only going to be used in trollflow for now (which I don't use), I'm ok with it 😉

"""
ofs = {}
for (src, targ) in zip(sources, targets):
fn = targ.rfile.path
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this rfile is a XRImage thing only, right? Should we maybe add something to this object in trollimage so you can do .path on the target or str(targ) to get the path?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would fspath be appropriate?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤷‍♂️ I don't use it often, but maybe? I guess it depends if other popular output writing libraries have support for it (netcdf4-python, rasterio, PIL, etc).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the other hand, fspath is rather for objects that represent a filesystem path, rather than an object that reprents an open file. I don't know if there exists a standard to get the path corresponding to an open file.

@gerritholl
Copy link
Collaborator Author

I made a test using dask/dask#9732 (via the trollflow2 interface), with or without callbacks.

The good news: with the branch bugfix-store-nocompute-optimize by @djhoese, adding a callback barely increases runtime compared to baseline, and the callback version is faster than for dask 2022.12.0

The bad news: with bugfix-store-nocompute-optimize, the baseline is slower than with dask 2022.12.0.

dask version time baseline RAM baseline time callback RAM callback
2022.12.0 09:18.68 83.8 GB 11:55.37 68.2 GB
bugfix-store-nocompute-optimize 10:14.39 84.3 GB 10:30.81 65.8 GB

All of the above: with enhancing, resampling FCI test data to full disc equirectangular 1 km / 2 km (depending on channel and composite) using custom DWD channels and 7 composites, threaded scheduler, writing ninjogeotiff.

@djhoese
Copy link
Member

djhoese commented Dec 16, 2022

@gerritholl which code example are you considering your "baseline"? This one:

#2281 (comment)

@gerritholl
Copy link
Collaborator Author

gerritholl commented Dec 16, 2022

By now I've been trying a lot of different combinations, so my minimal script is not at all minimal anymore :P

To produce the tables in the previous comments I have used the following script, where the baseline has mode=="direct"

import hdf5plugin
import os
import dask
from satpy import Scene
import pathlib
from dask import delayed
from satpy.writers import group_results_by_output_file
import dask.array as da
import satpy.resample
from satpy.tests.utils import make_fake_scene
import dask.array as da
import xarray as xr
import datetime

#mode = "direct"
mode = "close"
scheduler = "threads"
enhance = True
#resampling = "northamerica"
#resampling = "nqeuro3km"
resampling = "nq0003km"
dwdchans = False
comps = True
writer = "geotiff"
datamode = "real"
#sensor = "abi_l1b"
sensor = "fci_l1c_nc"

fci_dir = pathlib.Path("/media/nas/x21308/MTG_test_data/2022_05_MTG_Testdata/RC0099/")
abi_dir = pathlib.Path("/media/nas/x21308/abi/G17/F/")

args_ninjogeotiff = dict(ChannelID="x", DataType="x", PhysicUnit="K",
        PhysicValue="Temperature", SatelliteNameID="x")

def noop(obj, targ):
    """Do nothing."""
    return obj

def close(obj, targs):
    """Close target."""
    for targ in targs:
        targ.close()
    return obj

if sensor == "fci_l1c_nc":
    if dwdchans:
        names = ['dwd_vis04', 'dwd_vis05', 'dwd_vis06', 'dwd_nir08', 'dwd_nir09',
                'dwd_nir13', 'dwd_nir16', 'dwd_nir22', 'dwd_ir38',
                'dwd_wv63', 'dwd_wv73', 'dwd_ir87', 'dwd_ir97', 'dwd_ir105',
                'dwd_ir123', 'dwd_ir133']
    else:
        names = ['vis_04', 'vis_05', 'vis_06', 'vis_08', 'vis_09', 'nir_13', 'nir_16',
                'nir_22', 'ir_38', 'wv_63', 'wv_73', 'ir_87', 'ir_97', 'ir_105', 'ir_123',
                'ir_133']
    files = sorted(fci_dir.glob("*BODY*.nc"))

elif sensor == "abi_l1b":
    if dwdchans:
        raise NotImplementedError()
    else:
        names = [f"C{c:>02d}" for c in range(1, 17)]
    files = sorted(abi_dir.glob("OR_ABI-L1b-RadF-M6C*_G17_s202110000503*_e*_c*.nc"))
if comps:
    names.extend(["airmass", "dust", "ash"])

def get_scene(mode="real"):
    if mode == "real":
        sc = Scene(filenames={sensor: [os.fspath(f) for f in files]})
        sc.load(names, upper_right_corner="NE", pad_data=False)
    elif mode == "fake":
        ar = satpy.resample.get_area_def("mtg_fci_fdss_2km")
        sc = make_fake_scene(
            {f"arr{x:d}": xr.DataArray(
                dims=("y", "x"),
                data=da.linspace(180+x, 210-x, ar.size).reshape(ar.shape) +
                     da.random.random(ar.shape))
                for x in range(15)},
            daskify=True,
            area=ar,
            common_attrs={
                "start_time": datetime.datetime(2022, 11, 18, 18),
                "units": "K"})
    return sc

def main():
    sc = get_scene(datamode)
    if resampling:
        if resampling == "native":
            ls = sc.resample(resampler="native")
        else:
            ls = sc.resample(resampling)
    else:
        ls = sc
    (srcs, targs) = ls.save_datasets(writer=writer, enhance=enhance, compute=False, **args_ninjogeotiff)
    if mode == "direct":
        delayeds = [da.store(s, t, compute=False) for (s, t) in group_results_by_output_file(srcs, targs)]
    elif mode == "noop":
        delayeds = [delayed(noop)(da.store(s, t, compute=False), t) for (s, t) in group_results_by_output_file(srcs, targs)]
    elif mode == "close":
        delayeds = [delayed(close)(da.store(s, t, compute=False), t) for (s, t) in group_results_by_output_file(srcs, targs)]

    print("computing with mode", mode)
    print("enhance", enhance)
    print("resampling", resampling)
    print("dwdchans", dwdchans)
    print("comps", comps)
    print("data source", datamode)
    print("sensor", sensor)
    print("dask", dask.__version__)
    dask.compute(delayeds)

if __name__ == "__main__":
    print("using", scheduler)
    with dask.config.set(scheduler=scheduler):
        main()

Let me try to shorten that one :)

@gerritholl
Copy link
Collaborator Author

A somewhat shorter MCVE to reproduce the problem:

import hdf5plugin
import os
import dask
from satpy import Scene
import pathlib
from dask import delayed
from satpy.writers import group_results_by_output_file
import dask.array as da
from pyresample import create_area_def

#mode = "direct"
mode = "close"
area = create_area_def("test", 4087, area_extent=(-9_000_000, -9_000_000, 9_000_000, 9_000_000), resolution=3000)
sensor = "fci_l1c_nc"

fci_dir = pathlib.Path("/media/nas/x21308/MTG_test_data/2022_05_MTG_Testdata/RC0099/")

def close(obj, targs):
    """Close targets."""
    for targ in targs:
        targ.close()
    return obj

names = ['vis_04', 'vis_05', 'vis_06', 'vis_08', 'vis_09', 'nir_13', 'nir_16',
        'nir_22', 'ir_38', 'wv_63', 'wv_73', 'ir_87', 'ir_97', 'ir_105', 'ir_123',
        'ir_133', "airmass", "dust", "ash"]
files = sorted(fci_dir.glob("*BODY*.nc"))

def get_scene():
    sc = Scene(filenames={sensor: [os.fspath(f) for f in files]})
    sc.load(names, upper_right_corner="NE")
    return sc

def main():
    sc = get_scene()
    ls = sc.resample(area)
    (srcs, targs) = ls.save_datasets(writer="geotiff", enhance=True, compute=False)
    if mode == "direct":
        delayeds = [da.store(s, t, compute=False) for (s, t) in group_results_by_output_file(srcs, targs)]
    elif mode == "close":
        delayeds = [delayed(close)(da.store(s, t, compute=False), t) for (s, t) in group_results_by_output_file(srcs, targs)]

    print("computing with mode", mode)
    print("dask", dask.__version__)
    dask.compute(delayeds)

if __name__ == "__main__":
    main()

Resources measured with /usr/bin/time -v:

dask version mode time RAM
2022.12.0 direct
2022.12.0 close
bugfix-store-nocompute-optimize direct
bugfix-store-nocompute-optimize close

The shorter MCVE does not reproduce the problem despite the only difference being conditional :-/

@gerritholl
Copy link
Collaborator Author

I will fill the earlier table when I can reproduce my earlier results…
NB: I should check and try pure=True for delayed.

@djhoese
Copy link
Member

djhoese commented Dec 16, 2022

Ok so "good news" is that I get similar results to you when using ABI data and going to an eqc area that's a little larger than CONUS. My processing hovers between 39-45s, but with my PR I can't get it faster than ~50s. Looking at the dask diagnostic plots I can see that it is very clearly not executing tasks in the same order. Diving into the code I think because we have a Delayed object dask is completely ignoring all optimizations it could do to the graph related to array logic. If I force it to use array logic then I get some closer numbers, but the graph doesn't seem like what I expect still.

Script looks like this now:

Code
#import hdf5plugin
import os
import dask
from datetime import datetime
from satpy import Scene
import pathlib
from dask import delayed
from satpy.writers import group_results_by_output_file
import dask.array as da
from pyresample import create_area_def

mode = "direct"
#mode = "close"
# mode = 'allinone'
sensor = "abi"


def close(obj, targs):
    """Close targets."""
    for targ in targs:
        targ.close()
    return obj


def get_fci_scene():
    sensor = "fci_l1c_nc"

    fci_dir = pathlib.Path("/media/nas/x21308/MTG_test_data/2022_05_MTG_Testdata/RC0099/")
    names = ['vis_04', 'vis_05', 'vis_06', 'vis_08', 'vis_09', 'nir_13', 'nir_16',
            'nir_22', 'ir_38', 'wv_63', 'wv_73', 'ir_87', 'ir_97', 'ir_105', 'ir_123',
            'ir_133', "airmass", "dust", "ash"]
    files = sorted(fci_dir.glob("*BODY*.nc"))
    area = create_area_def("test", 4087, area_extent=(-9_000_000, -9_000_000, 9_000_000, 9_000_000), resolution=3000)

    sc = Scene(filenames={sensor: [os.fspath(f) for f in files]})
    sc.load(names, upper_right_corner="NE")
    ls = sc.resample(area)
    return ls


def get_abi_scene():
    files = pathlib.Path("/data/satellite/abi/2018253").glob("*RadF*.nc")
    names = [f"C{x:02d}" for x in range(1, 17)] + ["airmass", "ash", "dust"]
    area = create_area_def("test", 4087, area_extent=(-10_000_000, 1_000_000, -2_000_000, 6_000_000), resolution=3000)

    sc = Scene(reader='abi_l1b', filenames=[os.fspath(f) for f in files])
    sc.load(names)
    ls = sc.resample(area)
    return ls


def main():
    if sensor == "abi":
        ls = get_abi_scene()
    elif sensor == "fci":
        ls = get_fci_scene()
    (srcs, targs) = ls.save_datasets(writer="geotiff", enhance=True, compute=False)
    if mode == "direct":
        delayeds = [da.store(s, t, compute=False) for (s, t) in group_results_by_output_file(srcs, targs)]
    elif mode == "close":
        delayeds = [delayed(close)(da.store(s, t, compute=False), t) for (s, t) in group_results_by_output_file(srcs, targs)]
    elif mode == "allinone":
        delayeds = [da.store(srcs, targs, compute=False)]

    print(f"{sensor=}")
    print(f"{mode=}")
    print(f"{dask.__version__=}")
    with dask.config.set(delayed_optimize=da.optimization.optimize):
        dask.compute(delayeds)


if __name__ == "__main__":
    from dask.diagnostics import Profiler, ResourceProfiler, CacheProfiler, visualize
    with Profiler() as prof, ResourceProfiler() as rprof, CacheProfiler() as cprof:
        init_task = dask.delayed(lambda x: x)(1).compute()
        main()
    filename = f"profile_store_{sensor}_{mode}_{dask.__version__}_{datetime.utcnow():%Y%m%d_%H%M%S}.html"
    visualize([prof, rprof, cprof], show=False, filename=filename)
    cwd = os.getcwd()
    print(f"file://{cwd}/{filename}")

Note in the above code I use a "starter task" to make the resource profile plots line up with the other graphs. This is what I fixed in my other dask PR.

Also note the "with dask.config.set" line. That's not used in the first two images below, but is used in the last two. This forced delayed computations to be optimized like dask arrays.

Here's what 2022.12.0 looks like:

image

Here's what my PR looks like:

image

See the similar tasks being stacked on the left? Those are all open_dataset tasks loading all the data from ABI.

And here's 2022.12.0 when I force delayed graphs to be optimized like arrays:

image

And my PR with delayed optimizations like dask arrays:

image

So the task graphs still look pretty similar, but at least it computed faster.

@gerritholl gerritholl merged commit 864983c into pytroll:main Feb 2, 2023
@gerritholl gerritholl deleted the split-by-file branch February 2, 2023 07:41
@mraspaud mraspaud added this to the v0.40.0 milestone Feb 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component:writers enhancement code enhancements, features, improvements
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants