Skip to content

Plot States¤

Requires arrays of the shape (num_channels, *num_points) with either one, two, or three last spatial axes.

exponax.viz.plot_state_1d ¤

plot_state_1d(
    state: Float[Array, "C N"],
    *,
    vlim: tuple[float, float] = (-1.0, 1.0),
    domain_extent: float = None,
    labels: list[str] = None,
    ax=None,
    xlabel: str = "Space",
    ylabel: str = "Value",
    **kwargs
)

Plot the state of a 1d field.

Requires the input to be a real array with two axis: a leading channel axis and a spatial axis.

Arguments:

  • state: The state to plot as a two axis array. If there is more than one dimension in the first axis (i.e., multiple channels) then each channel will be plotted in a different color. Use the labels argument to provide a legend.
  • vlim: The limits of the y-axis.
  • domain_extent: The extent of the spatial domain. If not provided, the domain extent will be the number of points in the spatial axis. This adjusts the x-axis.
  • labels: The labels for the legend. This should be a list of strings with the same length as the number of channels.
  • ax: The axis to plot on. If not provided, a new figure will be created.
  • **kwargs: Additional arguments to pass to the plot function.

Returns:

  • If ax is not provided, returns the figure. Otherwise, returns the plot object.
Source code in exponax/viz/_plot.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def plot_state_1d(
    state: Float[Array, "C N"],
    *,
    vlim: tuple[float, float] = (-1.0, 1.0),
    domain_extent: float = None,
    labels: list[str] = None,
    ax=None,
    xlabel: str = "Space",
    ylabel: str = "Value",
    **kwargs,
):
    """
    Plot the state of a 1d field.

    Requires the input to be a real array with two axis: a leading channel axis
    and a spatial axis.

    **Arguments:**

    - `state`: The state to plot as a two axis array. If there is more than one
        dimension in the first axis (i.e., multiple channels) then each channel
        will be plotted in a different color. Use the `labels` argument to
        provide a legend.
    - `vlim`: The limits of the y-axis.
    - `domain_extent`: The extent of the spatial domain. If not provided, the
        domain extent will be the number of points in the spatial axis. This
        adjusts the x-axis.
    - `labels`: The labels for the legend. This should be a list of strings with
        the same length as the number of channels.
    - `ax`: The axis to plot on. If not provided, a new figure will be created.
    - `**kwargs`: Additional arguments to pass to the plot function.

    **Returns:**

    - If `ax` is not provided, returns the figure. Otherwise, returns the plot
        object.
    """
    if state.ndim != 2:
        raise ValueError("state must be a two-axis array.")

    state_wrapped = wrap_bc(state)

    num_points = state.shape[-1]

    if domain_extent is None:
        # One more because we wrapped the BC
        domain_extent = num_points

    grid = make_grid(1, domain_extent, num_points, full=True)

    if ax is None:
        return_all = True
        fig, ax = plt.subplots()
    else:
        return_all = False

    p = ax.plot(grid[0], state_wrapped.T, label=labels, **kwargs)
    ax.set_ylim(vlim)
    ax.grid()
    if labels is not None:
        ax.legend()
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    if return_all:
        plt.close(fig)
        return fig
    else:
        return p

exponax.viz.plot_state_2d ¤

plot_state_2d(
    state: Float[Array, "1 N N"],
    *,
    vlim: tuple[float, float] = (-1.0, 1.0),
    cmap: str = "RdBu_r",
    domain_extent: float = None,
    ax=None,
    **kwargs
)

Visualizes a two-dimensional state as an image.

Requires the input to be a real array with three axes: a leading channel axis, and two subsequent spatial axes. This function will visualize the zeroth channel. For plotting multiple channels at the same time, see plot_state_2d_facet.

Periodic boundary conditions will be applied to the spatial axes (the state is wrapped around).

Arguments:

  • state: The state to plot as a three axis array. The first axis should be the channel axis, and the subsequent two axes the spatial axes.
  • vlim: The limits of the color scale.
  • cmap: The colormap to use.
  • domain_extent: The extent of the spatial domain. If not provided, the domain extent will be the number of points in the spatial axes. This adjusts the x and y axes.
  • ax: The axis to plot on. If not provided, a new figure will be created.
  • **kwargs: Additional arguments to pass to the imshow function.

Returns:

  • If ax is not provided, returns the figure. Otherwise, returns the image object.
Source code in exponax/viz/_plot.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
def plot_state_2d(
    state: Float[Array, "1 N N"],
    *,
    vlim: tuple[float, float] = (-1.0, 1.0),
    cmap: str = "RdBu_r",
    domain_extent: float = None,
    ax=None,
    **kwargs,
):
    """
    Visualizes a two-dimensional state as an image.

    Requires the input to be a real array with three axes: a leading channel
    axis, and two subsequent spatial axes. This function will visualize the
    zeroth channel. For plotting multiple channels at the same time, see
    `plot_state_2d_facet`.

    Periodic boundary conditions will be applied to the spatial axes (the state
    is wrapped around).

    **Arguments:**

    - `state`: The state to plot as a three axis array. The first axis should be
        the channel axis, and the subsequent two axes the spatial axes.
    - `vlim`: The limits of the color scale.
    - `cmap`: The colormap to use.
    - `domain_extent`: The extent of the spatial domain. If not provided, the
        domain extent will be the number of points in the spatial axes. This
        adjusts the x and y axes.
    - `ax`: The axis to plot on. If not provided, a new figure will be created.
    - `**kwargs`: Additional arguments to pass to the imshow function.

    **Returns:**

    - If `ax` is not provided, returns the figure. Otherwise, returns the image
        object.
    """
    if state.ndim != 3:
        raise ValueError("state must be a three-axis array.")

    if domain_extent is not None:
        space_range = (0, domain_extent)
    else:
        # One more because we wrapped the BC
        space_range = (0, state.shape[-1])

    state_wrapped = wrap_bc(state)

    if ax is None:
        fig, ax = plt.subplots()
        return_all = True
    else:
        return_all = False

    im = ax.imshow(
        state_wrapped.T,
        vmin=vlim[0],
        vmax=vlim[1],
        cmap=cmap,
        origin="lower",
        aspect="auto",
        extent=(*space_range, *space_range),
        **kwargs,
    )
    ax.set_xlabel("x_0")
    ax.set_ylabel("x_1")
    ax.set_aspect("equal")

    if return_all:
        plt.close(fig)
        return fig
    else:
        return im

