Skip to content

Losses¤

trainax.loss.MSELoss ¤

Bases: BaseLoss

Source code in trainax/loss/_mse_loss.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class MSELoss(BaseLoss):
    def __init__(
        self,
        *,
        batch_reduction: Callable = jnp.mean,
    ):
        """
        Simple Mean Squared Error loss.
        """

        super().__init__(batch_reduction=batch_reduction)

    def single_batch(
        self,
        prediction: Float[Array, "num_channels ..."],
        target: Optional[Float[Array, "num_channels ..."]] = None,
    ) -> float:
        if target is None:
            diff = prediction
        else:
            diff = prediction - target
        return jnp.mean(jnp.square(diff))
__init__ ¤
__init__(*, batch_reduction: Callable = jnp.mean)

Simple Mean Squared Error loss.

Source code in trainax/loss/_mse_loss.py
10
11
12
13
14
15
16
17
18
19
def __init__(
    self,
    *,
    batch_reduction: Callable = jnp.mean,
):
    """
    Simple Mean Squared Error loss.
    """

    super().__init__(batch_reduction=batch_reduction)
__call__ ¤
__call__(
    prediction: Float[
        Array, "num_batches num_channels ..."
    ],
    target: Optional[
        Float[Array, "num_batches num_channels ..."]
    ] = None,
) -> float

Evaluate the loss for a batch of samples.

Inputs must be PyTrees of identical structure with array leafs having a leading batch axis, a subsequent channel/feature axis, and optionally one or more subsequent axes (e.g., spatial axes).

Uses the batch aggregator function specified during initialization.

Arguments:

  • prediction: The predicted values.
  • target: The target values.

Returns:

  • The loss value.
Source code in trainax/loss/_base_loss.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def __call__(
    self,
    prediction: Float[Array, "num_batches num_channels ..."],
    target: Optional[Float[Array, "num_batches num_channels ..."]] = None,
) -> float:
    """
    Evaluate the loss for a batch of samples.

    Inputs must be PyTrees of identical structure with array leafs having a
    leading batch axis, a subsequent channel/feature axis, and optionally one
    or more subsequent axes (e.g., spatial axes).

    Uses the batch aggregator function specified during initialization.

    **Arguments:**

    - `prediction`: The predicted values.
    - `target`: The target values.

    **Returns:**

    - The loss value.
    """
    return self.multi_batch(prediction, target)

trainax.loss.Normalized_MSELoss ¤

Bases: MSELoss

Source code in trainax/loss/_mse_loss.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class Normalized_MSELoss(MSELoss):
    def __init__(
        self,
        *,
        batch_reduction: Callable = jnp.mean,
    ):
        """
        Simple Mean Squared Error loss normalized on the target.
        """

        super().__init__(batch_reduction=batch_reduction)

    def single_batch(
        self,
        prediction: Float[Array, "num_channels ..."],
        target: Float[Array, "num_channels ..."],
    ) -> float:
        if target is None:
            raise ValueError("Target must be provided for Normalized MSE Loss")

        diff_mse = super().single_batch(prediction, target)
        target_mse = super().single_batch(target)

        return diff_mse / target_mse
__init__ ¤
__init__(*, batch_reduction: Callable = jnp.mean)

Simple Mean Squared Error loss normalized on the target.

Source code in trainax/loss/_mse_loss.py
34
35
36
37
38
39
40
41
42
43
def __init__(
    self,
    *,
    batch_reduction: Callable = jnp.mean,
):
    """
    Simple Mean Squared Error loss normalized on the target.
    """

    super().__init__(batch_reduction=batch_reduction)
__call__ ¤
__call__(
    prediction: Float[
        Array, "num_batches num_channels ..."
    ],
    target: Optional[
        Float[Array, "num_batches num_channels ..."]
    ] = None,
) -> float

Evaluate the loss for a batch of samples.

Inputs must be PyTrees of identical structure with array leafs having a leading batch axis, a subsequent channel/feature axis, and optionally one or more subsequent axes (e.g., spatial axes).

Uses the batch aggregator function specified during initialization.

Arguments:

  • prediction: The predicted values.
  • target: The target values.

Returns:

  • The loss value.
