Skip to content

Spatio-Temporal Plots in Facet Grid¤

exponax.viz.plot_spatio_temporal_facet ¤

plot_spatio_temporal_facet(
    trjs: Union[
        Float[Array, "T C N"], Float[Array, "B T 1 N"]
    ],
    *,
    facet_over_channels: bool = True,
    vlim: tuple[float, float] = (-1.0, 1.0),
    cmap: str = "RdBu_r",
    grid: tuple[int, int] = (3, 3),
    figsize: tuple[float, float] = (10, 10),
    titles: list[str] = None,
    domain_extent: float = None,
    dt: float = None,
    include_init: bool = False,
    **kwargs
)

Plot a facet of spatio-temporal trajectories.

Requires the input to be a real array with either three or four axes:

  • Three axes: a leading time axis, a channel axis, and a spatial axis. The faceting is performed over the channel axis. Requires the facet_over_channels argument to be True (default).
  • Four axes: a leading batch axis, a time axis, a channel axis, and a spatial axis. The faceting is performed over the batch axis. Requires the facet_over_channels argument to be False. Only the zeroth channel for each trajectory will be plotted.

Periodic boundary conditions will be applied to the spatial axis (the state is wrapped around).

Arguments:

  • trjs: The trajectories to plot as a three or four axis array. See above for the requirements.
  • facet_over_channels: Whether to facet over the channel axis (three axes) or the batch axis (four axes).
  • vlim: The limits of the color scale.
  • cmap: The colormap to use.
  • grid: The grid layout for the facet plot. This should be a tuple with two integers. If the number of trajectories is less than the product of the grid, the remaining axes will be removed.
  • figsize: The size of the figure.
  • titles: The titles for each plot. This should be a list of strings with the same length as the number of trajectories.
  • domain_extent: The extent of the spatial domain. If not provided, the domain extent will be the number of points in the spatial axis. This adjusts the y-axis.
  • dt: The time step. This adjust the extent of the x-axis. If not provided, the time axis will be the number of time steps.
  • include_init: Will affect the ticks of the time axis. If True, they will start at zero. If False, they will start at the time step.
  • **kwargs: Additional arguments to pass to the imshow function.

Returns:

  • The figure.
Source code in exponax/viz/_plot_facet.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
def plot_spatio_temporal_facet(
    trjs: Union[Float[Array, "T C N"], Float[Array, "B T 1 N"]],
    *,
    facet_over_channels: bool = True,
    vlim: tuple[float, float] = (-1.0, 1.0),
    cmap: str = "RdBu_r",
    grid: tuple[int, int] = (3, 3),
    figsize: tuple[float, float] = (10, 10),
    titles: list[str] = None,
    domain_extent: float = None,
    dt: float = None,
    include_init: bool = False,
    **kwargs,
):
    """
    Plot a facet of spatio-temporal trajectories.

    Requires the input to be a real array with either three or four axes:

    * Three axes: a leading time axis, a channel axis, and a spatial axis. The
        faceting is performed over the channel axis. Requires the
        `facet_over_channels` argument to be `True` (default).
    * Four axes: a leading batch axis, a time axis, a channel axis, and a
      spatial
        axis. The faceting is performed over the batch axis. Requires the
        `facet_over_channels` argument to be `False`. Only the zeroth channel
        for each trajectory will be plotted.

    Periodic boundary conditions will be applied to the spatial axis (the state
    is wrapped around).

    **Arguments:**

    - `trjs`: The trajectories to plot as a three or four axis array. See above
        for the requirements.
    - `facet_over_channels`: Whether to facet over the channel axis (three axes)
        or the batch axis (four axes).
    - `vlim`: The limits of the color scale.
    - `cmap`: The colormap to use.
    - `grid`: The grid layout for the facet plot. This should be a tuple with
        two integers. If the number of trajectories is less than the product of
        the grid, the remaining axes will be removed.
    - `figsize`: The size of the figure.
    - `titles`: The titles for each plot. This should be a list of strings with
        the same length as the number of trajectories.
    - `domain_extent`: The extent of the spatial domain. If not provided, the
        domain extent will be the number of points in the spatial axis. This
        adjusts the y-axis.
    - `dt`: The time step. This adjust the extent of the x-axis. If not
        provided, the time axis will be the number of time steps.
    - `include_init`: Will affect the ticks of the time axis. If `True`, they
        will start at zero. If `False`, they will start at the time step.
    - `**kwargs`: Additional arguments to pass to the imshow function.

    **Returns:**

    - The figure.
    """
    if facet_over_channels:
        if trjs.ndim != 3:
            raise ValueError("trjs must be a three-axis array.")
    else:
        if trjs.ndim != 4:
            raise ValueError("trjs must be a four-axis array.")

    fig, ax_s = plt.subplots(*grid, sharex=True, sharey=True, figsize=figsize)

    if facet_over_channels:
        trjs = jnp.swapaxes(trjs, 0, 1)
        trjs = trjs[:, :, None, :]

    num_subplots = trjs.shape[0]

    for i, ax in enumerate(ax_s.flatten()):
        single_trj = trjs[i]
        plot_spatio_temporal(
            single_trj,
            vlim=vlim,
            cmap=cmap,
            ax=ax,
            domain_extent=domain_extent,
            dt=dt,
            include_init=include_init,
            **kwargs,
        )
        if i >= num_subplots:
            ax.remove()
        else:
            if titles is not None:
                ax.set_title(titles[i])

    plt.close(fig)

    return fig

