Skip to content

Spatio-Temporal Visualization¤

Uses an imshow to visualize trajectories of 1d states and a volume render to visualize 2d trajectories; cannot display 3d trajectories.

exponax.viz.plot_spatio_temporal ¤

plot_spatio_temporal(
    trj: Float[Array, "T 1 N"],
    *,
    vlim: tuple[float, float] = (-1.0, 1.0),
    cmap: str = "RdBu_r",
    ax=None,
    domain_extent: float = None,
    dt: float = None,
    include_init: bool = False,
    **kwargs
)

Plot a trajectory of a 1d state as a spatio-temporal plot (space in y-axis, and time in x-axis).

Requires the input to be a real array with three axis: a leading time axis, a channel axis, and a spatial axis. Only the leading dimension in the channel axis will be plotted. See plot_spatio_temporal_facet for plotting multiple trajectories.

Periodic boundary conditions will be applied to the spatial axis (the state is wrapped around).

Arguments:

  • trj: The trajectory to plot as a three axis array. The first axis should be the time axis, the second axis the channel axis, and the third axis the spatial axis.
  • vlim: The limits of the color scale.
  • cmap: The colormap to use.
  • ax: The axis to plot on. If not provided, a new figure will be created.
  • domain_extent: The extent of the spatial domain. If not provided, the domain extent will be the number of points in the spatial axis. This adjusts the y-axis.
  • dt: The time step. This adjust the extent of the x-axis. If not provided, the time axis will be the number of time steps.
  • include_init: Will affect the ticks of the time axis. If True, they will start at zero. If False, they will start at the time step.
  • **kwargs: Additional arguments to pass to the imshow function.

Returns:

  • If ax is not provided, returns the figure. Otherwise, returns the image object.
Source code in exponax/viz/_plot.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def plot_spatio_temporal(
    trj: Float[Array, "T 1 N"],
    *,
    vlim: tuple[float, float] = (-1.0, 1.0),
    cmap: str = "RdBu_r",
    ax=None,
    domain_extent: float = None,
    dt: float = None,
    include_init: bool = False,
    **kwargs,
):
    """
    Plot a trajectory of a 1d state as a spatio-temporal plot (space in y-axis,
    and time in x-axis).

    Requires the input to be a real array with three axis: a leading time axis,
    a channel axis, and a spatial axis. Only the leading dimension in the
    channel axis will be plotted. See `plot_spatio_temporal_facet` for plotting
    multiple trajectories.

    Periodic boundary conditions will be applied to the spatial axis (the state
    is wrapped around).

    **Arguments:**

    - `trj`: The trajectory to plot as a three axis array. The first axis should
        be the time axis, the second axis the channel axis, and the third axis
        the spatial axis.
    - `vlim`: The limits of the color scale.
    - `cmap`: The colormap to use.
    - `ax`: The axis to plot on. If not provided, a new figure will be created.
    - `domain_extent`: The extent of the spatial domain. If not provided, the
        domain extent will be the number of points in the spatial axis. This
        adjusts the y-axis.
    - `dt`: The time step. This adjust the extent of the x-axis. If not
        provided, the time axis will be the number of time steps.
    - `include_init`: Will affect the ticks of the time axis. If `True`, they
        will start at zero. If `False`, they will start at the time step.
    - `**kwargs`: Additional arguments to pass to the imshow function.

    **Returns:**

    - If `ax` is not provided, returns the figure. Otherwise, returns the image
        object.
    """
    if trj.ndim != 3:
        raise ValueError("trj must be a two-axis array.")

    trj_wrapped = jax.vmap(wrap_bc)(trj)

    if domain_extent is not None:
        space_range = (0, domain_extent)
    else:
        # One more because we wrapped the BC
        space_range = (0, trj_wrapped.shape[1])

    if dt is not None:
        time_range = (0, dt * trj_wrapped.shape[0])
        if not include_init:
            time_range = (dt, time_range[1])
    else:
        time_range = (0, trj_wrapped.shape[0] - 1)

    if ax is None:
        fig, ax = plt.subplots()
        return_all = True
    else:
        return_all = False

    im = ax.imshow(
        trj_wrapped[:, 0, :].T,
        vmin=vlim[0],
        vmax=vlim[1],
        cmap=cmap,
        origin="lower",
        aspect="auto",
        extent=(*time_range, *space_range),
        **kwargs,
    )
    ax.set_xlabel("Time")
    ax.set_ylabel("Space")

    if return_all:
        plt.close(fig)
        return fig
    else:
        return im

exponax.viz.plot_spatio_temporal_2d ¤

plot_spatio_temporal_2d(
    trj: Float[Array, "T 1 N N"],
    *,
    vlim: tuple[float, float] = (-1.0, 1.0),
    ax=None,
    domain_extent: float = None,
    dt: float = None,
    include_init: bool = False,
    bg_color: Union[
        Literal["black"],
        Literal["white"],
        tuple[jnp.int8, jnp.int8, jnp.int8, jnp.int8],
    ] = "white",
    resolution: int = 384,
    cmap: str = "RdBu_r",
    transfer_function: callable = zigzag_alpha,
    distance_scale: float = 10.0,
    gamma_correction: float = 2.4,
    **kwargs
)

