Skip to content

Animate States¤

Visualizes trajectories of states in 1d, 2d, or 3d via an animation of their respective plot_states

exponax.viz.animate_state_1d ¤

animate_state_1d(
    trj: Float[Array, "T C N"],
    *,
    vlim: tuple[float, float] = (-1, 1),
    domain_extent: float = None,
    dt: float = None,
    include_init: bool = False,
    **kwargs
)

Animate a trajectory of 1d states.

Requires the input to be a three-axis array with a leading time axis, a channel axis, and a spatial axis. If there is more than one dimension in the channel axis, this will be plotted in a different color.

Periodic boundary conditions will be applied to the spatial axis (the state is wrapped around).

Arguments:

  • trj: The trajectory of states to animate. Must be a three-axis array with shape (n_timesteps, n_channels, n_spatial). If the channel axis has more than one dimension, the different channels will be plotted in different colors.
  • vlim: The limits of the colorbar. Default is (-1, 1).
  • domain_extent: The extent of the spatial domain. Default is None. This affects the x-axis limits of the plot.
  • dt: The time step between each frame. Default is None. If provided, a title will be displayed with the current time. If not provided, just the frames are counted.
  • include_init: Whether to the state starts at an initial condition (t=0) or at the first frame in the trajectory. This affects is the the time range is [0, (T-1)dt] or [dt, Tdt]. Default is False.
  • **kwargs: Additional keyword arguments to pass to the plotting function.

Returns:

  • ani: The animation object.
Source code in exponax/viz/_animate.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def animate_state_1d(
    trj: Float[Array, "T C N"],
    *,
    vlim: tuple[float, float] = (-1, 1),
    domain_extent: float = None,
    dt: float = None,
    include_init: bool = False,
    **kwargs,
):
    """
    Animate a trajectory of 1d states.

    Requires the input to be a three-axis array with a leading time axis, a
    channel axis, and a spatial axis. If there is more than one dimension in the
    channel axis, this will be plotted in a different color.

    Periodic boundary conditions will be applied to the spatial axis (the state
    is wrapped around).

    **Arguments**:

    - `trj`: The trajectory of states to animate. Must be a three-axis array
        with shape `(n_timesteps, n_channels, n_spatial)`. If the channel axis
        has more than one dimension, the different channels will be plotted in
        different colors.
    - `vlim`: The limits of the colorbar. Default is `(-1, 1)`.
    - `domain_extent`: The extent of the spatial domain. Default is `None`. This
        affects the x-axis limits of the plot.
    - `dt`: The time step between each frame. Default is `None`. If provided,
        a title will be displayed with the current time. If not provided, just
        the frames are counted.
    - `include_init`: Whether to the state starts at an initial condition (t=0)
        or at the first frame in the trajectory. This affects is the the time
        range is [0, (T-1)dt] or [dt, Tdt]. Default is `False`.
    - `**kwargs`: Additional keyword arguments to pass to the plotting function.

    **Returns**:

    - `ani`: The animation object.
    """
    fig, ax = plt.subplots()

    plot_state_1d(
        trj[0],
        vlim=vlim,
        domain_extent=domain_extent,
        ax=ax,
        **kwargs,
    )

    if include_init:
        temporal_grid = jnp.arange(trj.shape[0])
    else:
        temporal_grid = jnp.arange(1, trj.shape[0] + 1)

    if dt is not None:
        temporal_grid *= dt

    ax.set_title(f"t = {temporal_grid[0]:.2f}")

    def animate(i):
        ax.clear()
        plot_state_1d(
            trj[i],
            vlim=vlim,
            domain_extent=domain_extent,
            ax=ax,
            **kwargs,
        )

    plt.close(fig)

    ani = FuncAnimation(fig, animate, frames=trj.shape[0], interval=100, blit=False)

    return ani

exponax.viz.animate_state_2d ¤

animate_state_2d(
    trj: Float[Array, "T 1 N N"],
    *,
    vlim: tuple[float, float] = (-1.0, 1.0),
    cmap: str = "RdBu_r",
    domain_extent: float = None,
    dt: float = None,
    include_init: bool = False,
    **kwargs
)

Animate a trajectory of 2d states.

Requires the input to be a four-axis array with a leading time axis, a channel axis, and two spatial axes. Only the zeroth dimension in the channel axis is plotted.

Periodic boundary conditions will be applied to the spatial axes (the state is wrapped around).

