Correlation¤
exponax.metrics.correlation
¤
correlation(
u_pred: Float[Array, "C ... N"],
u_ref: Float[Array, "C ... N"],
) -> float
Compute the correlation between two fields. Average over all channels.
This function assumes that the arrays have one leading channel axis and an
arbitrary number of following spatial axes. For operation on batched arrays
use mean_correlation
.
Arguments:
u_pred
(array): The first field to be used in the error computation.u_ref
(array): The second field to be used in the error computation.
Returns:
correlation
(float): The correlation between the fields, averaged over all channels.
Source code in exponax/_metrics.py
364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 |
|
exponax.metrics.mean_correlation
¤
mean_correlation(
u_pred: Float[Array, "B C ... N"],
u_ref: Float[Array, "B C ... N"],
) -> float
Compute the mean correlation between multiple samples of two fields.
This function assumes that the arrays have one leading batch axis, followed by a channel axis and an arbitrary number of following spatial axes.
If you want to apply this function on two trajectories of fields, you can
use jax.vmap
to transform it, use jax.vmap(mean_correlation, in_axes=I)
with I
being the index of the time axis (e.g. I=0
for time axis at the
beginning of the array, or I=1
for time axis at the second position,
depending on the convention).
Arguments:
u_pred
(array): The first tensor of fields to be used in the error computation.u_ref
(array): The second tensor of fields to be used in the error computation.
Returns:
mean_correlation
(float): The mean correlation between the fields
Source code in exponax/_metrics.py
390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 |
|
exponax.metrics._correlation
¤
_correlation(
u_pred: Float[Array, "... N"],
u_ref: Float[Array, "... N"],
) -> float
Low-level function to compute the correlation between two fields.
This function assumes field without channel axes. Even for singleton channel
axes, use correlation
for correct operation.
Arguments:
u_pred
(array): The first field to be used in the lossu_ref
(array): The second field to be used in the error computation
Returns:
correlation
(float): The correlation between the fields
Source code in exponax/_metrics.py
337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 |
|