Source code in trainax/loss/_base_loss.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def __call__(
    self,
    prediction: Float[Array, "num_batches num_channels ..."],
    target: Optional[Float[Array, "num_batches num_channels ..."]] = None,
) -> float:
    """
    Evaluate the loss for a batch of samples.

    Inputs must be PyTrees of identical structure with array leafs having a
    leading batch axis, a subsequent channel/feature axis, and optionally one
    or more subsequent axes (e.g., spatial axes).

    Uses the batch aggregator function specified during initialization.

    **Arguments:**

    - `prediction`: The predicted values.
    - `target`: The target values.

    **Returns:**

    - The loss value.
    """
    return self.multi_batch(prediction, target)

trainax.loss.MAELoss ¤

Bases: BaseLoss

Source code in trainax/loss/_mae_loss.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class MAELoss(BaseLoss):
    def __init__(
        self,
        *,
        batch_reduction: Callable = jnp.mean,
    ):
        """
        Simple Mean Absolute Error loss.
        """

        super().__init__(batch_reduction=batch_reduction)

    def single_batch(
        self,
        prediction: Float[Array, "num_channels ..."],
        target: Optional[Float[Array, "num_channels ..."]] = None,
    ) -> float:
        if target is None:
            diff = prediction
        else:
            diff = prediction - target
        return jnp.mean(jnp.abs(diff))
__init__ ¤
__init__(*, batch_reduction: Callable = jnp.mean)

Simple Mean Absolute Error loss.

Source code in trainax/loss/_mae_loss.py
10
11
12
13
14
15
16
17
18
19
def __init__(
    self,
    *,
    batch_reduction: Callable = jnp.mean,
):
    """
    Simple Mean Absolute Error loss.
    """

    super().__init__(batch_reduction=batch_reduction)
__call__ ¤
__call__(
    prediction: Float[
        Array, "num_batches num_channels ..."
    ],
    target: Optional[
        Float[Array, "num_batches num_channels ..."]
    ] = None,
) -> float

Evaluate the loss for a batch of samples.

Inputs must be PyTrees of identical structure with array leafs having a leading batch axis, a subsequent channel/feature axis, and optionally one or more subsequent axes (e.g., spatial axes).

Uses the batch aggregator function specified during initialization.

Arguments:

  • prediction: The predicted values.
  • target: The target values.

Returns:

  • The loss value.
Source code in trainax/loss/_base_loss.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def __call__(
    self,
    prediction: Float[Array, "num_batches num_channels ..."],
    target: Optional[Float[Array, "num_batches num_channels ..."]] = None,
) -> float:
    """
    Evaluate the loss for a batch of samples.

    Inputs must be PyTrees of identical structure with array leafs having a
    leading batch axis, a subsequent channel/feature axis, and optionally one
    or more subsequent axes (e.g., spatial axes).

    Uses the batch aggregator function specified during initialization.

    **Arguments:**

    - `prediction`: The predicted values.
    - `target`: The target values.

    **Returns:**

    - The loss value.
    """
    return self.multi_batch(prediction, target)

trainax.loss.Normalized_MAELoss ¤

Bases: MAELoss

Source code in trainax/loss/_mae_loss.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class Normalized_MAELoss(MAELoss):
    def __init__(
        self,
        *,
        batch_reduction: Callable = jnp.mean,
    ):
        """
        Simple Mean Absolute Error loss normalized on the target.
        """

        super().__init__(batch_reduction=batch_reduction)

    def single_batch(
        self,
        prediction: Float[Array, "num_channels ..."],
        target: Float[Array, "num_channels ..."],
    ) -> float:
        if target is None:
            raise ValueError("Target must be provided for Normalized MAE Loss")

        diff_mae = super().single_batch(prediction, target)
        target_mae = super().single_batch(target)

        return diff_mae / target_mae
__init__ ¤
__init__(*, batch_reduction: Callable = jnp.mean)

Simple Mean Absolute Error loss normalized on the target.

Source code in trainax/loss/_mae_loss.py
34
35
36
37
38
39
40
41
42
43
def __init__(
    self,
    *,
    batch_reduction: Callable = jnp.mean,
):
    """
    Simple Mean Absolute Error loss normalized on the target.
    """

    super().__init__(batch_reduction=batch_reduction)
