Utilities¤
pdequinox.dataloader
¤
dataloader(
data: Union[PyTree, Array],
*,
batch_size: int,
key: PRNGKeyArray
)
A generator for looping over the data in batches.
The data is shuffled before looping. The length is based on how many minibatches are needed to loop over the data once (n_samples // batch_size + 1). In deep learning terminology, one looping over the dataloader represents one epoch.
For a supervised learning problem use
dataloader(
(inputs, targets), batch_size=batch_size, key=key,
)
Arguments:
data
: Union[PyTree, Array]. The data to be looped over. This must be JAX-compatible PyTree; in the easiest case an array. Each leaf array in the PyTree must be array-like with a leading a batch axis. These leading batch axes must be identical for all leafs.batch_size
: int. The size of the minibatches. (keyword-based argument)key
: JAX PRNGKey. The key to be used for shuffling the data; required for reproducible randomness. (keyword-based argument)
Source code in pdequinox/_utils.py
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
|
pdequinox.cycling_dataloader
¤
cycling_dataloader(
data: Union[PyTree, Array],
*,
batch_size: int,
num_steps: int,
key: PRNGKeyArray,
return_info: bool = False
)
A generator for looping over the data in batches for a fixed number of steps.
It performs as many epochs (one full iteration over the data) as needed to
produce num_steps
batches. Note that one batch will never contain data
from two epochs. Internally, this generator uses the dataloader
generator.
Hence, if batch_size
is chosen larger than the length of batch axis in the
leaf arrays of data
, the batch will be of the size of the data.
For a supervised learning problem use
cycling_dataloader(
(inputs, targets), batch_size=batch_size, num_steps=num_steps, key=key,
)
Arguments:
data
: Union[PyTree, Array]. The data to be looped over. This must be JAX-compatible PyTree; in the easiest case an array. Each leaf array in the PyTree must be array-like with a leading a batch axis. These leading batch axes must be identical for all leafs.batch_size
: int. The size of the minibatches. (keyword-based argument)num_steps
: int. The number of steps to loop over the data. (keyword-based argument)key
: JAX PRNGKey. The key to be used for shuffling the data; required for reproducible randomness. (keyword-based argument)return_info
: bool. Whether to return the epoch and batch indices in addition to the data. (keyword-based argument)
Source code in pdequinox/_utils.py
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
|
pdequinox.extract_from_ensemble
¤
extract_from_ensemble(ensemble: eqx.Module, i: int)
Given an ensemble of equinox Modules, extract its i-th element.
If you create an ensemble, e.g., with
```python
import equinox as eqx
ensemble = eqx.filter_vmap( lambda k: eqx.nn.Conv1d(1, 1, 3) )(jax.random.split(jax.random.PRNGKey(0), 5) ```
its weight arrays have an additional batch/ensemble axis. It cannot be used natively on its corresponding data. This function extracts the i-th element of the ensemble.
Arguments:
ensemble
: eqx.Module. The ensemble of networks.i
: int. The index of the network to be extracted. This can also be a slice!
Source code in pdequinox/_utils.py
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
|
pdequinox.combine_to_ensemble
¤
combine_to_ensemble(
networks: list[eqx.Module],
) -> eqx.Module
Given a list of multiple equinox Modules of the same PyTree structure combine them into an essemble (to have the weight arrays with an additional batch/ensemble axis).
Arguments:
networks
: list[eqx.Module]. The networks to be combined.
Returns:
ensemble
: eqx.Module. The ensemble of networks.
Source code in pdequinox/_utils.py
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
|