Skip to content

Correlation¤

exponax.metrics.correlation ¤

correlation(
    u_pred: Float[Array, "C ... N"],
    u_ref: Float[Array, "C ... N"],
) -> float

Compute the correlation between two fields. Average over all channels.

This function assumes that the arrays have one leading channel axis and an arbitrary number of following spatial axes. For operation on batched arrays use mean_correlation.

Arguments:

  • u_pred (array): The first field to be used in the error computation.
  • u_ref (array): The second field to be used in the error computation.

Returns:

  • correlation (float): The correlation between the fields, averaged over all channels.
Source code in exponax/_metrics.py
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
def correlation(
    u_pred: Float[Array, "C ... N"],
    u_ref: Float[Array, "C ... N"],
) -> float:
    """
    Compute the correlation between two fields. Average over all channels.

    This function assumes that the arrays have one leading channel axis and an
    arbitrary number of following spatial axes. For operation on batched arrays
    use `mean_correlation`.

    **Arguments**:

    - `u_pred` (array): The first field to be used in the error computation.
    - `u_ref` (array): The second field to be used in the error computation.

    **Returns**:

    - `correlation` (float): The correlation between the fields, averaged over
        all channels.
    """
    channel_wise_correlation = jax.vmap(_correlation)(u_pred, u_ref)
    correlation = jnp.mean(channel_wise_correlation)
    return correlation

exponax.metrics.mean_correlation ¤

mean_correlation(
    u_pred: Float[Array, "B C ... N"],
    u_ref: Float[Array, "B C ... N"],
) -> float

Compute the mean correlation between multiple samples of two fields.

This function assumes that the arrays have one leading batch axis, followed by a channel axis and an arbitrary number of following spatial axes.

If you want to apply this function on two trajectories of fields, you can use jax.vmap to transform it, use jax.vmap(mean_correlation, in_axes=I) with I being the index of the time axis (e.g. I=0 for time axis at the beginning of the array, or I=1 for time axis at the second position, depending on the convention).

Arguments:

  • u_pred (array): The first tensor of fields to be used in the error computation.
  • u_ref (array): The second tensor of fields to be used in the error computation.

Returns:

  • mean_correlation (float): The mean correlation between the fields
Source code in exponax/_metrics.py
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
def mean_correlation(
    u_pred: Float[Array, "B C ... N"],
    u_ref: Float[Array, "B C ... N"],
) -> float:
    """
    Compute the mean correlation between multiple samples of two fields.

    This function assumes that the arrays have one leading batch axis, followed
    by a channel axis and an arbitrary number of following spatial axes.

    If you want to apply this function on two trajectories of fields, you can
    use `jax.vmap` to transform it, use `jax.vmap(mean_correlation, in_axes=I)`
    with `I` being the index of the time axis (e.g. `I=0` for time axis at the
    beginning of the array, or `I=1` for time axis at the second position,
    depending on the convention).

    **Arguments**:

    - `u_pred` (array): The first tensor of fields to be used in the error
        computation.
    - `u_ref` (array): The second tensor of fields to be used in the error
        computation.

    **Returns**:

    - `mean_correlation` (float): The mean correlation between the fields
    """
    batch_wise_correlation = jax.vmap(correlation)(u_pred, u_ref)
    mean_correlation = jnp.mean(batch_wise_correlation)
    return mean_correlation

exponax.metrics._correlation ¤

_correlation(
    u_pred: Float[Array, "... N"],
    u_ref: Float[Array, "... N"],
) -> float

Low-level function to compute the correlation between two fields.

This function assumes field without channel axes. Even for singleton channel axes, use correlation for correct operation.

Arguments:

  • u_pred (array): The first field to be used in the loss
  • u_ref (array): The second field to be used in the error computation

Returns:

  • correlation (float): The correlation between the fields
Source code in exponax/_metrics.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
def _correlation(
    u_pred: Float[Array, "... N"],
    u_ref: Float[Array, "... N"],
) -> float:
    """
    Low-level function to compute the correlation between two fields.

    This function assumes field without channel axes. Even for singleton channel
    axes, use `correlation` for correct operation.

    **Arguments**:

    - `u_pred` (array): The first field to be used in the loss
    - `u_ref` (array): The second field to be used in the error computation

    **Returns**:

    - `correlation` (float): The correlation between the fields
    """
    u_pred_normalized = u_pred / jnp.linalg.norm(u_pred)
    u_ref_normalized = u_ref / jnp.linalg.norm(u_ref)

    correlation = jnp.dot(u_pred_normalized.flatten(), u_ref_normalized.flatten())

    return correlation