exponax.viz.plot_spatio_temporal_2d_facet ¤

plot_spatio_temporal_2d_facet(
    trjs: Union[
        Float[Array, "T C N N"], Float[Array, "B T 1 N N"]
    ],
    *,
    facet_over_channels: bool = True,
    vlim: tuple[float, float] = (-1.0, 1.0),
    grid: tuple[int, int] = (3, 3),
    figsize: tuple[float, float] = (10, 10),
    titles: list[str] = None,
    domain_extent: float = None,
    dt: float = None,
    include_init: bool = False,
    bg_color: Union[
        Literal["black"],
        Literal["white"],
        tuple[jnp.int8, jnp.int8, jnp.int8, jnp.int8],
    ] = "white",
    resolution: int = 384,
    cmap: str = "RdBu_r",
    transfer_function: callable = zigzag_alpha,
    distance_scale: float = 10.0,
    gamma_correction: float = 2.4,
    **kwargs
)

Plot a facet of spatio-temporal trajectories.

Requires the input to be a real array with either four or five axes:

  • Four axes: a leading time axis, a channel axis, and two subsequent spatial axes. The faceting is performed over the channel axis. Requires the facet_over_channels argument to be True (default).
  • Five axes: a leading batch axis, a time axis, a channel axis, and two subsequent spatial axes. The faceting is performed over the batch axis. Requires the facet_over_channels argument to be False.

Periodic boundary conditions will be applied to the spatial axes (the state is wrapped around).

Arguments:

  • trjs: The trajectories to plot as a four or five axis array. See above for the requirements.
  • facet_over_channels: Whether to facet over the channel axis (four axes) or the batch axis (five axes).
  • vlim: The limits of the color scale.
  • grid: The grid layout for the facet plot. This should be a tuple with two integers. If the number of trajectories is less than the product of the grid, the remaining axes will be removed.
  • figsize: The size of the figure.
  • titles: The titles for each plot. This should be a list of strings with the same length as the number of trajectories.
  • domain_extent: (Unused as of now)
  • dt: (Unused as of now)
  • include_init: (Unused as of now)
  • bg_color: The background color. Either "black", "white", or a tuple of RGBA values.
  • resolution: The resolution of the output image (affects render time).
  • cmap: The colormap to use.
  • transfer_function: The transfer function to use, i.e., how values within the vlim range are mapped to alpha values.
  • distance_scale: The distance scale of the volume renderer.
  • gamma_correction: The gamma correction to apply to the image.

