Skip to content

RMSE-bsed metrics¤

exponax.metrics.RMSE ¤

RMSE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Optional[Float[Array, "C ... N"]] = None,
    domain_extent: float = 1.0,
) -> float

Compute the root mean squared error (RMSE) between two fields.

This function assumes that the arrays have one leading channel axis and an arbitrary number of following spatial dimensions! For batched operation use jax.vmap on this function or use the exponax.metrics.mean_RMSE function.

Arguments: - u_pred (array): The first field to be used in the error computation. - u_ref (array, optional): The second field to be used in the error computation. If None, the error will be computed with respect to zero. - domain_extent (float, optional): The extent of the domain in which the fields are defined. This is used to scale the error to be independent of the domain size. Default is 1.0.

Returns: - rmse (float): The (correctly scaled) root mean squared error between the fields.

Source code in exponax/_metrics.py
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def RMSE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Optional[Float[Array, "C ... N"]] = None,
    domain_extent: float = 1.0,
) -> float:
    """
    Compute the root mean squared error (RMSE) between two fields.

    This function assumes that the arrays have one leading channel axis and an
    arbitrary number of following spatial dimensions! For batched operation use
    `jax.vmap` on this function or use the [`exponax.metrics.mean_RMSE`][] function.

    **Arguments**:
        - `u_pred` (array): The first field to be used in the error computation.
        - `u_ref` (array, optional): The second field to be used in the error
            computation. If `None`, the error will be computed with respect to
            zero.
        - `domain_extent` (float, optional): The extent of the domain in which
            the fields are defined. This is used to scale the error to be
            independent of the domain size. Default is 1.0.

    **Returns**:
        - `rmse` (float): The (correctly scaled) root mean squared error between
            the fields.
    """

    num_spatial_dims = len(u_pred.shape) - 1

    rmse = _RMSE(u_pred, u_ref, domain_extent, num_spatial_dims=num_spatial_dims)

    return rmse

exponax.metrics.nRMSE ¤

nRMSE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Float[Array, "C ... N"],
) -> float

Compute the normalized root mean squared error (nRMSE) between two fields.

In contrast to exponax.metrics.RMSE, no domain_extent is required, because of the normalization.

Arguments: - u_pred (array): The first field to be used in the error computation. - u_ref (array): The second field to be used in the error computation.

Returns: - nrmse (float): The normalized root mean squared error between the fields

Source code in exponax/_metrics.py
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
def nRMSE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Float[Array, "C ... N"],
) -> float:
    """
    Compute the normalized root mean squared error (nRMSE) between two fields.

    In contrast to [`exponax.metrics.RMSE`][], no `domain_extent` is required, because of
    the normalization.

    **Arguments**:
        - `u_pred` (array): The first field to be used in the error computation.
        - `u_ref` (array): The second field to be used in the error computation.

    **Returns**:
        - `nrmse` (float): The normalized root mean squared error between the
            fields
    """

    num_spatial_dims = len(u_pred.shape) - 1

    # Do not have to supply the domain_extent, because we will normalize with
    # the ref_rmse
    diff_rmse = _RMSE(u_pred, u_ref, num_spatial_dims=num_spatial_dims)
    ref_rmse = _RMSE(u_ref, num_spatial_dims=num_spatial_dims)

    nrmse = diff_rmse / ref_rmse

    return nrmse

exponax.metrics.mean_RMSE ¤

mean_RMSE(
    u_pred: Float[Array, "B C ... N"],
    u_ref: Float[Array, "B C ... N"],
    domain_extent: float = 1.0,
) -> float

Compute the mean RMSE between two fields. Use this function to correctly operate on arrays with a batch axis.

Arguments: - u_pred (array): The first field to be used in the error computation. - u_ref (array): The second field to be used in the error computation. - domain_extent (float, optional): The extent of the domain in which

Returns: - mean_rmse (float): The mean root mean squared error between the fields

Source code in exponax/_metrics.py
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
def mean_RMSE(
    u_pred: Float[Array, "B C ... N"],
    u_ref: Float[Array, "B C ... N"],
    domain_extent: float = 1.0,
) -> float:
    """
    Compute the mean RMSE between two fields. Use this function to correctly
    operate on arrays with a batch axis.

    **Arguments**:
        - `u_pred` (array): The first field to be used in the error computation.
        - `u_ref` (array): The second field to be used in the error computation.
        - `domain_extent` (float, optional): The extent of the domain in which

    **Returns**:
        - `mean_rmse` (float): The mean root mean squared error between the
            fields
    """
    batch_wise_rmse = jax.vmap(RMSE, in_axes=(0, 0, None))(u_pred, u_ref, domain_extent)
    mean_rmse = jnp.mean(batch_wise_rmse)
    return mean_rmse