exponax.viz.plot_state_3d ¤

plot_state_3d(
    state: Float[Array, "1 N N N"],
    *,
    vlim: tuple[float, float] = (-1.0, 1.0),
    domain_extent: float = None,
    ax=None,
    bg_color: Union[
        Literal["black"],
        Literal["white"],
        tuple[jnp.int8, jnp.int8, jnp.int8, jnp.int8],
    ] = "white",
    resolution: int = 384,
    cmap: str = "RdBu_r",
    transfer_function: callable = zigzag_alpha,
    distance_scale: float = 10.0,
    gamma_correction: float = 2.4,
    **kwargs
)

Visualizes a three-dimensional state as a volume rendering.

Requires the input to be a real array with four axes: a leading channel axis, and three subsequent spatial axes. This function will visualize the zeroth channel. For plotting multiple channels at the same time, see plot_state_3d_facet.

Periodic boundary conditions will be applied to the spatial axes (the state is wrapped around).

Arguments:

  • state: The state to plot as a four axis array. The first axis should be the channel axis, and the subsequent three axes the spatial axes.
  • vlim: The limits of the color scale.
  • domain_extent: (Unused as of now)
  • ax: The axis to plot on. If not provided, a new figure will be created.
  • bg_color: The background color. Either "black", "white", or a tuple of RGBA values.
  • resolution: The resolution of the output image (affects render time).
  • cmap: The colormap to use.
  • transfer_function: The transfer function to use, i.e., how values within the vlim range are mapped to alpha values.
  • distance_scale: The distance scale of the volume renderer.
  • gamma_correction: The gamma correction to apply to the image.

Returns:

  • If ax is not provided, returns the figure. Otherwise, returns the image object.

Note:

  • This function requires the vape volume renderer package.
Source code in exponax/viz/_plot.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
def plot_state_3d(
    state: Float[Array, "1 N N N"],
    *,
    vlim: tuple[float, float] = (-1.0, 1.0),
    domain_extent: float = None,
    ax=None,
    bg_color: Union[
        Literal["black"],
        Literal["white"],
        tuple[jnp.int8, jnp.int8, jnp.int8, jnp.int8],
    ] = "white",
    resolution: int = 384,
    cmap: str = "RdBu_r",
    transfer_function: callable = zigzag_alpha,
    distance_scale: float = 10.0,
    gamma_correction: float = 2.4,
    **kwargs,
):
    """
    Visualizes a three-dimensional state as a volume rendering.

    Requires the input to be a real array with four axes: a leading channel axis,
    and three subsequent spatial axes. This function will visualize the zeroth
    channel. For plotting multiple channels at the same time, see
    `plot_state_3d_facet`.

    Periodic boundary conditions will be applied to the spatial axes (the state
    is wrapped around).

    **Arguments:**

    - `state`: The state to plot as a four axis array. The first axis should be
        the channel axis, and the subsequent three axes the spatial axes.
    - `vlim`: The limits of the color scale.
    - `domain_extent`: (Unused as of now)
    - `ax`: The axis to plot on. If not provided, a new figure will be created.
    - `bg_color`: The background color. Either `"black"`, `"white"`, or a tuple
        of RGBA values.
    - `resolution`: The resolution of the output image (affects render time).
    - `cmap`: The colormap to use.
    - `transfer_function`: The transfer function to use, i.e., how values within
        the `vlim` range are mapped to alpha values.
    - `distance_scale`: The distance scale of the volume renderer.
    - `gamma_correction`: The gamma correction to apply to the image.

    **Returns:**

    - If `ax` is not provided, returns the figure. Otherwise, returns the image
        object.

    **Note:**

    - This function requires the `vape` volume renderer package.
    """
    if state.ndim != 4:
        raise ValueError("state must be a four-axis array.")

    one_channel_state = state[0:1]
    one_channel_state_wrapped = wrap_bc(one_channel_state)

    imgs = volume_render_state_3d(
        one_channel_state_wrapped,
        vlim=vlim,
        bg_color=bg_color,
        resolution=resolution,
        cmap=cmap,
        transfer_function=transfer_function,
        distance_scale=distance_scale,
        gamma_correction=gamma_correction,
        **kwargs,
    )

    img = imgs[0]

    if ax is None:
        fig, ax = plt.subplots()
        return_all = True
    else:
        return_all = False

    im = ax.imshow(img)
    ax.axis("off")

    if return_all:
        plt.close(fig)
        return fig
    else:
        return im