Arguments:

  • trj: The trajectory of states to animate. Must be a four-axis array with shape (n_timesteps, 1, n_spatial, n_spatial).
  • vlim: The limits of the colorbar. Default is (-1, 1).
  • cmap: The colormap to use. Default is "RdBu_r".
  • domain_extent: The extent of the spatial domain. Default is None. This affects the x- and y-axis limits of the plot.
  • dt: The time step between each frame. Default is None. If provided, a title will be displayed with the current time. If not provided, just the frames are counted.
  • include_init: Whether to the state starts at an initial condition (t=0) or at the first frame in the trajectory. This affects is the the time range is [0, (T-1)dt] or [dt, Tdt]. Default is False.
  • **kwargs: Additional keyword arguments to pass to the plotting function.

Returns:

  • ani: The animation object.
Source code in exponax/viz/_animate.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
def animate_state_2d(
    trj: Float[Array, "T 1 N N"],
    *,
    vlim: tuple[float, float] = (-1.0, 1.0),
    cmap: str = "RdBu_r",
    domain_extent: float = None,
    dt: float = None,
    include_init: bool = False,
    **kwargs,
):
    """
    Animate a trajectory of 2d states.

    Requires the input to be a four-axis array with a leading time axis, a
    channel axis, and two spatial axes. Only the zeroth dimension in the channel
    axis is plotted.

    Periodic boundary conditions will be applied to the spatial axes (the state
    is wrapped around).

    **Arguments**:

    - `trj`: The trajectory of states to animate. Must be a four-axis array with
        shape `(n_timesteps, 1, n_spatial, n_spatial)`.
    - `vlim`: The limits of the colorbar. Default is `(-1, 1)`.
    - `cmap`: The colormap to use. Default is `"RdBu_r"`.
    - `domain_extent`: The extent of the spatial domain. Default is `None`. This
        affects the x- and y-axis limits of the plot.
    - `dt`: The time step between each frame. Default is `None`. If provided,
        a title will be displayed with the current time. If not provided, just
        the frames are counted.
    - `include_init`: Whether to the state starts at an initial condition (t=0)
        or at the first frame in the trajectory. This affects is the the time
        range is [0, (T-1)dt] or [dt, Tdt]. Default is `False`.
    - `**kwargs`: Additional keyword arguments to pass to the plotting function.

    **Returns**:

    - `ani`: The animation object.
    """
    if trj.ndim != 4:
        raise ValueError("trj must be a four-axis array.")

    fig, ax = plt.subplots()

    if include_init:
        temporal_grid = jnp.arange(trj.shape[0])
    else:
        temporal_grid = jnp.arange(1, trj.shape[0] + 1)

    if dt is not None:
        temporal_grid *= dt

    plot_state_2d(
        trj[0],
        vlim=vlim,
        cmap=cmap,
        domain_extent=domain_extent,
        ax=ax,
    )
    ax.set_title(f"t = {temporal_grid[0]:.2f}")

    def animate(i):
        ax.clear()
        plot_state_2d(
            trj[i],
            vlim=vlim,
            cmap=cmap,
            domain_extent=domain_extent,
            ax=ax,
        )
        ax.set_title(f"t = {temporal_grid[i]:.2f}")

    plt.close(fig)

    ani = FuncAnimation(fig, animate, frames=trj.shape[0], interval=100, blit=False)

    return ani

exponax.viz.animate_state_3d ¤

animate_state_3d(
    trj: Float[Array, "T 1 N N N"],
    *,
    vlim: tuple[float, float] = (-1.0, 1.0),
    domain_extent: float = None,
    dt: float = None,
    include_init: bool = False,
    bg_color: Union[
        Literal["black"],
        Literal["white"],
        tuple[jnp.int8, jnp.int8, jnp.int8, jnp.int8],
    ] = "white",
    resolution: int = 384,
    cmap: str = "RdBu_r",
    transfer_function: callable = zigzag_alpha,
    distance_scale: float = 10.0,
    gamma_correction: float = 2.4,
    chunk_size: int = 64,
    **kwargs
)

Animate a trajectory of 3d states as volume renderings.

Requires the input to be a five-axis array with a leading time axis, a channel axis, and three spatial axes. Only the zeroth dimension in the channel axis is plotted.

Periodic boundary conditions will be applied to the spatial axes (the state is wrapped around).

