Skip to content

MSE-based metrics¤

exponax.metrics.MSE ¤

MSE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Optional[Float[Array, "C ... N"]] = None,
    domain_extent: float = 1.0,
)

Compute the mean squared error (MSE) between two fields.

This function assumes that the arrays have one leading channel axis and an arbitrary number of following spatial dimensions! For batched operation use jax.vmap on this function or use the exponax.metrics.mean_MSE function.

Arguments: - u_pred (array): The first field to be used in the error computation. - u_ref (array, optional): The second field to be used in the error computation. If None, the error will be computed with respect to zero. - domain_extent (float, optional): The extent of the domain in which the fields are defined. This is used to scale the error to be independent of the domain size. Default is 1.0.

Returns: - mse (float): The (correctly scaled) mean squared error between the fields.

Source code in exponax/_metrics.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def MSE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Optional[Float[Array, "C ... N"]] = None,
    domain_extent: float = 1.0,
):
    """
    Compute the mean squared error (MSE) between two fields.

    This function assumes that the arrays have one leading channel axis and an
    arbitrary number of following spatial dimensions! For batched operation use
    `jax.vmap` on this function or use the [`exponax.metrics.mean_MSE`][] function.

    **Arguments**:
        - `u_pred` (array): The first field to be used in the error computation.
        - `u_ref` (array, optional): The second field to be used in the error
            computation. If `None`, the error will be computed with respect to
            zero.
        - `domain_extent` (float, optional): The extent of the domain in which
            the fields are defined. This is used to scale the error to be
            independent of the domain size. Default is 1.0.

    **Returns**:
        - `mse` (float): The (correctly scaled) mean squared error between the
            fields.
    """

    num_spatial_dims = len(u_pred.shape) - 1

    mse = _MSE(u_pred, u_ref, domain_extent, num_spatial_dims=num_spatial_dims)

    return mse

exponax.metrics.nMSE ¤

nMSE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Float[Array, "C ... N"],
) -> float

Compute the normalized mean squared error (nMSE) between two fields.

In contrast to exponax.metrics.MSE, no domain_extent is required, because of the normalization.

Arguments: - u_pred (array): The first field to be used in the error computation. - u_ref (array): The second field to be used in the error computation. This is also used to normalize the error.

Returns: - nmse (float): The normalized mean squared error between the fields

Source code in exponax/_metrics.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def nMSE(
    u_pred: Float[Array, "C ... N"],
    u_ref: Float[Array, "C ... N"],
) -> float:
    """
    Compute the normalized mean squared error (nMSE) between two fields.

    In contrast to [`exponax.metrics.MSE`][], no `domain_extent` is required, because of the
    normalization.

    **Arguments**:
        - `u_pred` (array): The first field to be used in the error computation.
        - `u_ref` (array): The second field to be used in the error computation.
            This is also used to normalize the error.

    **Returns**:
        - `nmse` (float): The normalized mean squared error between the fields
    """

    num_spatial_dims = len(u_pred.shape) - 1

    # Do not have to supply the domain_extent, because we will normalize with
    # the ref_mse
    diff_mse = _MSE(u_pred, u_ref, num_spatial_dims=num_spatial_dims)
    ref_mse = _MSE(u_ref, num_spatial_dims=num_spatial_dims)

    nmse = diff_mse / ref_mse

    return nmse

exponax.metrics.mean_MSE ¤

mean_MSE(
    u_pred: Float[Array, "B C ... N"],
    u_ref: Float[Array, "B C ... N"],
    domain_extent: float = 1.0,
) -> float

Compute the mean MSE between two fields. Use this function to correctly operate on arrays with a batch axis.

Arguments: - u_pred (array): The first field to be used in the error computation. - u_ref (array): The second field to be used in the error computation. - domain_extent (float, optional): The extent of the domain in which the fields are defined. This is used to scale the error to be independent of the domain size. Default is 1.0.

Returns: - mean_mse (float): The mean mean squared error between the fields

Source code in exponax/_metrics.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def mean_MSE(
    u_pred: Float[Array, "B C ... N"],
    u_ref: Float[Array, "B C ... N"],
    domain_extent: float = 1.0,
) -> float:
    """
    Compute the mean MSE between two fields. Use this function to correctly
    operate on arrays with a batch axis.

    **Arguments**:
        - `u_pred` (array): The first field to be used in the error computation.
        - `u_ref` (array): The second field to be used in the error computation.
        - `domain_extent` (float, optional): The extent of the domain in which
            the fields are defined. This is used to scale the error to be
            independent of the domain size. Default is 1.0.

    **Returns**:
        - `mean_mse` (float): The mean mean squared error between the fields
    """
    batch_wise_mse = jax.vmap(MSE, in_axes=(0, 0, None))(u_pred, u_ref, domain_extent)
    mean_mse = jnp.mean(batch_wise_mse)
    return mean_mse

