Correlation¤
exponax.metrics.correlation
¤
correlation(
u_pred: Float[Array, "C ... N"],
u_ref: Float[Array, "C ... N"],
) -> float
Compute the correlation between two fields. Average over all channels.
This function assumes that the arrays have one leading channel axis and an arbitrary number of following spatial axes.
Tip
To apply this function to a state tensor with a leading batch axis, use
jax.vmap
. Then the batch axis can be reduced, e.g., by jnp.mean
. As
a helper for this, exponax.metrics.mean_metric
is provided.
Arguments:
u_pred
: The first field to be used in the error computation.u_ref
: The second field to be used in the error computation.
Returns:
correlation
: The correlation between the fields, averaged over all channels.
Source code in exponax/metrics/_correlation.py
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
|