Arguments:

  • trj: The trajectory of states to animate. Must be a five-axis array with shape (n_timesteps, 1, n_spatial, n_spatial, n_spatial).
  • vlim: The limits of the colorbar. Default is (-1, 1).
  • domain_extent: (Unused as of now)
  • dt: The time step between each frame. Default is None. If provided, a title will be displayed with the current time. If not provided, just the frames are counted.
  • include_init: Whether to the state starts at an initial condition (t=0) or at the first frame in the trajectory. This affects is the the time range is [0, (T-1)dt] or [dt, Tdt]. Default is False.
  • bg_color: The background color. Either "black", "white", or a tuple of RGBA values. Default is "white".
  • resolution: The resolution of the output image (affects render time). Default is 384.
  • cmap: The colormap to use. Default is "RdBu_r".
  • transfer_function: The transfer function to use. Default is zigzag_alpha.
  • distance_scale: The distance scale. Default is 10.0.
  • gamma_correction: The gamma correction. Default is 2.4.
  • chunk_size: The chunk size. Default is 64.

Returns:

  • ani: The animation object.

Note:

  • This function requires the vape volume renderer package.
Source code in exponax/viz/_animate.py
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
def animate_state_3d(
    trj: Float[Array, "T 1 N N N"],
    *,
    vlim: tuple[float, float] = (-1.0, 1.0),
    domain_extent: float = None,
    dt: float = None,
    include_init: bool = False,
    bg_color: Union[
        Literal["black"],
        Literal["white"],
        tuple[jnp.int8, jnp.int8, jnp.int8, jnp.int8],
    ] = "white",
    resolution: int = 384,
    cmap: str = "RdBu_r",
    transfer_function: callable = zigzag_alpha,
    distance_scale: float = 10.0,
    gamma_correction: float = 2.4,
    chunk_size: int = 64,
    **kwargs,
):
    """
    Animate a trajectory of 3d states as volume renderings.

    Requires the input to be a five-axis array with a leading time axis, a
    channel axis, and three spatial axes. Only the zeroth dimension in the
    channel axis is plotted.

    Periodic boundary conditions will be applied to the spatial axes (the state
    is wrapped around).

    **Arguments**:

    - `trj`: The trajectory of states to animate. Must be a five-axis array with
        shape `(n_timesteps, 1, n_spatial, n_spatial, n_spatial)`.
    - `vlim`: The limits of the colorbar. Default is `(-1, 1)`.
    - `domain_extent`: (Unused as of now)
    - `dt`: The time step between each frame. Default is `None`. If provided,
        a title will be displayed with the current time. If not provided, just
        the frames are counted.
    - `include_init`: Whether to the state starts at an initial condition (t=0)
        or at the first frame in the trajectory. This affects is the the time
        range is [0, (T-1)dt] or [dt, Tdt]. Default is `False`.
    - `bg_color`: The background color. Either `"black"`, `"white"`, or a tuple
        of RGBA values. Default is `"white"`.
    - `resolution`: The resolution of the output image (affects render time).
        Default is `384`.
    - `cmap`: The colormap to use. Default is `"RdBu_r"`.
    - `transfer_function`: The transfer function to use. Default is `zigzag_alpha`.
    - `distance_scale`: The distance scale. Default is `10.0`.
    - `gamma_correction`: The gamma correction. Default is `2.4`.
    - `chunk_size`: The chunk size. Default is `64`.

    **Returns**:

    - `ani`: The animation object.

    **Note:**

    - This function requires the `vape` volume renderer package.
    """
    if trj.ndim != 5:
        raise ValueError("trj must be a five-axis array.")

    fig, ax = plt.subplots()

    if include_init:
        temporal_grid = jnp.arange(trj.shape[0])
    else:
        temporal_grid = jnp.arange(1, trj.shape[0] + 1)

    if dt is not None:
        temporal_grid *= dt

    trj_wrapped = jax.vmap(wrap_bc)(trj)
    trj_wrapped_no_channel = trj_wrapped[:, 0]

    imgs = volume_render_state_3d(
        trj_wrapped_no_channel,
        vlim=vlim,
        bg_color=bg_color,
        resolution=resolution,
        cmap=cmap,
        transfer_function=transfer_function,
        distance_scale=distance_scale,
        gamma_correction=gamma_correction,
        chunk_size=chunk_size,
        **kwargs,
    )

    ax.imshow(imgs[0])
    ax.axis("off")
    ax.set_title(f"t = {temporal_grid[0]:.2f}")

    def animate(i):
        ax.clear()
        ax.imshow(imgs[i])
        ax.axis("off")
        ax.set_title(f"t = {temporal_grid[i]:.2f}")

    ani = FuncAnimation(fig, animate, frames=trj.shape[0], interval=100, blit=False)

    plt.close(fig)

    return ani