Plot a trajectory of a 2d state as a spatio-temporal plot visualized by a volume render (space in in plain parallel to screen, and time in the direction into the screen).

Requires the input to be a real array with four axes: a leading time axis, a channel axis, and two subsequent spatial axes. Only the leading dimension in the channel axis will be plotted. See plot_spatio_temporal_facet for plotting multiple trajectories (e.g. for problems consisting of multiple channels like Burgers simulations).

Periodic boundary conditions will be applied to the spatial axes (the state is wrapped around).

Arguments:

  • trj: The trajectory to plot as a four axis array. The first axis should be the time axis, the second axis the channel axis, and the third and fourth axes the spatial axes.
  • vlim: The limits of the color scale.
  • ax: The axis to plot on. If not provided, a new figure will be created.
  • domain_extent: (Unused as of now)
  • dt: (Unused as of now)
  • include_init: (Unused as of now)
  • bg_color: The background color. Either "black", "white", or a tuple of RGBA values.
  • resolution: The resolution of the output image (affects render time).
  • cmap: The colormap to use.
  • transfer_function: The transfer function to use, i.e., how values within the vlim range are mapped to alpha values.
  • distance_scale: The distance scale of the volume renderer.
  • gamma_correction: The gamma correction to apply to the image.

Returns:

  • If ax is not provided, returns the figure. Otherwise, returns the image object.

Note:

  • This function requires the vape volume renderer package.
Source code in exponax/viz/_plot.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
def plot_spatio_temporal_2d(
    trj: Float[Array, "T 1 N N"],
    *,
    vlim: tuple[float, float] = (-1.0, 1.0),
    ax=None,
    domain_extent: float = None,
    dt: float = None,
    include_init: bool = False,
    bg_color: Union[
        Literal["black"],
        Literal["white"],
        tuple[jnp.int8, jnp.int8, jnp.int8, jnp.int8],
    ] = "white",
    resolution: int = 384,
    cmap: str = "RdBu_r",
    transfer_function: callable = zigzag_alpha,
    distance_scale: float = 10.0,
    gamma_correction: float = 2.4,
    **kwargs,
):
    """
    Plot a trajectory of a 2d state as a spatio-temporal plot visualized by a
    volume render (space in in plain parallel to screen, and time in the
    direction into the screen).

    Requires the input to be a real array with four axes: a leading time axis, a
    channel axis, and two subsequent spatial axes. Only the leading dimension in
    the channel axis will be plotted. See `plot_spatio_temporal_facet` for
    plotting multiple trajectories (e.g. for problems consisting of multiple
    channels like Burgers simulations).

    Periodic boundary conditions will be applied to the spatial axes (the state
    is wrapped around).

    **Arguments:**

    - `trj`: The trajectory to plot as a four axis array. The first axis should
        be the time axis, the second axis the channel axis, and the third and
        fourth axes the spatial axes.
    - `vlim`: The limits of the color scale.
    - `ax`: The axis to plot on. If not provided, a new figure will be created.
    - `domain_extent`: (Unused as of now)
    - `dt`: (Unused as of now)
    - `include_init`: (Unused as of now)
    - `bg_color`: The background color. Either `"black"`, `"white"`, or a tuple
        of RGBA values.
    - `resolution`: The resolution of the output image (affects render time).
    - `cmap`: The colormap to use.
    - `transfer_function`: The transfer function to use, i.e., how values within
        the `vlim` range are mapped to alpha values.
    - `distance_scale`: The distance scale of the volume renderer.
    - `gamma_correction`: The gamma correction to apply to the image.

    **Returns:**

    - If `ax` is not provided, returns the figure. Otherwise, returns the image
        object.

    **Note:**

    - This function requires the `vape` volume renderer package.
    """
    if trj.ndim != 4:
        raise ValueError("trj must be a four-axis array.")

    trj_one_channel = trj[:, 0:1]
    trj_one_channel_wrapped = jax.vmap(wrap_bc)(trj_one_channel)

    trj_reshaped_to_3d = jnp.flip(
        jnp.array(trj_one_channel_wrapped.transpose(1, 2, 3, 0)), 3
    )

    imgs = volume_render_state_3d(
        trj_reshaped_to_3d,
        vlim=vlim,
        bg_color=bg_color,
        resolution=resolution,
        cmap=cmap,
        transfer_function=transfer_function,
        distance_scale=distance_scale,
        gamma_correction=gamma_correction,
        **kwargs,
    )

    img = imgs[0]

    if ax is None:
        fig, ax = plt.subplots()
        return_all = True
    else:
        return_all = False

    im = ax.imshow(img)
    ax.axis("off")

    if return_all:
        plt.close(fig)
        return fig
    else:
        return im