Skip to content

Backends API

SolveResult

numen.bridge.runtime.SolveResult dataclass

Output from any backend solve() call.

Attributes:

Name Type Description
t ndarray

Time points, shape (n_steps,).

x ndarray

State matrix, shape (state_size, n_steps). Row i is the time series for state slot i.

timings_ms list[float]

Per-solve wall times in ms (Julia backends only). timings_ms[0] = first solve (JIT + dynamics); timings_ms[1:] = warm solves when reps > 1.

Properties

startup_ms: Wall time minus sum of timings_ms (subprocess + package load). jit_ms: First timing (includes JIT), or None if not available. warm_ms: Minimum of warm timings, or None if fewer than 2 reps.

Source code in src/numen/bridge/runtime.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
@dataclass
class SolveResult:
    """Output from any backend ``solve()`` call.

    Attributes:
        t:          Time points, shape ``(n_steps,)``.
        x:          State matrix, shape ``(state_size, n_steps)``.
                    Row ``i`` is the time series for state slot ``i``.
        timings_ms: Per-solve wall times in ms (Julia backends only).
                    ``timings_ms[0]`` = first solve (JIT + dynamics);
                    ``timings_ms[1:]`` = warm solves when ``reps > 1``.

    Properties:
        startup_ms: Wall time minus sum of ``timings_ms`` (subprocess + package load).
        jit_ms:     First timing (includes JIT), or ``None`` if not available.
        warm_ms:    Minimum of warm timings, or ``None`` if fewer than 2 reps.
    """
    t: np.ndarray
    x: np.ndarray
    timings_ms: list[float] = field(default_factory=list)

    @property
    def startup_ms(self) -> float:
        return self._startup_ms

    def _set_startup(self, wall_ms: float) -> None:
        self._startup_ms = wall_ms - sum(self.timings_ms)

    @property
    def jit_ms(self) -> float | None:
        return self.timings_ms[0] if self.timings_ms else None

    @property
    def warm_ms(self) -> float | None:
        warm = self.timings_ms[1:]
        return min(warm) if warm else None

ScipyBackend

numen.bridge.scipy_backend.ScipyBackend

Pure-Python solver backend using scipy.integrate.solve_ivp.

Good for development, debugging, and models that require control callbacks.

Parameters:

Name Type Description Default
method str

scipy solver method ("RK45", "RK23", "DOP853", "LSODA"). Default "RK45".

'RK45'
rtol float

Relative tolerance. Default 1e-6.

1e-06
atol float

Absolute tolerance. Default 1e-8.

1e-08
dtmax float | None

Maximum adaptive step size (passed as max_step).

None
dtsave float | None

Save output every dtsave time units. Mutually exclusive with n_save_points.

None
n_save_points int

Save exactly this many uniformly-spaced output points. Mutually exclusive with dtsave.

0

Example::

from numen.bridge.scipy_backend import ScipyBackend
result = ScipyBackend(rtol=1e-9, atol=1e-9).solve(spec, tspan=(0.0, 5.0))
Source code in src/numen/bridge/scipy_backend.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
class ScipyBackend:
    """Pure-Python solver backend using ``scipy.integrate.solve_ivp``.

    Good for development, debugging, and models that require control callbacks.

    Args:
        method:        scipy solver method (``"RK45"``, ``"RK23"``, ``"DOP853"``,
                       ``"LSODA"``). Default ``"RK45"``.
        rtol:          Relative tolerance. Default ``1e-6``.
        atol:          Absolute tolerance. Default ``1e-8``.
        dtmax:         Maximum adaptive step size (passed as ``max_step``).
        dtsave:        Save output every ``dtsave`` time units.
                       Mutually exclusive with ``n_save_points``.
        n_save_points: Save exactly this many uniformly-spaced output points.
                       Mutually exclusive with ``dtsave``.

    Example::

        from numen.bridge.scipy_backend import ScipyBackend
        result = ScipyBackend(rtol=1e-9, atol=1e-9).solve(spec, tspan=(0.0, 5.0))
    """

    supported_features: ClassVar[frozenset[str]] = frozenset({
        "vector_fields",
        "discrete_fields",
        "continuous_fields",
        "control_callbacks",
    })

    def __init__(self, method: str = "RK45", rtol: float = 1e-6, atol: float = 1e-8,
                 dtmax: float | None = None, dtsave: float | None = None,
                 n_save_points: int = 0):
        if n_save_points > 0 and dtsave is not None:
            raise ValueError("Specify either n_save_points or dtsave, not both.")
        self.method = method
        self.rtol = rtol
        self.atol = atol
        self.dtmax = dtmax
        self.dtsave = dtsave
        self.n_save_points = n_save_points

    def solve(
        self,
        compiled_spec: CompiledSpec,
        tspan: tuple[float, float],
        t_eval: "np.ndarray | None" = None,
        progress: bool = False,
    ) -> SolveResult:
        check_backend_features(compiled_spec, "ScipyBackend", self.supported_features)
        check_python_fns(compiled_spec, "ScipyBackend")

        _log.debug(
            "solve: state_size=%d param_size=%d tspan=%s method=%s rtol=%g atol=%g",
            compiled_spec.state_size, compiled_spec.param_size,
            tspan, self.method, self.rtol, self.atol,
        )

        # Validate callbacks have python_fn
        cb_missing = [c.name for c in compiled_spec.compiled_callbacks if c.python_fn is None]
        if cb_missing:
            raise ValueError(
                f"ScipyBackend: callbacks missing python_fn: {cb_missing}\n"
                f"Declare 'python_fn: ClassVar = staticmethod(your_fn)' on the Callback class."
            )

        if compiled_spec.compiled_callbacks:
            return self._solve_segmented(compiled_spec, tspan, t_eval, progress)
        return self._solve_simple(compiled_spec, tspan, t_eval, progress)

    # ------------------------------------------------------------------
    # Simple path (no callbacks) — single solve_ivp call
    # ------------------------------------------------------------------

    def _solve_simple(
        self,
        compiled_spec: CompiledSpec,
        tspan: tuple[float, float],
        t_eval: "np.ndarray | None",
        progress: bool,
    ) -> SolveResult:
        p = np.array(compiled_spec.p)
        tstops = _build_tstops(compiled_spec.discrete_dts, tspan)
        t0_sim, tf_sim = tspan

        def rhs(t: float, x: np.ndarray) -> np.ndarray:
            dx = DxBuffer(np.zeros_like(x))
            for sys in compiled_spec.systems:
                sys.python_fn(dx, x, p, t, compiled_spec, sys)
            return dx.array

        if progress:
            rhs = _wrap_rhs_progress(rhs, t0_sim, tf_sim)

        t_eval = _apply_save_density(t_eval, tspan, self.dtsave, self.n_save_points)
        dense_eval = _merge_t_eval(t_eval, tstops)

        kw: dict = dict(method=self.method, t_eval=dense_eval, rtol=self.rtol, atol=self.atol)
        if self.dtmax is not None:
            kw["max_step"] = self.dtmax

        t0_wall = time.perf_counter()
        try:
            sol = solve_ivp(rhs, tspan, compiled_spec.x0, **kw)
        finally:
            if progress:
                _close_rhs_progress(rhs)

        elapsed_ms = (time.perf_counter() - t0_wall) * 1000
        if not sol.success:
            _log.error("solve failed after %.0f ms: %s", elapsed_ms, sol.message)
            raise RuntimeError(f"Solver failed: {sol.message}")

        _log.debug("solve done in %.1f ms, %d steps", elapsed_ms, len(sol.t))
        return SolveResult(t=sol.t, x=sol.y)

    # ------------------------------------------------------------------
    # Segmented path (with callbacks)
    # ------------------------------------------------------------------

    def _solve_segmented(
        self,
        compiled_spec: CompiledSpec,
        tspan: tuple[float, float],
        t_eval: "np.ndarray | None",
        progress: bool,
    ) -> SolveResult:
        p = np.array(compiled_spec.p)
        t0_sim, tf_sim = tspan

        t_eval = _apply_save_density(t_eval, tspan, self.dtsave, self.n_save_points)

        # Build the callback fire schedule
        cb_schedule = _build_tstop_callbacks(compiled_spec.compiled_callbacks, tspan)
        # Also include DiscreteField tstops (no callback, just save points)
        disc_tstops = set(_build_tstops(compiled_spec.discrete_dts, tspan))

        # All segment boundaries = callback tstops ∪ disc tstops ∪ [tf]
        cb_times = {t for t, _ in cb_schedule}
        all_boundaries = sorted(cb_times | disc_tstops | {tf_sim})

        def rhs(t: float, x: np.ndarray) -> np.ndarray:
            dx = DxBuffer(np.zeros_like(x))
            for sys in compiled_spec.systems:
                sys.python_fn(dx, x, p, t, compiled_spec, sys)
            return dx.array

        # Build a lookup from tstop → callbacks (only for cb_schedule entries)
        cb_at: dict[float, list[CompiledCallback]] = {t: cbs for t, cbs in cb_schedule}

        t_cur = t0_sim
        x_cur = np.array(compiled_spec.x0, dtype=float)
        all_t: list[np.ndarray] = []
        all_x: list[np.ndarray] = []  # each entry shape (state_size, n_steps_in_segment)

        t0_wall = time.perf_counter()

        if progress:
            try:
                from tqdm.auto import tqdm
                pbar = tqdm(
                    total=tf_sim - t0_sim, unit="s", desc="scipy",
                    bar_format="{desc}: {percentage:.0f}%|{bar}| {n:.3g}/{total:.3g} s  [{elapsed}<{remaining}]",
                    dynamic_ncols=True,
                )
            except ImportError:
                pbar = None
        else:
            pbar = None

        for t_stop in all_boundaries:
            if t_stop <= t_cur + 1e-15:
                continue

            # t_eval points inside this segment
            if t_eval is not None:
                seg_eval_mask = (t_eval > t_cur) & (t_eval <= t_stop)
                seg_eval = np.sort(np.unique(np.concatenate([t_eval[seg_eval_mask], [t_stop]])))
            else:
                seg_eval = np.array([t_stop])

            seg_kw: dict = dict(method=self.method, t_eval=seg_eval, rtol=self.rtol, atol=self.atol)
            if self.dtmax is not None:
                seg_kw["max_step"] = self.dtmax
            sol = solve_ivp(rhs, (t_cur, t_stop), x_cur, **seg_kw)
            if not sol.success:
                elapsed_ms = (time.perf_counter() - t0_wall) * 1000
                _log.error("segment solve failed at t=%.6f after %.0f ms: %s", t_stop, elapsed_ms, sol.message)
                raise RuntimeError(f"Solver failed at t={t_stop:.6f}: {sol.message}")

            # Collect — drop the first point if it duplicates the last saved point
            t_seg = sol.t
            x_seg = sol.y
            if all_t and len(t_seg) > 0 and abs(t_seg[0] - all_t[-1][-1]) < 1e-15:
                t_seg = t_seg[1:]
                x_seg = x_seg[:, 1:]
            if len(t_seg) > 0:
                all_t.append(t_seg)
                all_x.append(x_seg)

            x_cur = sol.y[:, -1].copy()
            t_cur = t_stop

            # Fire callbacks at this tstop
            for cb in cb_at.get(t_stop, []):
                updates = cb.python_fn(t_stop, x_cur, p, compiled_spec)
                if updates:
                    for key, val in updates.items():
                        if key in compiled_spec.state_index_map:
                            s, e = compiled_spec.state_index_map[key]
                            x_cur[s:e] = np.atleast_1d(val) if e - s > 1 else val
                        else:
                            raise KeyError(
                                f"Callback '{cb.name}' returned unknown state key '{key}'. "
                                f"Available: {list(compiled_spec.state_index_map)}"
                            )
                _log.debug("callback '%s' fired at t=%.6f", cb.name, t_stop)

            if pbar is not None:
                pbar.update(t_stop - (all_boundaries[all_boundaries.index(t_stop) - 1] if all_boundaries.index(t_stop) > 0 else t0_sim))

        if pbar is not None:
            pbar.close()

        elapsed_ms = (time.perf_counter() - t0_wall) * 1000
        _log.debug("segmented solve done in %.1f ms, %d segments", elapsed_ms, len(all_t))

        t_out = np.concatenate(all_t) if all_t else np.array([t0_sim])
        x_out = np.concatenate(all_x, axis=1) if all_x else np.array(compiled_spec.x0)[:, None]
        return SolveResult(t=t_out, x=x_out)

JAXBackend

numen.bridge.jax_backend.JAXBackend

JAX + diffrax solver backend.

The entire diffeqsolve call — ODE steps, RHS evaluations, and adaptive step-size control — is wrapped in jax.jit (or equinox.filter_jit if equinox is installed, which gives clearer runtime error messages).

The compiled XLA program is cached per (compiled_spec id, tspan) so repeated solves with the same problem (e.g. Monte Carlo over initial conditions) reuse the compiled kernel. Only x0 is a dynamic input: parameters p, save times, and tolerances are static constants baked into the compiled program.

During JIT tracing, Python for loops over entity_groups and dict lookups in state_index_map run once; the XLA kernel contains only integer-indexed array operations with no dict overhead at execution time.

Parameters:

Name Type Description Default
rtol float

Relative tolerance (diffrax PIDController).

1e-06
atol float

Absolute tolerance.

1e-08
n_saves int

Number of evenly-spaced save points when no discrete fields are present.

500
max_steps int

Maximum ODE solver steps. Increase if you get "maximum number of solver steps was reached". For stiff problems consider switching to an implicit solver (solver="Kvaerno5") or using JuliaServerBackend with method="Rodas5P".

100000
solver str

diffrax solver class name. Explicit (non-stiff): "Dopri5" (default), "Tsit5" Implicit (stiff): "Kvaerno3", "Kvaerno4", "Kvaerno5", "ImplicitEuler"

