Skip to content

Animate Spatio-Temporal Facet¤

exponax.viz.animate_spatio_temporal_facet ¤

animate_spatio_temporal_facet(
    trjs: Union[
        Float[Array, "S T C N"], Float[Array, "B S T 1 N"]
    ],
    *,
    facet_over_channels: bool = True,
    vlim: tuple[float, float] = (-1.0, 1.0),
    cmap: str = "RdBu_r",
    domain_extent: float = None,
    dt: float = None,
    include_init: bool = False,
    grid: tuple[int, int] = (3, 3),
    figsize: tuple[float, float] = (10, 10),
    **kwargs
)

Animate a facet of trajectories of spatio-temporal states. Allows to visualize "two time dimensions". One time dimension is the x-axis. The other is via the animation. For instance, this can be used to present how neural predictors learn spatio-temporal dynamics over time. The additional faceting dimension can be used two compare multiple networks with one another.

Requires the input to be either a four-axis array or a five-axis array:

  • If facet_over_channels is True, the input must be a four-axis array with a leading outer time axis, a time axis, a channel axis, and a spatial axis. Each faceted subplot displays a different channel.
  • If facet_over_channels is False, the input must be a five-axis array with a leading batch axis, an outer time axis, a time axis, a channel axis, and a spatial axis. Each faceted subplot displays a different batch, only the zeroth dimension in the channel axis is plotted.

Periodic boundary conditions will be applied to the spatial axis (the state is wrapped around).

Arguments:

  • trjs: The trajectory of states to animate. Must be a four-axis array with shape (n_timesteps_outer, n_time_steps, n_channels, n_spatial) if facet_over_channels is True, or a five-axis array with shape (n_batches, n_timesteps_outer, n_time_steps, n_channels, n_spatial) if facet_over_channels is False.
  • facet_over_channels: Whether to facet over the channel axis or the batch axis. Default is True.
  • vlim: The limits of the colorbar. Default is (-1, 1).
  • cmap: The colormap to use. Default is "RdBu_r".
  • domain_extent: The extent of the spatial domain. Default is None. This affects the x-axis limits of the plot.
  • dt: The time step between each frame. Default is None. If provided, a title will be displayed with the current time. If not provided, just the frames are counted.
  • include_init: Whether to the state starts at an initial condition (t=0) or at the first frame in the trajectory. This affects is the the time range is [0, (T-1)dt] or [dt, Tdt]. Default is False.
  • grid: The grid of subplots. Default is (3, 3).
  • figsize: The size of the figure. Default is (10, 10).
  • **kwargs: Additional keyword arguments to pass to the plotting function.

Returns:

  • ani: The animation object.
Source code in exponax/viz/_animate_facet.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def animate_spatio_temporal_facet(
    trjs: Union[Float[Array, "S T C N"], Float[Array, "B S T 1 N"]],
    *,
    facet_over_channels: bool = True,
    vlim: tuple[float, float] = (-1.0, 1.0),
    cmap: str = "RdBu_r",
    domain_extent: float = None,
    dt: float = None,
    include_init: bool = False,
    grid: tuple[int, int] = (3, 3),
    figsize: tuple[float, float] = (10, 10),
    **kwargs,
):
    """
    Animate a facet of trajectories of spatio-temporal states. Allows to
    visualize "two time dimensions". One time dimension is the x-axis. The other
    is via the animation. For instance, this can be used to present how neural
    predictors learn spatio-temporal dynamics over time. The additional faceting
    dimension can be used two compare multiple networks with one another.

    Requires the input to be either a four-axis array or a five-axis array:

    - If `facet_over_channels` is `True`, the input must be a four-axis array
        with a leading outer time axis, a time axis, a channel axis, and a
        spatial axis. Each faceted subplot displays a different channel.
    - If `facet_over_channels` is `False`, the input must be a five-axis array
        with a leading batch axis, an outer time axis, a time axis, a channel
        axis, and a spatial axis. Each faceted subplot displays a different
        batch, only the zeroth dimension in the channel axis is plotted.

    Periodic boundary conditions will be applied to the spatial axis (the state
    is wrapped around).

    **Arguments**:

    - `trjs`: The trajectory of states to animate. Must be a four-axis array
        with shape `(n_timesteps_outer, n_time_steps, n_channels, n_spatial)` if
        `facet_over_channels` is `True`, or a five-axis array with shape
        `(n_batches, n_timesteps_outer, n_time_steps, n_channels, n_spatial)` if
        `facet_over_channels` is `False`.
    - `facet_over_channels`: Whether to facet over the channel axis or the batch
        axis. Default is `True`.
    - `vlim`: The limits of the colorbar. Default is `(-1, 1)`.
    - `cmap`: The colormap to use. Default is `"RdBu_r"`.
    - `domain_extent`: The extent of the spatial domain. Default is `None`. This
        affects the x-axis limits of the plot.
    - `dt`: The time step between each frame. Default is `None`. If provided,
        a title will be displayed with the current time. If not provided, just
        the frames are counted.
    - `include_init`: Whether to the state starts at an initial condition (t=0)
        or at the first frame in the trajectory. This affects is the the time
        range is [0, (T-1)dt] or [dt, Tdt]. Default is `False`.
    - `grid`: The grid of subplots. Default is `(3, 3)`.
    - `figsize`: The size of the figure. Default is `(10, 10)`.
    - `**kwargs`: Additional keyword arguments to pass to the plotting function.

    **Returns**:

    - `ani`: The animation object.
    """
    if facet_over_channels:
        if trjs.ndim != 4:
            raise ValueError("trjs must be a four-axis array.")
    else:
        if trjs.ndim != 5:
            raise ValueError("states must be a five-axis array.")
    # TODO
    raise NotImplementedError("Not implemented yet.")

exponax.viz.animate_spatio_temporal_2d_facet ¤

animate_spatio_temporal_2d_facet()
Source code in exponax/viz/_animate_facet.py
448
449
450
def animate_spatio_temporal_2d_facet():
    # TODO
    raise NotImplementedError("Not implemented yet.")