Skip to content

Correlation¤

exponax.metrics.correlation ¤

correlation(
    u_pred: Float[Array, "C ... N"],
    u_ref: Float[Array, "C ... N"],
) -> float

Compute the correlation between two fields. Average over all channels.

This function assumes that the arrays have one leading channel axis and an arbitrary number of following spatial axes.

Tip

To apply this function to a state tensor with a leading batch axis, use jax.vmap. Then the batch axis can be reduced, e.g., by jnp.mean. As a helper for this, exponax.metrics.mean_metric is provided.

Arguments:

  • u_pred: The first field to be used in the error computation.
  • u_ref: The second field to be used in the error computation.

Returns:

  • correlation: The correlation between the fields, averaged over all channels.
Source code in exponax/metrics/_correlation.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def correlation(
    u_pred: Float[Array, "C ... N"],
    u_ref: Float[Array, "C ... N"],
) -> float:
    """
    Compute the correlation between two fields. Average over all channels.

    This function assumes that the arrays have one leading channel axis and an
    arbitrary number of following spatial axes.

    !!! tip
        To apply this function to a state tensor with a leading batch axis, use
        `jax.vmap`. Then the batch axis can be reduced, e.g., by `jnp.mean`. As
        a helper for this, [`exponax.metrics.mean_metric`][] is provided.

    **Arguments**:

    - `u_pred`: The first field to be used in the error computation.
    - `u_ref`: The second field to be used in the error computation.

    **Returns**:

    - `correlation`: The correlation between the fields, averaged over
        all channels.
    """
    channel_wise_correlation = jax.vmap(_correlation)(u_pred, u_ref)
    correlation = jnp.mean(channel_wise_correlation)
    return correlation