Skip to content

Plot states in facet grid¤

Allows displaying multiple states at the same time.

exponax.viz.plot_state_1d_facet ¤

plot_state_1d_facet(
    states: Float[Array, "B C N"],
    *,
    vlim: tuple[float, float] = (-1.0, 1.0),
    labels: list[str] = None,
    titles: list[str] = None,
    domain_extent: float = None,
    grid: tuple[int, int] = (3, 3),
    figsize: tuple[float, float] = (10, 10),
    **kwargs
)

Plot a facet of 1d states.

Requires the input to be a real array with three axis: a leading batch axis, a channel axis, and a spatial axis. Dimensions in the batch axis will be distributed over the individual plots. Dimensions in the channel axis will be plotted in different colors.

Arguments:

  • states: The states to plot as a three axis array. If there is more than one dimension in the channel axis (i.e., multiple channels) then each channel will be plotted in a different color. Use the labels argument to provide a legend. Use the titles argument to provide titles for each plot.
  • vlim: The limits of the y-axis.
  • labels: The labels for the legend. This should be a list of strings with the same length as the number of channels.
  • titles: The titles for each plot. This should be a list of strings with the same length as the number of states.
  • domain_extent: The extent of the spatial domain. If not provided, the domain extent will be the number of points in the spatial axis. This adjusts the x-axis.
  • grid: The grid layout for the facet plot. This should be a tuple with two integers. If the number of states is less than the product of the grid, the remaining axes will be removed.
  • figsize: The size of the figure.
  • **kwargs: Additional arguments to pass to the plot function.

Returns:

  • The figure.
Source code in exponax/viz/_plot_facet.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def plot_state_1d_facet(
    states: Float[Array, "B C N"],
    *,
    vlim: tuple[float, float] = (-1.0, 1.0),
    labels: list[str] = None,
    titles: list[str] = None,
    domain_extent: float = None,
    grid: tuple[int, int] = (3, 3),
    figsize: tuple[float, float] = (10, 10),
    **kwargs,
):
    """
    Plot a facet of 1d states.

    Requires the input to be a real array with three axis: a leading batch axis,
    a channel axis, and a spatial axis. Dimensions in the batch axis will be
    distributed over the individual plots. Dimensions in the channel axis will
    be plotted in different colors.

    **Arguments:**

    - `states`: The states to plot as a three axis array. If there is more than
        one dimension in the channel axis (i.e., multiple channels) then each
        channel will be plotted in a different color. Use the `labels` argument
        to provide a legend. Use the `titles` argument to provide titles for each
        plot.
    - `vlim`: The limits of the y-axis.
    - `labels`: The labels for the legend. This should be a list of strings with
        the same length as the number of channels.
    - `titles`: The titles for each plot. This should be a list of strings with
        the same length as the number of states.
    - `domain_extent`: The extent of the spatial domain. If not provided, the
        domain extent will be the number of points in the spatial axis. This
        adjusts the x-axis.
    - `grid`: The grid layout for the facet plot. This should be a tuple with
        two integers. If the number of states is less than the product of the
        grid, the remaining axes will be removed.
    - `figsize`: The size of the figure.
    - `**kwargs`: Additional arguments to pass to the plot
        function.

    **Returns:**

    - The figure.
    """
    if states.ndim != 3:
        raise ValueError("states must be a three-axis array.")

    fig, ax_s = plt.subplots(*grid, figsize=figsize)

    num_batches = states.shape[0]

    if grid[0] * grid[1] == 1:
        ax_s = np.array([[ax_s]])
    for i, ax in enumerate(ax_s.flatten()):
        if i < num_batches:
            plot_state_1d(
                states[i],
                vlim=vlim,
                domain_extent=domain_extent,
                labels=labels,
                ax=ax,
                **kwargs,
            )
            if titles is not None:
                ax.set_title(titles[i])
        else:
            ax.remove()

    plt.close(fig)

    return fig

exponax.viz.plot_state_2d_facet ¤

plot_state_2d_facet(
    states: Union[
        Float[Array, "C N N"], Float[Array, "B 1 N N"]
    ],
    *,
    facet_over_channels: bool = True,
    vlim: tuple[float, float] = (-1.0, 1.0),
    cmap: str = "RdBu_r",
    grid: tuple[int, int] = (3, 3),
    figsize: tuple[float, float] = (10, 10),
    titles: list[str] = None,
    domain_extent: float = None,
    **kwargs
)

Plot a facet of 2d states.

Requires the input to be a real array with three or four axes:

  • Three axes: a leading channel axis, and two subsequent spatial axes. The facet will be done over the channel axis, requires the facet_over_channels argument to be True (default).
  • Four axes: a leading batch axis, a channel axis, and two subsequent spatial axes. The facet will be done over the batch axis, requires the facet_over_channels argument to be False. Only the zeroth channel for each state will be plotted.