exponax.metrics.mean_nRMSE ¤

mean_nRMSE(
    u_pred: Float[Array, "B C ... N"],
    u_ref: Float[Array, "B C ... N"],
)

Compute the mean nRMSE between two fields. Use this function to correctly operate on arrays with a batch axis.

Arguments: - u_pred (array): The first field to be used in the error computation. - u_ref (array): The second field to be used in the error computation.

Returns: - mean_nrmse (float): The mean normalized root mean squared error

Source code in exponax/_metrics.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
def mean_nRMSE(
    u_pred: Float[Array, "B C ... N"],
    u_ref: Float[Array, "B C ... N"],
):
    """
    Compute the mean nRMSE between two fields. Use this function to correctly
    operate on arrays with a batch axis.

    **Arguments**:
        - `u_pred` (array): The first field to be used in the error computation.
        - `u_ref` (array): The second field to be used in the error computation.

    **Returns**:
        - `mean_nrmse` (float): The mean normalized root mean squared error
    """
    batch_wise_nrmse = jax.vmap(nRMSE)(u_pred, u_ref)
    mean_nrmse = jnp.mean(batch_wise_nrmse)
    return mean_nrmse

exponax.metrics._RMSE ¤

_RMSE(
    u_pred: Float[Array, "... N"],
    u_ref: Optional[Float[Array, "... N"]] = None,
    domain_extent: float = 1.0,
    *,
    num_spatial_dims: Optional[int] = None
) -> float

Low-level function to compute the root mean squared error (RMSE) correctly scaled for states representing physical fields on uniform Cartesian grids.

RMSE = sqrt(1/L^D * 1/N * sum_i (u_pred_i - u_ref_i)^2)

Note that by default (num_spatial_dims=None), the number of spatial dimensions is inferred from the shape of the input fields. Please adjust this argument if you call this function with an array that also contains channels (even for arrays with singleton channels!).

Providing correct information regarding the scaling (i.e. providing domain_extent and num_spatial_dims) is not necessary if the result is used to compute a normalized error (e.g. nRMSE) if the normalization is computed similarly.

Arguments: - u_pred (array): The first field to be used in the loss - u_ref (array, optional): The second field to be used in the error computation. If None, the error will be computed with respect to zero. - domain_extent (float, optional): The extent of the domain in which the fields are defined. This is used to scale the error to be independent of the domain size. Default is 1.0. - num_spatial_dims (int, optional): The number of spatial dimensions in the field. If None, it will be inferred from the shape of the input fields and then is the number of axes present. Default is None.

Returns: - rmse (float): The (correctly scaled) root mean squared error between the fields.

Source code in exponax/_metrics.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def _RMSE(
    u_pred: Float[Array, "... N"],
    u_ref: Optional[Float[Array, "... N"]] = None,
    domain_extent: float = 1.0,
    *,
    num_spatial_dims: Optional[int] = None,
) -> float:
    """
    Low-level function to compute the root mean squared error (RMSE) correctly
    scaled for states representing physical fields on uniform Cartesian grids.

    RMSE = sqrt(1/L^D * 1/N * sum_i (u_pred_i - u_ref_i)^2)

    Note that by default (`num_spatial_dims=None`), the number of spatial
    dimensions is inferred from the shape of the input fields. Please adjust
    this argument if you call this function with an array that also contains
    channels (even for arrays with singleton channels!).

    Providing correct information regarding the scaling (i.e. providing
    `domain_extent` and `num_spatial_dims`) is not necessary if the result is
    used to compute a normalized error (e.g. nRMSE) if the normalization is
    computed similarly.

    **Arguments**:
        - `u_pred` (array): The first field to be used in the loss
        - `u_ref` (array, optional): The second field to be used in the error
            computation. If `None`, the error will be computed with respect to
            zero.
        - `domain_extent` (float, optional): The extent of the domain in which
            the fields are defined. This is used to scale the error to be
            independent of the domain size. Default is 1.0.
        - `num_spatial_dims` (int, optional): The number of spatial dimensions
            in the field. If `None`, it will be inferred from the shape of the
            input fields and then is the number of axes present. Default is
            `None`.

    **Returns**:
        - `rmse` (float): The (correctly scaled) root mean squared error between
          the fields.
    """
    if u_ref is None:
        diff = u_pred
    else:
        diff = u_pred - u_ref

    if num_spatial_dims is None:
        # Assuming that we only have spatial dimensions
        num_spatial_dims = len(u_pred.shape)

    # Todo: Check if we have to divide by 1/L or by 1/L^D for D dimensions
    scale = 1 / (domain_extent**num_spatial_dims)

    rmse = jnp.sqrt(scale * jnp.mean(jnp.square(diff)))
    return rmse