exponax.metrics.mean_nMSE ¤

mean_nMSE(
    u_pred: Float[Array, "B C ... N"],
    u_ref: Float[Array, "B C ... N"],
)

Compute the mean nMSE between two fields. Use this function to correctly operate on arrays with a batch axis.

Arguments: - u_pred (array): The first field to be used in the error computation. - u_ref (array): The second field to be used in the error computation.

Returns: - mean_nmse (float): The mean normalized mean squared error between

Source code in exponax/_metrics.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def mean_nMSE(
    u_pred: Float[Array, "B C ... N"],
    u_ref: Float[Array, "B C ... N"],
):
    """
    Compute the mean nMSE between two fields. Use this function to correctly
    operate on arrays with a batch axis.

    **Arguments**:
        - `u_pred` (array): The first field to be used in the error computation.
        - `u_ref` (array): The second field to be used in the error computation.

    **Returns**:
        - `mean_nmse` (float): The mean normalized mean squared error between
    """
    batch_wise_nmse = jax.vmap(nMSE)(u_pred, u_ref)
    mean_nmse = jnp.mean(batch_wise_nmse)
    return mean_nmse

exponax.metrics._MSE ¤

_MSE(
    u_pred: Float[Array, "... N"],
    u_ref: Optional[Float[Array, "... N"]] = None,
    domain_extent: float = 1.0,
    *,
    num_spatial_dims: Optional[int] = None
) -> float

Low-level function to compute the mean squared error (MSE) correctly scaled for states representing physical fields on uniform Cartesian grids.

MSE = 1/L^D * 1/N * sum_i (u_pred_i - u_ref_i)^2

Note that by default (num_spatial_dims=None), the number of spatial dimensions is inferred from the shape of the input fields. Please adjust this argument if you call this function with an array that also contains channels (even for arrays with singleton channels.

Providing correct information regarding the scaling (i.e. providing domain_extent and num_spatial_dims) is not necessary if the result is used to compute a normalized error (e.g. nMSE) if the normalization is computed similarly.

Arguments: - u_pred (array): The first field to be used in the loss - u_ref (array, optional): The second field to be used in the error computation. If None, the error will be computed with respect to zero. - domain_extent (float, optional): The extent of the domain in which the fields are defined. This is used to scale the error to be independent of the domain size. Default is 1.0. - num_spatial_dims (int, optional): The number of spatial dimensions in the field. If None, it will be inferred from the shape of the input fields and then is the number of axes present. Default is None.

Returns: - mse (float): The (correctly scaled) mean squared error between the fields.

Source code in exponax/_metrics.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def _MSE(
    u_pred: Float[Array, "... N"],
    u_ref: Optional[Float[Array, "... N"]] = None,
    domain_extent: float = 1.0,
    *,
    num_spatial_dims: Optional[int] = None,
) -> float:
    """
    Low-level function to compute the mean squared error (MSE) correctly scaled
    for states representing physical fields on uniform Cartesian grids.

    MSE = 1/L^D * 1/N * sum_i (u_pred_i - u_ref_i)^2

    Note that by default (`num_spatial_dims=None`), the number of spatial
    dimensions is inferred from the shape of the input fields. Please adjust
    this argument if you call this function with an array that also contains
    channels (even for arrays with singleton channels.

    Providing correct information regarding the scaling (i.e. providing
    `domain_extent` and `num_spatial_dims`) is not necessary if the result is
    used to compute a normalized error (e.g. nMSE) if the normalization is
    computed similarly.

    **Arguments**:
        - `u_pred` (array): The first field to be used in the loss
        - `u_ref` (array, optional): The second field to be used in the error
            computation. If `None`, the error will be computed with respect to
            zero.
        - `domain_extent` (float, optional): The extent of the domain in which
            the fields are defined. This is used to scale the error to be
            independent of the domain size. Default is 1.0.
        - `num_spatial_dims` (int, optional): The number of spatial dimensions
            in the field. If `None`, it will be inferred from the shape of the
            input fields and then is the number of axes present. Default is
            `None`.

    **Returns**:
        - `mse` (float): The (correctly scaled) mean squared error between the
          fields.
    """
    if u_ref is None:
        diff = u_pred
    else:
        diff = u_pred - u_ref

    if num_spatial_dims is None:
        # Assuming that we only have spatial dimensions
        num_spatial_dims = len(u_pred.shape)

    scale = 1 / (domain_extent**num_spatial_dims)

    mse = scale * jnp.mean(jnp.square(diff))

    return mse