Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize ndrolling nanreduce #4325

Open
fujiisoup opened this issue Aug 8, 2020 · 5 comments
Open

Optimize ndrolling nanreduce #4325

fujiisoup opened this issue Aug 8, 2020 · 5 comments

Comments

@fujiisoup
Copy link
Member

In #4219 we added ndrolling.
However, nanreduce, such as ds.rolling(x=3, y=2).mean() calls np.nanmean which copies the strided-array into a full-array.
This is memory-inefficient.

We can implement inhouse-nanreduce methods for the strided array.
For example, our .nansum currently does
make a strided array -> copy the array -> replace nan by 0 -> sum
but we can do instead
replace nan by 0 -> make a strided array -> sum
This is much more memory efficient.

@mathause
Copy link
Collaborator

mathause commented Oct 26, 2020

This is already done for counts, correct? Here:

# We use False as the fill_value instead of np.nan, since boolean
# array is faster to be reduced than object array.
# The use of skipna==False is also faster since it does not need to
# copy the strided array.
counts = (
self.obj.notnull()
.rolling(
center={d: self.center[i] for i, d in enumerate(self.dim)},
**{d: w for d, w in zip(self.dim, self.window)},
)
.construct(rolling_dim, fill_value=False)
.sum(dim=list(rolling_dim.values()), skipna=False)
)

This should work for most of the reductions (and is a bit similar to what is done in weighted for mean and sum):

  • count: isnull() -> rolling -> sum
  • argmax: fillna(-inf) -> rolling -> argmax
  • argmin: fillna(inf) -> rolling -> argmin
  • max: fillna(-inf) -> rolling -> max (not sure about this one, need to be careful with the dtype)
  • min: fillna(inf) -> rolling -> min (dito)
  • mean: fillna(0) -> rolling -> sum / count (ensure nan if count == 0)
  • prod: fillna(1) -> rolling -> prod
  • sum: fillna(0) -> rolling -> sum
  • var: fillna(0) -> rolling -> possible (?) but a bit more involved
  • std: sqrt(var)
  • median: probably not possible

I think this should not be too difficult, the thing is that rolling itself is already quite complicated

@fujiisoup
Copy link
Member Author

@mathause
Oh, I missed this issue.
Yes, this is implemented only for count.

the thing is that rolling itself is already quite complicated

Agreed.
We need to clean this up.

One possible option would be to drop support of bottleneck.
This does not work for nd-rolling and if we implement the nd-nanreduce, the speed should be comparable with bottleneck.

@mathause
Copy link
Collaborator

mathause commented Dec 9, 2020

I just saw that numpy 1.20 introduces stride_tricks.sliding_window_view. I have not looked at this yet. Just leaving this here for reference.

https://numpy.org/devdocs/reference/generated/numpy.lib.stride_tricks.sliding_window_view.html#numpy.lib.stride_tricks.sliding_window_view

https://numpy.org/devdocs/release/1.20.0-notes.html#sliding-window-view-provides-a-sliding-window-view-for-numpy-arrays

numpy/numpy#17394

@dcherian dcherian changed the title Needs an improvement for ndrolling nanreduce Optimize ndrolling nanreduce Apr 13, 2023
@tbloch1
Copy link

tbloch1 commented Apr 13, 2023

I think I may have found a way to make the variance/standard deviation calculation more memory efficient, but I don't know enough about writing the sort of code that would be needed for a PR.

I basically wrote out the calculation for variance trying to only use the functions that have already been optimsed. Derived from:

$$ var = \frac{1}{n} \sum_{i=1}^{n} (x_i - \mu)^2 $$

$$ var = \frac{1}{n} \left( (x_1 - \mu)^2 + (x_2 - \mu)^2 + (x_3 - \mu)^2 + ... \right) $$

$$ var = \frac{1}{n} \left(x_1^2 -2x_1\mu + \mu^2 + \ x_2^2 -2x_2\mu + \mu^2 + \ x_3^2 -2x_3\mu + \mu^2 + ... \right) $$

$$ var = \frac{1}{n} \left( \sum_{i=1}^{n} x_i^2 - 2\mu\sum_{i=1}^{n} x_i + n\mu^2 \right)$$

I coded this up and demonstrate that it uses approximately 10% of the memory as the current .var() implementation:

%load_ext memory_profiler

import numpy as np
import xarray as xr

temp = xr.DataArray(np.random.randint(0, 10, (5000, 500)), dims=("x", "y"))

def new_var(da, x=10, y=20):
    # Defining the re-used parts
    roll = da.rolling(x=x, y=y)
    mean = roll.mean()
    count = roll.count()
    # First term: sum of squared values
    term1 = (da**2).rolling(x=x, y=y).sum()
    # Second term cross term sum
    term2 = -2 * mean * roll.sum()
    # Third term 'sum' of squared means
    term3 = count * mean**2
    # Combining into the variance
    var = (term1 + term2 + term3) / count
    return var

def old_var(da, x=10, y=20):
    roll = da.rolling(x=x, y=y)
    var = roll.var()
    return var

%memit new_var(temp)
%memit old_var(temp)
peak memory: 429.77 MiB, increment: 134.92 MiB
peak memory: 5064.07 MiB, increment: 4768.45 MiB

I wanted to double check that the calculation was working correctly:

print((var_o.where(~np.isnan(var_o), 0) == var_n.where(~np.isnan(var_n), 0)).all().values)
print(np.allclose(var_o, var_n, equal_nan = True))
False
True

I think the difference here is just due to floating point errors, but maybe someone who knows how to check that in more detail could have a look.

The standard deviation can be trivially implemented from this if the approach works.

@dcherian
Copy link
Contributor

Over in #7344 (comment) @shoyer

That said -- we could also switch to smarter NumPy based algorithms to implement most moving window calculations, e.g,. using np.nancumsum for moving window means.

After some digging, this would involve using "summed area tables" which have been generalized to nD, and can be used to compute all our built-in reductions (except median). Basically we'd store the summed area table (repeated np.cumsum) and then calculate reductions using binary ops (mostly subtraction) on those tables.

This would be an intermediate level project but we could implement it incrementally (start with sum for example). One downside is the potential for floating point inaccuracies because we're taking differences of potentially large numbers.

cc @aulemahal

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants