Skip to content

Animate Spatio-Temporal¤

exponax.viz.animate_spatio_temporal ¤

animate_spatio_temporal(
    trjs: Float[Array, "S T C N"],
    *,
    vlim: tuple[float, float] = (-1.0, 1.0),
    cmap: str = "RdBu_r",
    domain_extent: float = None,
    dt: float = None,
    include_init: bool = False,
    **kwargs
)

Animate a trajectory of spatio-temporal states. Allows to visualize "two time dimensions". One time dimension is the x-axis. The other is via the animation. For instance, this can be used to present how neural predictors learn spatio-temporal dynamics over time.

Requires the input to be a four-axis array with a leading spatial axis, a time axis, a channel axis, and a batch axis. Only the zeroth dimension in the channel axis is plotted.

Periodic boundary conditions will be applied to the spatial axis (the state is wrapped around).

Arguments:

  • trjs: The trajectory of states to animate. Must be a four-axis array with shape (n_timesteps_outer, n_time_steps, n_channels, n_spatial).
  • vlim: The limits of the colorbar. Default is (-1, 1).
  • cmap: The colormap to use. Default is "RdBu_r".
  • domain_extent: The extent of the spatial domain. Default is None. This affects the x-axis limits of the plot.
  • dt: The time step between each frame. Default is None. If provided, a title will be displayed with the current time. If not provided, just the frames are counted.
  • include_init: Whether to the state starts at an initial condition (t=0) or at the first frame in the trajectory. This affects is the the time range is [0, (T-1)dt] or [dt, Tdt]. Default is False.
  • **kwargs: Additional keyword arguments to pass to the plotting function.

Returns:

  • ani: The animation object.
Source code in exponax/viz/_animate.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def animate_spatio_temporal(
    trjs: Float[Array, "S T C N"],
    *,
    vlim: tuple[float, float] = (-1.0, 1.0),
    cmap: str = "RdBu_r",
    domain_extent: float = None,
    dt: float = None,
    include_init: bool = False,
    **kwargs,
):
    """
    Animate a trajectory of spatio-temporal states. Allows to visualize "two
    time dimensions". One time dimension is the x-axis. The other is via the
    animation. For instance, this can be used to present how neural predictors
    learn spatio-temporal dynamics over time.

    Requires the input to be a four-axis array with a leading spatial axis, a
    time axis, a channel axis, and a batch axis. Only the zeroth dimension in
    the channel axis is plotted.

    Periodic boundary conditions will be applied to the spatial axis (the state
    is wrapped around).

    **Arguments**:

    - `trjs`: The trajectory of states to animate. Must be a four-axis array
        with shape `(n_timesteps_outer, n_time_steps, n_channels, n_spatial)`.
    - `vlim`: The limits of the colorbar. Default is `(-1, 1)`.
    - `cmap`: The colormap to use. Default is `"RdBu_r"`.
    - `domain_extent`: The extent of the spatial domain. Default is `None`. This
        affects the x-axis limits of the plot.
    - `dt`: The time step between each frame. Default is `None`. If provided,
        a title will be displayed with the current time. If not provided, just
        the frames are counted.
    - `include_init`: Whether to the state starts at an initial condition (t=0)
        or at the first frame in the trajectory. This affects is the the time
        range is [0, (T-1)dt] or [dt, Tdt]. Default is `False`.
    - `**kwargs`: Additional keyword arguments to pass to the plotting function.

    **Returns**:

    - `ani`: The animation object.
    """
    if trjs.ndim != 4:
        raise ValueError("trjs must be a four-axis array.")

    fig, ax = plt.subplots()

    plot_spatio_temporal(
        trjs[0],
        vlim=vlim,
        cmap=cmap,
        domain_extent=domain_extent,
        dt=dt,
        include_init=include_init,
        ax=ax,
        **kwargs,
    )

    def animate(i):
        ax.clear()
        plot_spatio_temporal(
            trjs[i],
            vlim=vlim,
            cmap=cmap,
            domain_extent=domain_extent,
            dt=dt,
            include_init=include_init,
            ax=ax,
            **kwargs,
        )

    plt.close(fig)

    ani = FuncAnimation(fig, animate, frames=trjs.shape[0], interval=100, blit=False)

    return ani

exponax.viz.animate_spatio_temporal_2d ¤

animate_spatio_temporal_2d()
Source code in exponax/viz/_animate.py
358
359
def animate_spatio_temporal_2d():
    raise NotImplementedError("This function is not yet implemented.")