Periodic boundary conditions will be applied to the spatial axes (the state is wrapped around).

Arguments:

  • states: The states to plot as a three or four axis array. See above for the requirements.
  • facet_over_channels: Whether to facet over the channel axis (three axes) or the batch axis (four axes).
  • vlim: The limits of the color scale.
  • cmap: The colormap to use.
  • grid: The grid layout for the facet plot. This should be a tuple with two integers. If the number of states is less than the product of the grid, the remaining axes will be removed.
  • figsize: The size of the figure.
  • titles: The titles for each plot. This should be a list of strings with the same length as the number of states.
  • domain_extent: The extent of the spatial domain. If not provided, the domain extent will be the number of points in the spatial axes. This adjusts the x and y axes.
  • **kwargs: Additional arguments to pass to the imshow function.

Returns:

  • The figure.
Source code in exponax/viz/_plot_facet.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
def plot_state_2d_facet(
    states: Union[Float[Array, "C N N"], Float[Array, "B 1 N N"]],
    *,
    facet_over_channels: bool = True,
    vlim: tuple[float, float] = (-1.0, 1.0),
    cmap: str = "RdBu_r",
    grid: tuple[int, int] = (3, 3),
    figsize: tuple[float, float] = (10, 10),
    titles: list[str] = None,
    domain_extent: float = None,
    **kwargs,
):
    """
    Plot a facet of 2d states.

    Requires the input to be a real array with three or four axes:

    * Three axes: a leading channel axis, and two subsequent spatial axes. The
        facet will be done over the channel axis, requires the
        `facet_over_channels` argument to be `True` (default).
    * Four axes: a leading batch axis, a channel axis, and two subsequent
        spatial axes. The facet will be done over the batch axis, requires the
        `facet_over_channels` argument to be `False`. Only the zeroth channel
        for each state will be plotted.

    Periodic boundary conditions will be applied to the spatial axes (the state
    is wrapped around).

    **Arguments:**

    - `states`: The states to plot as a three or four axis array. See above for
        the requirements.
    - `facet_over_channels`: Whether to facet over the channel axis (three axes)
        or the batch axis (four axes).
    - `vlim`: The limits of the color scale.
    - `cmap`: The colormap to use.
    - `grid`: The grid layout for the facet plot. This should be a tuple with
        two integers. If the number of states is less than the product of the
        grid, the remaining axes will be removed.
    - `figsize`: The size of the figure.
    - `titles`: The titles for each plot. This should be a list of strings with
        the same length as the number of states.
    - `domain_extent`: The extent of the spatial domain. If not provided, the
        domain extent will be the number of points in the spatial axes. This
        adjusts the x and y axes.
    - `**kwargs`: Additional arguments to pass to the imshow function.

    **Returns:**

    - The figure.
    """
    if facet_over_channels:
        if states.ndim != 3:
            raise ValueError("states must be a three-axis array.")
        states = states[:, None, :, :]
    else:
        if states.ndim != 4:
            raise ValueError("states must be a four-axis array.")

    fig, ax_s = plt.subplots(*grid, sharex=True, sharey=True, figsize=figsize)

    num_subplots = states.shape[0]

    if grid[0] * grid[1] == 1:
        ax_s = np.array([[ax_s]])
    for i, ax in enumerate(ax_s.flatten()):
        plot_state_2d(
            states[i],
            vlim=vlim,
            cmap=cmap,
            ax=ax,
            domain_extent=domain_extent,
            **kwargs,
        )
        if i >= num_subplots:
            ax.remove()
        else:
            if titles is not None:
                ax.set_title(titles[i])

    plt.close(fig)

    return fig

exponax.viz.plot_state_3d_facet ¤

plot_state_3d_facet(
    states: Union[
        Float[Array, "C N N N"], Float[Array, "B 1 N N N"]
    ],
    *,
    facet_over_channels: bool = True,
    vlim: tuple[float, float] = (-1.0, 1.0),
    grid: tuple[int, int] = (3, 3),
    figsize: tuple[float, float] = (10, 10),
    titles: list[str] = None,
    domain_extent: float = None,
    bg_color: Union[
        Literal["black"],
        Literal["white"],
        tuple[jnp.int8, jnp.int8, jnp.int8, jnp.int8],
    ] = "white",
    resolution: int = 384,
    cmap: str = "RdBu_r",
    transfer_function: callable = zigzag_alpha,
    distance_scale: float = 10.0,
    gamma_correction: float = 2.4,
    **kwargs
)

Plot a facet of 3d states as volume renders.