'Dopri5'
Source code in src/numen/bridge/jax_backend.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
class JAXBackend:
    """JAX + diffrax solver backend.

    The *entire* ``diffeqsolve`` call — ODE steps, RHS evaluations, and
    adaptive step-size control — is wrapped in ``jax.jit`` (or
    ``equinox.filter_jit`` if equinox is installed, which gives clearer
    runtime error messages).

    The compiled XLA program is cached per ``(compiled_spec id, tspan)`` so
    repeated solves with the same problem (e.g. Monte Carlo over initial
    conditions) reuse the compiled kernel.  Only ``x0`` is a dynamic input:
    parameters ``p``, save times, and tolerances are static constants baked
    into the compiled program.

    During JIT tracing, Python ``for`` loops over ``entity_groups`` and dict
    lookups in ``state_index_map`` run once; the XLA kernel contains only
    integer-indexed array operations with no dict overhead at execution time.

    Args:
        rtol:      Relative tolerance (diffrax PIDController).
        atol:      Absolute tolerance.
        n_saves:   Number of evenly-spaced save points when no discrete fields
                   are present.
        max_steps: Maximum ODE solver steps.  Increase if you get
                   "maximum number of solver steps was reached".  For stiff
                   problems consider switching to an implicit solver
                   (``solver="Kvaerno5"``) or using JuliaServerBackend with
                   ``method="Rodas5P"``.
        solver:    diffrax solver class name.
                   Explicit (non-stiff): ``"Dopri5"`` (default), ``"Tsit5"``
                   Implicit (stiff):     ``"Kvaerno3"``, ``"Kvaerno4"``,
                                         ``"Kvaerno5"``, ``"ImplicitEuler"``
    """

    supported_features: ClassVar[frozenset[str]] = frozenset({
        "vector_fields",
        "discrete_fields",
        "continuous_fields",
    })

    def __init__(
        self,
        rtol: float = 1e-6,
        atol: float = 1e-8,
        n_saves: int = 500,
        max_steps: int = 100_000,
        solver: str = "Dopri5",
    ) -> None:
        self.rtol      = rtol
        self.atol      = atol
        self.n_saves   = n_saves
        self.max_steps = max_steps
        self.solver    = solver
        self._cache: dict[tuple, Callable] = {}

    def _build_run_fn(self, compiled_spec: CompiledSpec, tspan: tuple[float, float]) -> Callable:
        """Build and JIT-compile the full ODE solve for this spec + tspan."""
        import jax.numpy as jnp
        import diffrax

        t0, tf   = tspan
        p        = jnp.array(compiled_spec.p)
        systems  = compiled_spec.systems

        tstops   = _build_tstops(compiled_spec.discrete_dts, tspan)
        save_ts  = jnp.array(tstops) if tstops else jnp.linspace(t0, tf, self.n_saves)

        max_steps = self.max_steps
        ctrl      = diffrax.PIDController(rtol=self.rtol, atol=self.atol)
        saveat    = diffrax.SaveAt(ts=save_ts)
        _solver   = getattr(diffrax, self.solver)()
        _jit      = _get_jit()

        @_jit
        def run(x0: jnp.ndarray) -> Any:
            def rhs(t: float, y: jnp.ndarray, _args: None) -> jnp.ndarray:
                dx = DxBuffer(jnp.zeros_like(y))
                for sys in systems:
                    sys.python_fn(dx, y, p, t, compiled_spec, sys)
                return dx.array

            return diffrax.diffeqsolve(
                diffrax.ODETerm(rhs),
                _solver,
                t0=t0,
                t1=tf,
                dt0=None,
                y0=x0,
                args=None,
                saveat=saveat,
                stepsize_controller=ctrl,
                max_steps=max_steps,
            )

        return run

    def solve(
        self,
        compiled_spec: CompiledSpec,
        tspan: tuple[float, float],
        progress: bool = False,
    ) -> SolveResult:
        import jax.numpy as jnp

        check_backend_features(compiled_spec, "JAXBackend", self.supported_features)
        check_python_fns(compiled_spec, "JAXBackend")

        _log.debug(
            "solve: state_size=%d param_size=%d tspan=%s solver=%s rtol=%g atol=%g max_steps=%d",
            compiled_spec.state_size, compiled_spec.param_size,
            tspan, self.solver, self.rtol, self.atol, self.max_steps,
        )

        key = (id(compiled_spec), tspan)
        is_first_call = key not in self._cache
        if is_first_call:
            _log.debug("JIT compiling for spec id=%d tspan=%s", id(compiled_spec), tspan)
            self._cache[key] = self._build_run_fn(compiled_spec, tspan)

        run = self._cache[key]
        x0  = jnp.array(compiled_spec.x0)

        label = "JAX (JIT compiling...)" if is_first_call else "JAX"
        t0_wall = time.perf_counter()

        if progress:
            from numen.bridge.server_backend import _read_with_spinner as _spinner
            import threading
            _result = [None]
            _exc    = [None]
            def _run():
                try:
                    _result[0] = run(x0)
                except Exception as e:
                    _exc[0] = e
            t = threading.Thread(target=_run, daemon=True)
            t.start()
            try:
                from tqdm.auto import tqdm
                pbar = tqdm(bar_format=f"{label} {{elapsed}}", total=0, dynamic_ncols=True)
            except ImportError:
                pbar = None
            while t.is_alive():
                if pbar is not None:
                    pbar.update(0)
                t.join(timeout=0.1)
            if pbar is not None:
                pbar.close()
            if _exc[0] is not None:
                raise _exc[0]
            sol = _result[0]
        else:
            try:
                sol = run(x0)
            except Exception as e:
                _reraise_jax_error(e, self.max_steps, self.solver)

        elapsed_ms = (time.perf_counter() - t0_wall) * 1000
        _log.debug("solve done in %.1f ms%s", elapsed_ms, " (includes JIT)" if is_first_call else "")

        t_out = np.array(sol.ts)
        x_out = np.array(sol.ys).T   # (n_saves, state_size) → (state_size, n_saves)
        return SolveResult(t=t_out, x=x_out)

JuliaBackend

numen.bridge.runtime.JuliaBackend

Julia + OrdinaryDiffEq.jl solver backend via subprocess.

Spawns a fresh Julia process for each solve call. Startup (~2–3 s) includes Julia boot, package loading, and user dynamics file include. Within that process, reps solves are timed individually: the first includes JIT compilation of user dynamics; subsequent ones are warm.

Parameters:

Name Type Description Default
julia_file str | None

Path to the .jl file that defines all dynamics_fn modules.

None
method str

Julia solver name (default "Tsit5").

'Tsit5'
rtol float

Relative tolerance.

1e-06
atol float

Absolute tolerance.

1e-08

Example::

backend = JuliaBackend(julia_file="examples/oscillator/dynamics.jl")
result  = backend.solve(spec, tspan=(0.0, 5.0), reps=6)
print(f"startup {result.startup_ms:.0f} ms  "
      f"JIT {result.jit_ms:.0f} ms  "
      f"warm {result.warm_ms:.0f} ms")
Source code in src/numen/bridge/runtime.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
class JuliaBackend:
    """Julia + OrdinaryDiffEq.jl solver backend via subprocess.

    Spawns a fresh Julia process for each ``solve`` call. Startup (~2–3 s)
    includes Julia boot, package loading, and user dynamics file ``include``.
    Within that process, ``reps`` solves are timed individually: the first
    includes JIT compilation of user dynamics; subsequent ones are warm.

    Args:
        julia_file: Path to the .jl file that defines all ``dynamics_fn`` modules.
        method:     Julia solver name (default ``"Tsit5"``).
        rtol:       Relative tolerance.
        atol:       Absolute tolerance.

    Example::

        backend = JuliaBackend(julia_file="examples/oscillator/dynamics.jl")
        result  = backend.solve(spec, tspan=(0.0, 5.0), reps=6)
        print(f"startup {result.startup_ms:.0f} ms  "
              f"JIT {result.jit_ms:.0f} ms  "
              f"warm {result.warm_ms:.0f} ms")
    """

    supported_features: ClassVar[frozenset[str]] = frozenset({
        "vector_fields",
        "discrete_fields",
        "continuous_fields",
        "control_callbacks",
        "dae_constraints",
    })

    def __init__(
        self,
        julia_file: str | None = None,
        method: str = "Tsit5",
        rtol: float = 1e-6,
        atol: float = 1e-8,
        n_save_points: int = 0,
        dtsave: float | None = None,
        dtmax: float | None = None,
    ) -> None:
        if n_save_points > 0 and dtsave is not None:
            raise ValueError("Specify either n_save_points or dtsave, not both.")
        self._julia_file = str(Path(julia_file).resolve()) if julia_file else None
        self.method = method
        self.rtol = rtol
        self.atol = atol
        self.n_save_points = n_save_points
        self.dtsave = dtsave
        self.dtmax = dtmax

    def solve(
        self,
        compiled_spec: CompiledSpec,
        tspan: tuple[float, float],
        reps: int = 1,
    ) -> SolveResult:
        """Run the ODE solver via Julia subprocess.

        Args:
            compiled_spec: Compiled simulation spec.
            tspan:         (t0, tf) integration interval.
            reps:          Number of solves to run inside the subprocess.
                           reps=1 → single solve (no warm timing).
                           reps>1 → first solve is JIT, rest are warm.
        """
        check_backend_features(compiled_spec, "JuliaBackend", self.supported_features)
        check_julia_fns(compiled_spec, "JuliaBackend")

        _log.debug(
            "solve: state_size=%d param_size=%d tspan=%s method=%s rtol=%g atol=%g reps=%d",
            compiled_spec.state_size, compiled_spec.param_size,
            tspan, self.method, self.rtol, self.atol, reps,
        )

        payload = {
            "spec":          compiled_spec.to_dict(),
            "tspan":         list(tspan),
            "reps":          reps,
            "method":        self.method,
            "rtol":          self.rtol,
            "atol":          self.atol,
            "n_save_points": self.n_save_points,
            "dtsave":        self.dtsave,
            "dtmax":         self.dtmax,
        }

        payload_path = Path(tempfile.mktemp(suffix=".json"))
        result_path  = Path(tempfile.mktemp(suffix=".json"))
        try:
            payload_path.write_text(json.dumps(payload))
            result_path.touch()

            cmd = [
                "julia",
                f"--project={_JULIA_PKG_DIR}",
                str(_RUNNER_JL),
                str(payload_path),
                str(result_path),
            ]
            if self._julia_file:
                cmd.append(self._julia_file)

            t0 = time.perf_counter()
            proc = subprocess.run(cmd, capture_output=True, text=True)
            wall_ms = (time.perf_counter() - t0) * 1000

            # Always route Julia output to logger, not just on failure
            for line in proc.stderr.splitlines():
                if line.strip():
                    _log.debug("[julia stderr] %s", line)
            for line in proc.stdout.splitlines():
                if line.strip():
                    _log.debug("[julia stdout] %s", line)

            if proc.returncode != 0:
                _log.error(
                    "Julia subprocess failed (exit %d) after %.0f ms",
                    proc.returncode, wall_ms,
                )
                raise RuntimeError(
                    f"Julia subprocess failed (exit {proc.returncode}):\n"
                    f"STDOUT:\n{proc.stdout}\nSTDERR:\n{proc.stderr}"
                )

            data = json.loads(result_path.read_text())
        finally:
            payload_path.unlink(missing_ok=True)
            result_path.unlink(missing_ok=True)

        _log.debug("solve done in %.0f ms wall time", wall_ms)
        t = np.array(data["t"])
        x = np.array(data["x"])   # shape (state_size, n_steps) — runner.jl serializes row-wise
        result = SolveResult(t=t, x=x, timings_ms=data.get("timings_ms", []))
        result._set_startup(wall_ms)
        return result

solve(compiled_spec, tspan, reps=1)

Run the ODE solver via Julia subprocess.

Parameters:

Name Type Description Default
compiled_spec CompiledSpec

Compiled simulation spec.

required
tspan tuple[float, float]

(t0, tf) integration interval.

required
reps int

Number of solves to run inside the subprocess. reps=1 → single solve (no warm timing). reps>1 → first solve is JIT, rest are warm.

1
Source code in src/numen/bridge/runtime.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def solve(
    self,
    compiled_spec: CompiledSpec,
    tspan: tuple[float, float],
    reps: int = 1,
) -> SolveResult:
    """Run the ODE solver via Julia subprocess.

    Args:
        compiled_spec: Compiled simulation spec.
        tspan:         (t0, tf) integration interval.
        reps:          Number of solves to run inside the subprocess.
                       reps=1 → single solve (no warm timing).
                       reps>1 → first solve is JIT, rest are warm.
    """
    check_backend_features(compiled_spec, "JuliaBackend", self.supported_features)
    check_julia_fns(compiled_spec, "JuliaBackend")

    _log.debug(
        "solve: state_size=%d param_size=%d tspan=%s method=%s rtol=%g atol=%g reps=%d",
        compiled_spec.state_size, compiled_spec.param_size,
        tspan, self.method, self.rtol, self.atol, reps,
    )

    payload = {
        "spec":          compiled_spec.to_dict(),
        "tspan":         list(tspan),
        "reps":          reps,
        "method":        self.method,
        "rtol":          self.rtol,
        "atol":          self.atol,
        "n_save_points": self.n_save_points,
        "dtsave":        self.dtsave,
        "dtmax":         self.dtmax,
    }

    payload_path = Path(tempfile.mktemp(suffix=".json"))
    result_path  = Path(tempfile.mktemp(suffix=".json"))
    try:
        payload_path.write_text(json.dumps(payload))
        result_path.touch()

        cmd = [
            "julia",
            f"--project={_JULIA_PKG_DIR}",
            str(_RUNNER_JL),
            str(payload_path),
            str(result_path),
        ]
        if self._julia_file:
            cmd.append(self._julia_file)

        t0 = time.perf_counter()
        proc = subprocess.run(cmd, capture_output=True, text=True)
        wall_ms = (time.perf_counter() - t0) * 1000

        # Always route Julia output to logger, not just on failure
        for line in proc.stderr.splitlines():
            if line.strip():
                _log.debug("[julia stderr] %s", line)
        for line in proc.stdout.splitlines():
            if line.strip():
                _log.debug("[julia stdout] %s", line)

        if proc.returncode != 0:
            _log.error(
                "Julia subprocess failed (exit %d) after %.0f ms",
                proc.returncode, wall_ms,
            )
            raise RuntimeError(
                f"Julia subprocess failed (exit {proc.returncode}):\n"
                f"STDOUT:\n{proc.stdout}\nSTDERR:\n{proc.stderr}"
            )

        data = json.loads(result_path.read_text())
    finally:
        payload_path.unlink(missing_ok=True)
        result_path.unlink(missing_ok=True)

    _log.debug("solve done in %.0f ms wall time", wall_ms)
    t = np.array(data["t"])
    x = np.array(data["x"])   # shape (state_size, n_steps) — runner.jl serializes row-wise
    result = SolveResult(t=t, x=x, timings_ms=data.get("timings_ms", []))
    result._set_startup(wall_ms)
    return result

Errors

numen.errors.NumenError

Bases: RuntimeError

Base class for all Numen framework errors.

Source code in src/numen/errors.py
20
21
class NumenError(RuntimeError):
    """Base class for all Numen framework errors."""

numen.errors.NumenFeatureError

Bases: NumenError

A CompiledSpec requires features not supported by the selected backend.

Raised by check_backend_features before the solve begins, with a hint about which backend to use instead.

Source code in src/numen/errors.py
24
25
26
27
28
29
class NumenFeatureError(NumenError):
    """A ``CompiledSpec`` requires features not supported by the selected backend.

    Raised by ``check_backend_features`` before the solve begins, with a hint
    about which backend to use instead.
    """

numen.errors.NumenMissingFnError

Bases: NumenError

A System is missing the required dynamics function for the selected backend.

Raised by check_python_fns (scipy/JAX) or check_julia_fns (Julia) before the solve begins.

Source code in src/numen/errors.py
32
33
34
35
36
37
class NumenMissingFnError(NumenError):
    """A ``System`` is missing the required dynamics function for the selected backend.

    Raised by ``check_python_fns`` (scipy/JAX) or ``check_julia_fns`` (Julia)
    before the solve begins.
    """