Returns:

  • The figure.

Note:

  • This function requires the vape volume renderer package.
Source code in exponax/viz/_plot_facet.py
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
def plot_spatio_temporal_2d_facet(
    trjs: Union[Float[Array, "T C N N"], Float[Array, "B T 1 N N"]],
    *,
    facet_over_channels: bool = True,
    vlim: tuple[float, float] = (-1.0, 1.0),
    grid: tuple[int, int] = (3, 3),
    figsize: tuple[float, float] = (10, 10),
    titles: list[str] = None,
    domain_extent: float = None,
    dt: float = None,
    include_init: bool = False,
    bg_color: Union[
        Literal["black"],
        Literal["white"],
        tuple[jnp.int8, jnp.int8, jnp.int8, jnp.int8],
    ] = "white",
    resolution: int = 384,
    cmap: str = "RdBu_r",
    transfer_function: callable = zigzag_alpha,
    distance_scale: float = 10.0,
    gamma_correction: float = 2.4,
    **kwargs,
):
    """
    Plot a facet of spatio-temporal trajectories.

    Requires the input to be a real array with either four or five axes:

    * Four axes: a leading time axis, a channel axis, and two subsequent spatial
        axes. The faceting is performed over the channel axis. Requires the
        `facet_over_channels` argument to be `True` (default).
    * Five axes: a leading batch axis, a time axis, a channel axis, and two
        subsequent spatial axes. The faceting is performed over the batch axis.
        Requires the `facet_over_channels` argument to be `False`.

    Periodic boundary conditions will be applied to the spatial axes (the state
    is wrapped around).

    **Arguments:**

    - `trjs`: The trajectories to plot as a four or five axis array. See above
        for the requirements.
    - `facet_over_channels`: Whether to facet over the channel axis (four axes)
        or the batch axis (five axes).
    - `vlim`: The limits of the color scale.
    - `grid`: The grid layout for the facet plot. This should be a tuple with
        two integers. If the number of trajectories is less than the product of
        the grid, the remaining axes will be removed.
    - `figsize`: The size of the figure.
    - `titles`: The titles for each plot. This should be a list of strings with
        the same length as the number of trajectories.
    - `domain_extent`: (Unused as of now)
    - `dt`: (Unused as of now)
    - `include_init`: (Unused as of now)
    - `bg_color`: The background color. Either `"black"`, `"white"`, or a tuple
        of RGBA values.
    - `resolution`: The resolution of the output image (affects render time).
    - `cmap`: The colormap to use.
    - `transfer_function`: The transfer function to use, i.e., how values within
        the `vlim` range are mapped to alpha values.
    - `distance_scale`: The distance scale of the volume renderer.
    - `gamma_correction`: The gamma correction to apply to the image.

    **Returns:**

    - The figure.

    **Note:**

    - This function requires the `vape` volume renderer package.
    """
    if facet_over_channels:
        if trjs.ndim != 4:
            raise ValueError("trjs must be a four-axis array.")
    else:
        if trjs.ndim != 5:
            raise ValueError("trjs must be a five-axis array.")

    fig, ax_s = plt.subplots(*grid, figsize=figsize)

    if facet_over_channels:
        trjs = jnp.swapaxes(trjs, 0, 1)
        trjs = trjs[:, :, None, :, :]

    num_subplots = trjs.shape[0]

    for i, ax in enumerate(ax_s.flatten()):
        single_trj = trjs[i]
        plot_spatio_temporal_2d(
            single_trj,
            vlim=vlim,
            ax=ax,
            domain_extent=domain_extent,
            dt=dt,
            include_init=include_init,
            bg_color=bg_color,
            resolution=resolution,
            cmap=cmap,
            transfer_function=transfer_function,
            distance_scale=distance_scale,
            gamma_correction=gamma_correction,
            **kwargs,
        )
        if i >= num_subplots:
            ax.remove()
        else:
            if titles is not None:
                ax.set_title(titles[i])

    plt.close(fig)

    return fig