Requires the input to be a real array with four or five axes: a leading batch axis, a channel axis, and three subsequent spatial axes. The facet will be done over the batch axis, requires the facet_over_channels argument to be False. Only the zeroth channel for each state will be plotted.

Periodic boundary conditions will be applied to the spatial axes (the state is wrapped around).

Arguments:

  • states: The states to plot as a four or five axis array. See above for the requirements.
  • facet_over_channels: Whether to facet over the channel axis (four axes) or the batch axis (five axes).
  • vlim: The limits of the color scale.
  • grid: The grid layout for the facet plot. This should be a tuple with two integers. If the number of states is less than the product of the grid, the remaining axes will be removed.
  • figsize: The size of the figure.
  • titles: The titles for each plot. This should be a list of strings with the same length as the number of states.
  • domain_extent: (Unused as of now)
  • bg_color: The background color. Either "black", "white", or a tuple of RGBA values.
  • resolution: The resolution of the output image (affects render time).
  • cmap: The colormap to use.
  • transfer_function: The transfer function to use, i.e., how values within the vlim range are mapped to alpha values.
  • distance_scale: The distance scale of the volume renderer.
  • gamma_correction: The gamma correction to apply to the image.

Returns:

  • The figure.

Note:

  • This function requires the vape volume renderer package.
Source code in exponax/viz/_plot_facet.py
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
def plot_state_3d_facet(
    states: Union[Float[Array, "C N N N"], Float[Array, "B 1 N N N"]],
    *,
    facet_over_channels: bool = True,
    vlim: tuple[float, float] = (-1.0, 1.0),
    grid: tuple[int, int] = (3, 3),
    figsize: tuple[float, float] = (10, 10),
    titles: list[str] = None,
    domain_extent: float = None,
    bg_color: Union[
        Literal["black"],
        Literal["white"],
        tuple[jnp.int8, jnp.int8, jnp.int8, jnp.int8],
    ] = "white",
    resolution: int = 384,
    cmap: str = "RdBu_r",
    transfer_function: callable = zigzag_alpha,
    distance_scale: float = 10.0,
    gamma_correction: float = 2.4,
    **kwargs,
):
    """
    Plot a facet of 3d states as volume renders.

    Requires the input to be a real array with four or five axes: a leading
    batch axis, a channel axis, and three subsequent spatial axes. The facet
    will be done over the batch axis, requires the `facet_over_channels`
    argument to be `False`. Only the zeroth channel for each state will be
    plotted.

    Periodic boundary conditions will be applied to the spatial axes (the state
    is wrapped around).

    **Arguments:**

    - `states`: The states to plot as a four or five axis array. See above for
        the requirements.
    - `facet_over_channels`: Whether to facet over the channel axis (four axes)
        or the batch axis (five axes).
    - `vlim`: The limits of the color scale.
    - `grid`: The grid layout for the facet plot. This should be a tuple with
        two integers. If the number of states is less than the product of the
        grid, the remaining axes will be removed.
    - `figsize`: The size of the figure.
    - `titles`: The titles for each plot. This should be a list of strings with
        the same length as the number of states.
    - `domain_extent`: (Unused as of now)
    - `bg_color`: The background color. Either `"black"`, `"white"`, or a tuple
        of RGBA values.
    - `resolution`: The resolution of the output image (affects render time).
    - `cmap`: The colormap to use.
    - `transfer_function`: The transfer function to use, i.e., how values within
        the `vlim` range are mapped to alpha values.
    - `distance_scale`: The distance scale of the volume renderer.
    - `gamma_correction`: The gamma correction to apply to the image.

    **Returns:**

    - The figure.

    **Note:**

    - This function requires the `vape` volume renderer package.
    """
    if facet_over_channels:
        if states.ndim != 4:
            raise ValueError("states must be a four-axis array.")
        states = states[:, None, :, :, :]
    else:
        if states.ndim != 5:
            raise ValueError("states must be a five-axis array.")

    fig, ax_s = plt.subplots(*grid, figsize=figsize)

    num_subplots = states.shape[0]

    if grid[0] * grid[1] == 1:
        ax_s = np.array([[ax_s]])
    for i, ax in enumerate(ax_s.flatten()):
        plot_state_3d(
            states[i],
            vlim=vlim,
            domain_extent=domain_extent,
            ax=ax,
            bg_color=bg_color,
            resolution=resolution,
            cmap=cmap,
            transfer_function=transfer_function,
            distance_scale=distance_scale,
            gamma_correction=gamma_correction,
            **kwargs,
        )
        if i >= num_subplots:
            ax.remove()
        else:
            if titles is not None:
                ax.set_title(titles[i])

    plt.close(fig)

    return fig