__call__ ¤
__call__(
    prediction: Float[
        Array, "num_batches num_channels ..."
    ],
    target: Optional[
        Float[Array, "num_batches num_channels ..."]
    ] = None,
) -> float

Evaluate the loss for a batch of samples.

Inputs must be PyTrees of identical structure with array leafs having a leading batch axis, a subsequent channel/feature axis, and optionally one or more subsequent axes (e.g., spatial axes).

Uses the batch aggregator function specified during initialization.

Arguments:

  • prediction: The predicted values.
  • target: The target values.

Returns:

  • The loss value.
Source code in trainax/loss/_base_loss.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def __call__(
    self,
    prediction: Float[Array, "num_batches num_channels ..."],
    target: Optional[Float[Array, "num_batches num_channels ..."]] = None,
) -> float:
    """
    Evaluate the loss for a batch of samples.

    Inputs must be PyTrees of identical structure with array leafs having a
    leading batch axis, a subsequent channel/feature axis, and optionally one
    or more subsequent axes (e.g., spatial axes).

    Uses the batch aggregator function specified during initialization.

    **Arguments:**

    - `prediction`: The predicted values.
    - `target`: The target values.

    **Returns:**

    - The loss value.
    """
    return self.multi_batch(prediction, target)

trainax.loss.BaseLoss ¤

Bases: Module, ABC

Source code in trainax/loss/_base_loss.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class BaseLoss(eqx.Module, ABC):
    batch_reduction: Callable

    def __init__(self, *, batch_reduction: Callable = jnp.mean):
        """Base class for loss functions."""
        self.batch_reduction = batch_reduction

    @abstractmethod
    def single_batch(
        self,
        prediction: Float[Array, "num_channels ..."],
        target: Optional[Float[Array, "num_channels ..."]] = None,
    ) -> float:
        """
        Evaluate the loss for a single sample.

        Inputs must be PyTrees of identical structure with array leafs having at
        least a channel/feature axis, and optionally one or more subsequent axes
        (e.g., spatial axes). There should be **no batch axis**.

        !!! info

            To operate on a batch of inputs, either use `multi_batch` or use
            `jax.vmap` on this method.

        **Arguments:**

        - `prediction`: The predicted values.
        - `target`: The target values.

        **Returns:**

        - The loss value.
        """
        pass

    def multi_batch(
        self,
        prediction: Float[Array, "num_batches num_channels ..."],
        target: Optional[Float[Array, "num_batches num_channels ..."]] = None,
    ) -> float:
        """
        Evaluate the loss for a batch of samples.

        Inputs must be PyTrees of identical structure with array leafs having a
        leading batch axis, a subsequent channel/feature axis, and optionally one
        or more subsequent axes (e.g., spatial axes).

        Uses the batch aggregator function specified during initialization.

        **Arguments:**

        - `prediction`: The predicted values.
        - `target`: The target values.

        **Returns:**

        - The loss value.
        """
        if target is None:
            return self.batch_reduction(
                jax.vmap(
                    self.single_batch,
                    in_axes=(0, None),
                )(prediction, target)
            )
        else:
            return self.batch_reduction(
                jax.vmap(
                    self.single_batch,
                    in_axes=(0, 0),
                )(prediction, target)
            )

    def __call__(
        self,
        prediction: Float[Array, "num_batches num_channels ..."],
        target: Optional[Float[Array, "num_batches num_channels ..."]] = None,
    ) -> float:
        """
        Evaluate the loss for a batch of samples.

        Inputs must be PyTrees of identical structure with array leafs having a
        leading batch axis, a subsequent channel/feature axis, and optionally one
        or more subsequent axes (e.g., spatial axes).

        Uses the batch aggregator function specified during initialization.

        **Arguments:**

        - `prediction`: The predicted values.
        - `target`: The target values.

        **Returns:**

        - The loss value.
        """
        return self.multi_batch(prediction, target)
batch_reduction instance-attribute ¤
batch_reduction: Callable = batch_reduction
__init__ ¤
__init__(*, batch_reduction: Callable = jnp.mean)

Base class for loss functions.

Source code in trainax/loss/_base_loss.py
13
14
15
def __init__(self, *, batch_reduction: Callable = jnp.mean):
    """Base class for loss functions."""
    self.batch_reduction = batch_reduction
