Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Reduce xarray.Dataset by custom function

I want to use xarray functionality to reduce a dataset by a custom/external function across a named dimension.

Create dataset to demonstrate the problem

import xarray as xr 
import numpy as np
import pandas as pd 

time = pd.date_range("2000-01-01", "2001-01-01", freq="D")
sids = np.arange(4)
obs = np.random.random(size=(len(time), len(sids)))
sim = np.random.random(size=(len(time), len(sids)))

original = xr.Dataset({"obs": (("time", "station_id"), obs), "sim": (("time", "station_id"), sim)}, coords={"time": time, "station_id": sids})

I want to calculate the mean_squared_error using the two variables in original, calculating the metric by collapsing the "time" dimension. This should return an xr.Dataset like the following:

<xarray.Dataset>
Dimensions:             (station_id: 4)
Coordinates:
  * station_id          (station_id) int64 0 1 2 3
Data variables:
    mean_squared_error  (station_id) float64 0.4411 0.183 0.06754 0.9662

I have tried using the reduce function

from sklearn.metrics import mean_squared_error

original.reduce(mean_squared_error, dim="time")
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-243-51111f05437b> in <module>
----> 1 original.reduce(mean_squared_error, dim="time")

~/miniconda3/envs/ml/lib/python3.8/site-packages/xarray/core/dataset.py in reduce(self, func, dim, keep_attrs, keepdims, numeric_only, **kwargs)
   4915                         # the former is often more efficient
   4916                         reduce_dims = None  # type: ignore[assignment]
-> 4917                     variables[name] = var.reduce(
   4918                         func,
   4919                         dim=reduce_dims,

~/miniconda3/envs/ml/lib/python3.8/site-packages/xarray/core/variable.py in reduce(self, func, dim, axis, keep_attrs, keepdims, **kwargs)
   1721             )
   1722             if axis is not None:
-> 1723                 data = func(self.data, axis=axis, **kwargs)
   1724             else:
   1725                 data = func(self.data, **kwargs)

~/miniconda3/envs/ml/lib/python3.8/site-packages/sklearn/utils/validation.py in inner_f(*args, **kwargs)
     70                           FutureWarning)
     71         kwargs.update({k: arg for k, arg in zip(sig.parameters, args)})
---> 72         return f(**kwargs)
     73     return inner_f
     74 

TypeError: mean_squared_error() got an unexpected keyword argument 'axis'
like image 656
Tommy Lees Avatar asked Jan 21 '26 11:01

Tommy Lees


1 Answers

There is a package called xskillscore, which has a method to calculate the MSE.

pip install xskillscore

xskillscore.mse(original.obs, original.sim, 'time')
like image 146
Saverio Guzzo Avatar answered Jan 24 '26 00:01

Saverio Guzzo