Skip to content

Utilities¤

pdequinox.dataloader ¤

dataloader(
    data: Union[PyTree, Array],
    *,
    batch_size: int,
    key: PRNGKeyArray
)

A generator for looping over the data in batches.

The data is shuffled before looping. The length is based on how many minibatches are needed to loop over the data once (n_samples // batch_size + 1). In deep learning terminology, one looping over the dataloader represents one epoch.

For a supervised learning problem use

dataloader(
    (inputs, targets), batch_size=batch_size, key=key,
)

Arguments:

  • data: Union[PyTree, Array]. The data to be looped over. This must be JAX-compatible PyTree; in the easiest case an array. Each leaf array in the PyTree must be array-like with a leading a batch axis. These leading batch axes must be identical for all leafs.
  • batch_size: int. The size of the minibatches. (keyword-based argument)
  • key: JAX PRNGKey. The key to be used for shuffling the data; required for reproducible randomness. (keyword-based argument)
Source code in pdequinox/_utils.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def dataloader(
    data: Union[PyTree, Array],
    *,
    batch_size: int,
    key: PRNGKeyArray,
):
    """
    A generator for looping over the data in batches.

    The data is shuffled before looping. The length is based on how many
    minibatches are needed to loop over the data once (n_samples // batch_size +
    1). In deep learning terminology, one looping over the dataloader represents
    one epoch.

    For a supervised learning problem use

    ```python

    dataloader(
        (inputs, targets), batch_size=batch_size, key=key,
    )

    ```

    **Arguments:**

    - `data`: Union[PyTree, Array]. The data to be looped over. This must be
        JAX-compatible PyTree; in the easiest case an array. Each leaf array in
        the PyTree must be array-like with a leading a batch axis. These leading
        batch axes must be identical for all leafs.
    - `batch_size`: int. The size of the minibatches. (keyword-based argument)
    - `key`: JAX PRNGKey. The key to be used for shuffling the data; required
        for reproducible randomness. (keyword-based argument)
    """

    n_samples_list = [a.shape[0] for a in jtu.tree_leaves(data)]

    if not all(n == n_samples_list[0] for n in n_samples_list):
        raise ValueError(
            "All arrays / PyTree leaves must have the same number of samples. (Leading array axis)"
        )

    n_samples = n_samples_list[0]

    n_batches = int(jnp.ceil(n_samples / batch_size))

    permutation = jax.random.permutation(key, n_samples)

    for batch_id in range(n_batches):
        start = batch_id * batch_size
        end = min((batch_id + 1) * batch_size, n_samples)

        batch_indices = permutation[start:end]

        sub_data = jtu.tree_map(lambda a: a[batch_indices], data)

        yield sub_data

pdequinox.cycling_dataloader ¤

cycling_dataloader(
    data: Union[PyTree, Array],
    *,
    batch_size: int,
    num_steps: int,
    key: PRNGKeyArray,
    return_info: bool = False
)

A generator for looping over the data in batches for a fixed number of steps.

It performs as many epochs (one full iteration over the data) as needed to produce num_steps batches. Note that one batch will never contain data from two epochs. Internally, this generator uses the dataloader generator. Hence, if batch_size is chosen larger than the length of batch axis in the leaf arrays of data, the batch will be of the size of the data.

For a supervised learning problem use

cycling_dataloader(
    (inputs, targets), batch_size=batch_size, num_steps=num_steps, key=key,
)

Arguments:

  • data: Union[PyTree, Array]. The data to be looped over. This must be JAX-compatible PyTree; in the easiest case an array. Each leaf array in the PyTree must be array-like with a leading a batch axis. These leading batch axes must be identical for all leafs.
  • batch_size: int. The size of the minibatches. (keyword-based argument)
  • num_steps: int. The number of steps to loop over the data. (keyword-based argument)
  • key: JAX PRNGKey. The key to be used for shuffling the data; required for reproducible randomness. (keyword-based argument)
  • return_info: bool. Whether to return the epoch and batch indices in addition to the data. (keyword-based argument)
Source code in pdequinox/_utils.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def cycling_dataloader(
    data: Union[PyTree, Array],
    *,
    batch_size: int,
    num_steps: int,
    key: PRNGKeyArray,
    return_info: bool = False,
):
    """
    A generator for looping over the data in batches for a fixed number of
    steps.

    It performs as many epochs (one full iteration over the data) as needed to
    produce `num_steps` batches. Note that one batch will never contain data
    from two epochs. Internally, this generator uses the `dataloader` generator.
    Hence, if `batch_size` is chosen larger than the length of batch axis in the
    leaf arrays of `data`, the batch will be of the size of the data.

    For a supervised learning problem use

    ```python

    cycling_dataloader(
        (inputs, targets), batch_size=batch_size, num_steps=num_steps, key=key,
    )

    ```

    **Arguments:**

    - `data`: Union[PyTree, Array]. The data to be looped over. This must be
        JAX-compatible PyTree; in the easiest case an array. Each leaf array in
        the PyTree must be array-like with a leading a batch axis. These leading
        batch axes must be identical for all leafs.
    - `batch_size`: int. The size of the minibatches. (keyword-based argument)
    - `num_steps`: int. The number of steps to loop over the data.
      (keyword-based argument)
    - `key`: JAX PRNGKey. The key to be used for shuffling the data; required
        for reproducible randomness. (keyword-based argument)
    - `return_info`: bool. Whether to return the epoch and batch indices in
        addition to the data. (keyword-based argument)
    """
    epoch_id = 0
    total_step_id = 0

    while True:
        key, subkey = jax.random.split(key)

        for batch_id, sub_data in enumerate(
            dataloader(data, batch_size=batch_size, key=subkey)
        ):
            if total_step_id == num_steps:
                return

            if return_info:
                yield sub_data, epoch_id, batch_id
            else:
                yield sub_data

            total_step_id += 1

        epoch_id += 1

pdequinox.extract_from_ensemble ¤

extract_from_ensemble(ensemble: eqx.Module, i: int)

Given an ensemble of equinox Modules, extract its i-th element.

If you create an ensemble, e.g., with

```python

import equinox as eqx

ensemble = eqx.filter_vmap( lambda k: eqx.nn.Conv1d(1, 1, 3) )(jax.random.split(jax.random.PRNGKey(0), 5) ```

its weight arrays have an additional batch/ensemble axis. It cannot be used natively on its corresponding data. This function extracts the i-th element of the ensemble.

Arguments:

  • ensemble: eqx.Module. The ensemble of networks.
  • i: int. The index of the network to be extracted. This can also be a slice!
Source code in pdequinox/_utils.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def extract_from_ensemble(ensemble: eqx.Module, i: int):
    """
    Given an ensemble of equinox Modules, extract its i-th element.

    If you create an ensemble, e.g., with

    ```python

    import equinox as eqx

    ensemble = eqx.filter_vmap(
        lambda k: eqx.nn.Conv1d(1, 1, 3)
    )(jax.random.split(jax.random.PRNGKey(0), 5) ```

    its weight arrays have an additional batch/ensemble axis. It cannot be used
    natively on its corresponding data. This function extracts the i-th element
    of the ensemble.

    **Arguments:**

    - `ensemble`: eqx.Module. The ensemble of networks.
    - `i`: int. The index of the network to be extracted. This can also be a
        slice!
    """
    params, static = eqx.partition(ensemble, eqx.is_array)
    params_extracted = jtu.tree_map(lambda x: x[i], params)
    network_extracted = eqx.combine(params_extracted, static)
    return network_extracted

pdequinox.combine_to_ensemble ¤

combine_to_ensemble(
    networks: list[eqx.Module],
) -> eqx.Module

Given a list of multiple equinox Modules of the same PyTree structure combine them into an essemble (to have the weight arrays with an additional batch/ensemble axis).

Arguments:

  • networks: list[eqx.Module]. The networks to be combined.

Returns:

  • ensemble: eqx.Module. The ensemble of networks.
Source code in pdequinox/_utils.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def combine_to_ensemble(networks: list[eqx.Module]) -> eqx.Module:
    """
    Given a list of multiple equinox Modules of the same PyTree structure
    combine them into an essemble (to have the weight arrays with an additional
    batch/ensemble axis).

    **Arguments:**

    - `networks`: list[eqx.Module]. The networks to be combined.

    **Returns:**

    - `ensemble`: eqx.Module. The ensemble of networks.
    """
    _, static = eqx.partition(networks[0], eqx.is_array)
    params = [eqx.filter(network, eqx.is_array) for network in networks]
    params_combined = jtu.tree_map(lambda *x: jnp.stack(x), *params)
    ensemble = eqx.combine(params_combined, static)
    return ensemble