pythonpython-xarray

How to do a left-join on a data variable's values using xarray?


Imagine having a dataset of shipments where every shipID can contain multiple products:

In [2]: shipDS
Out[2]:
<xarray.Dataset>
Dimensions:     (shipID: 6)
Coordinates:
  * shipID      (shipID) int64 1 1 2 3 4 4
Data variables:
    prodID      (shipID) int64 90 91 92 90 90 91
    qtyShipped  (shipID) int64 1 1 1 1 1 1

In a separate dataset, each product has some stored attributes (name, weight, etc.):

In [3]: prodDS
Out[3]:
<xarray.Dataset>
Dimensions:     (prodID: 4)
Coordinates:
  * prodID      (prodID) int64 90 91 92 93
Data variables:
    prodName    (prodID) <U8 'Almonds' 'Berries' 'Candy' 'Dog Food'
    prodWeight  (prodID) float64 0.5 1.5 1.0 12.0

I love xarray and want to use it as much as possible without reverting to Pandas. However, I cannot seem to answer a very basic question in the scenario above, namely "how much did each shipment weigh?"

Here is the code to replicate the two datasets:

import xarray as xr

ds = xr.Dataset({"a": ("index", list("xyzxyzxy"))})
other = xr.Dataset({"a": list("xz"), "b": ("a", [1, 2])})

## construct dataset of shipments
## each shipment might have multiple products
shipIDs = [1,1,2,3,4,4]
productIDs = [90,91,92,90,90,91]
quantity = [1,1,1,1,1,1]

shipDS = xr.Dataset(data_vars = 
    {
        "prodID": ("shipID", productIDs),
        "qtyShipped": ("shipID", quantity)
    },
    coords = {
        "shipID": shipIDs
    }
)

## construct product info dataset
prodID = [90,91,92,93]
prodName = ["Almonds","Berries","Candy","Dog Food"]
prodWeight = [0.5,1.5,1.0,12]

prodDS = xr.Dataset(data_vars = 
    {
        "prodName": ("prodID", prodName),
        "prodWeight": ("prodID", prodWeight)
    },
    coords = {
        "prodID": prodID
    }
)

## MOTIVATING QUESTION: HOW MUCH WEIGHT WAS IN EACH SHIPMENT?
## PART I AM STUCK ON IS JUST JOINING THE DATASETS.  HELP APPRECIATED!

In one's head, it is easy to imagine how to find the total weight for each shipID.

Using merge like

shipDS.set_coords("prodID").merge(prodDS).to_dataframe()

seems to require making prodID a coordinate to work which expands the space to be way too big; shipID should be the only coordinate/dimension and "dog food" should not be part of the new dataset as it was never shipped.

Here is the code to get an answer using pandas:

In [5]: ## ANSWER USING PANDAS
   ...: (
   ...:     shipDS   ## start with shipment dataset
   ...:     .to_dataframe()  ## convert ot pandas df
   ...:     .join(prodDS.to_dataframe(), on="prodID")   #left join with product info df
   ...:     .reset_index()  # reset index
   ...:     .groupby("shipID")  # group by shipID
   ...:     .agg(totalWt = ("prodWeight", sum))  # get total weight
   ...: )
Out[5]:
        totalWt
shipID
1           2.0
2           1.0
3           0.5
4           2.0

How can this be done without reverting to pandas?


Solution

  • You can use xarray's Advanced Indexing rules to reindex prodDS to conform to the dimensions of shipDS:

    In [4]: products_shipped = prodDS.sel(prodID=shipDS.prodID)
       ...: products_shipped
    Out[4]:
    <xarray.Dataset>
    Dimensions:     (shipID: 6)
    Coordinates:
        prodID      (shipID) int64 90 91 92 90 90 91
      * shipID      (shipID) int64 1 1 2 3 4 4
    Data variables:
        prodName    (shipID) <U8 'Almonds' 'Berries' 'Candy' ... 'Almonds' 'Berries'
        prodWeight  (shipID) float64 0.5 1.5 1.0 0.5 0.5 1.5
    

    This happens because the indexer, shipDS.prodID, is a DataVariable. Under the Advanced Indexing rules, if you select using a DataArray, the values of the indexing DataArray are used to select/reshape the data along the selected dataset's dimensions (prodID), but the dimensions and coordinates of the indexing DataArray (shipID) will be used as the dimensions and coordinates of the result! So now we have product names and weights by shipID! Notice that this is exactly a left join in the pandas sense - we have all the elements of prodDS, reindexed to match the product IDs present on each ship in shipDS.

    Now, using this reindexed result, the weight carried on each ship is a simple product:

    In [6]: shipDS.qtyShipped * products_shipped.prodWeight
    Out[6]:
    <xarray.DataArray (shipID: 6)>
    array([0.5, 1.5, 1. , 0.5, 0.5, 1.5])
    Coordinates:
      * shipID   (shipID) int64 1 1 2 3 4 4
        prodID   (shipID) int64 90 91 92 90 90 91
    

    And voila! Now you can love xarray even more ;)