I am developing a relatively large model using xarray and therefore want to make use of chunks. Most of my operations run a lot faster when chunked but there is one that keeps running (a lot) slower than unchunked.
Take this dataarray as an example (the real one is a lot bigger)
import xarray as xr
import numpy as np
da = xr.DataArray(
np.random.randint(low=-5, high=5, size=(10, 100, 100)),
coords=[range(10), range(100), range(100)],
dims=["time", "x", "y"],
).chunk(chunks={"x":10, "y":10}) # comment this to opt out of chunking
my_max = 10
def sum_storage(arr):
arr_new = xr.zeros_like(arr)
for idx in range(1,len(arr)):
arr_new[dict(time=idx)] = arr_new[dict(time=idx-1)] + da[dict(time=idx)]
arr_storage[dict(time=idx)] = arr_storage[dict(time=idx)].where(arr_storage[dict(time=idx)] <= my_max, my_max)
return arr_new
%time arr_storage = sum_storage(da)
I ran this unchunked and chunked. See below the CPU times.
Unchunked: CPU times: total: 0 ns Chunked: CPU times: total: 6.72 s
I have tried .rolling and .np_apply_along_axis following other suggestions (e.g., https://github.com/pydata/xarray/discussions/6247) and applying the ufunc on dask array (https://tutorial.xarray.dev/advanced/apply_ufunc/dask_apply_ufunc.html). I also called a function using xr.concat e.g.:
xr.concat([my_function(da[dict(time=i)]) for i in my_list], dim=da.time)
but all solutions I looked at are iterating on only one dataarray. Whereas my example has two dataarrays that need to be iterated over at the same time, including the fact that for one of the arrays the idx and idx-1 indices need to be accessed (and not for the others).
Any suggestions?