pythonoptimizationmeanpython-xarraydata-cube

Xarray most efficient way to select variable and calculate its mean


I have a datacube of 3Gb opened with xarray that has 3 variables I'm interested in (v, vx, vy). The description is below with the code.

I am interested only in one specific time window spanning between 2009 and 2013, while the entire dataset spans from 1984 to 2018.

What I want to do is:

The issue is that it takes so much time that after 1 hour, the few lines of code I wrote were still running. What I don't understand is that if I save my "v" values as an array, load them as such and calculate their mean, it takes way less time than doing what I wrote below (see code). I don't know if there is a memory leak, or if it is just a terrible way of doing it. My pc has 16Gb of RAM, of which 60% is available before loading the datacube. So theoritically it should have enough RAM to compute everything.

What would be an efficient way to truncate my datacube to the desired time-window, then calculate the temporal mean (over axis 0) of the 3 variables "v", "vx", "vy" ?

I tried doing it like that:

datacube = xr.open_dataset('datacube.nc')  # Load the datacube
datacube = datacube.reindex(mid_date = sorted(datacube.mid_date.values))  # Sort the datacube by ascending time, where "mid_date" is the time dimension
    
sdate = '2009-01'   # Start date
edate = '2013-12'   # End date
    
ds = datacube.sel(mid_date = slice(sdate, edate))   # Create a new datacube gathering only the values between the start and end dates
    
vvtot = np.nanmean(ds.v.values, axis=0)   # Calculate the mean of the values of the "v" variable of the new datacube
vxtot = np.nanmean(ds.vx.values, axis=0)
vytot = np.nanmean(ds.vy.values, axis=0)






Dimensions:                    (mid_date: 18206, y: 334, x: 333)
Coordinates:
  * mid_date                   (mid_date) datetime64[ns] 1984-06-10T00:00:00....
  * x                          (x) float64 4.868e+05 4.871e+05 ... 5.665e+05
  * y                          (y) float64 6.696e+06 6.696e+06 ... 6.616e+06
Data variables: (12/43)
    UTM_Projection             object ...
    acquisition_img1           (mid_date) datetime64[ns] ...
    acquisition_img2           (mid_date) datetime64[ns] ...
    autoRIFT_software_version  (mid_date) float64 ...
    chip_size_height           (mid_date, y, x) float32 ...
    chip_size_width            (mid_date, y, x) float32 ...
                        ...
    vy                         (mid_date, y, x) float32 ...
    vy_error                   (mid_date) float32 ...
    vy_stable_shift            (mid_date) float64 ...
    vyp                        (mid_date, y, x) float64 ...
    vyp_error                  (mid_date) float64 ...
    vyp_stable_shift           (mid_date) float64 ...
Attributes:
    GDAL_AREA_OR_POINT:         Area
    datacube_software_version:  1.0
    date_created:               30-01-2021 20:49:16
    date_updated:               30-01-2021 20:49:16
    projection:                 32607

Solution

  • Try to avoid calling ".values" in between, because when you do that you are switching to np.array instead of xr.DataArray!

    import xarray as xr
    from dask.diagnostics import ProgressBar
    
    # Open the dataset using chunks.
    ds = xr.open_dataset(r"/path/to/you/data/test.nc", chunks = "auto")
    
    # Select the period you want to have the mean for. 
    ds = ds.sel(time = slice(sdate, edate))
    
    # Calculate the mean for all the variables in your ds.
    ds = ds.mean(dim = "time")
    
    # The above code takes less than a second, because no actual
    # calculations have been done yet (and no data has been loaded into your RAM).
    # Once you use ".values", ".compute()", or
    # ".to_netcdf()" they will be done. We can see progress like this:
    with ProgressBar():
        ds = ds.compute()