single_batch abstractmethod ¤
single_batch(
    prediction: Float[Array, "num_channels ..."],
    target: Optional[
        Float[Array, "num_channels ..."]
    ] = None,
) -> float

Evaluate the loss for a single sample.

Inputs must be PyTrees of identical structure with array leafs having at least a channel/feature axis, and optionally one or more subsequent axes (e.g., spatial axes). There should be no batch axis.

Info

To operate on a batch of inputs, either use multi_batch or use jax.vmap on this method.

Arguments:

  • prediction: The predicted values.
  • target: The target values.

Returns:

  • The loss value.
Source code in trainax/loss/_base_loss.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
@abstractmethod
def single_batch(
    self,
    prediction: Float[Array, "num_channels ..."],
    target: Optional[Float[Array, "num_channels ..."]] = None,
) -> float:
    """
    Evaluate the loss for a single sample.

    Inputs must be PyTrees of identical structure with array leafs having at
    least a channel/feature axis, and optionally one or more subsequent axes
    (e.g., spatial axes). There should be **no batch axis**.

    !!! info

        To operate on a batch of inputs, either use `multi_batch` or use
        `jax.vmap` on this method.

    **Arguments:**

    - `prediction`: The predicted values.
    - `target`: The target values.

    **Returns:**

    - The loss value.
    """
    pass
multi_batch ¤
multi_batch(
    prediction: Float[
        Array, "num_batches num_channels ..."
    ],
    target: Optional[
        Float[Array, "num_batches num_channels ..."]
    ] = None,
) -> float

Evaluate the loss for a batch of samples.

Inputs must be PyTrees of identical structure with array leafs having a leading batch axis, a subsequent channel/feature axis, and optionally one or more subsequent axes (e.g., spatial axes).

Uses the batch aggregator function specified during initialization.

Arguments:

  • prediction: The predicted values.
  • target: The target values.

Returns:

  • The loss value.
Source code in trainax/loss/_base_loss.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def multi_batch(
    self,
    prediction: Float[Array, "num_batches num_channels ..."],
    target: Optional[Float[Array, "num_batches num_channels ..."]] = None,
) -> float:
    """
    Evaluate the loss for a batch of samples.

    Inputs must be PyTrees of identical structure with array leafs having a
    leading batch axis, a subsequent channel/feature axis, and optionally one
    or more subsequent axes (e.g., spatial axes).

    Uses the batch aggregator function specified during initialization.

    **Arguments:**

    - `prediction`: The predicted values.
    - `target`: The target values.

    **Returns:**

    - The loss value.
    """
    if target is None:
        return self.batch_reduction(
            jax.vmap(
                self.single_batch,
                in_axes=(0, None),
            )(prediction, target)
        )
    else:
        return self.batch_reduction(
            jax.vmap(
                self.single_batch,
                in_axes=(0, 0),
            )(prediction, target)
        )
__call__ ¤
__call__(
    prediction: Float[
        Array, "num_batches num_channels ..."
    ],
    target: Optional[
        Float[Array, "num_batches num_channels ..."]
    ] = None,
) -> float

Evaluate the loss for a batch of samples.

Inputs must be PyTrees of identical structure with array leafs having a leading batch axis, a subsequent channel/feature axis, and optionally one or more subsequent axes (e.g., spatial axes).

Uses the batch aggregator function specified during initialization.

Arguments:

  • prediction: The predicted values.
  • target: The target values.

Returns:

  • The loss value.
Source code in trainax/loss/_base_loss.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def __call__(
    self,
    prediction: Float[Array, "num_batches num_channels ..."],
    target: Optional[Float[Array, "num_batches num_channels ..."]] = None,
) -> float:
    """
    Evaluate the loss for a batch of samples.

    Inputs must be PyTrees of identical structure with array leafs having a
    leading batch axis, a subsequent channel/feature axis, and optionally one
    or more subsequent axes (e.g., spatial axes).

    Uses the batch aggregator function specified during initialization.

    **Arguments:**

    - `prediction`: The predicted values.
    - `target`: The target values.

    **Returns:**

    - The loss value.
    """
    return self.multi_batch(prediction, target)