nstat.extras.em.dynamax_bridge — EM-trained state-space models via Dynamax
Wraps Dynamax (JAX-based, MIT-licensed) to provide EM-trained linear-Gaussian state-space models without nstat having to re-implement the EM machinery.
This is the foundation for closing the unported MATLAB nSTAT
KF_EM / PP_EM / mPPCO_EM families documented in
AUDIT_REPORT.md
§3.2 (19 methods, ~7,500 LOC of MATLAB if ported verbatim).
Install
pip install nstat-toolbox[dynamax] # pulls Dynamax (~50 MB) + JAX (~200 MB)
API
EM trainers
Symbol |
MATLAB counterpart |
Notes |
|---|---|---|
|
|
LG state-space EM via Dynamax |
|
|
Poisson-LGSSM EM (CMGF E-step + closed-form/Newton M-step) → |
|
|
Mixed Poisson + Gaussian EM (IRLS-pseudo-obs augmented smoother E-step) → |
Point-process inference (known model)
Symbol |
MATLAB counterpart |
Notes |
|---|---|---|
|
|
CMGF (EKF-integration) point-process filter → |
|
|
CMGF forward-backward smoother → |
Held-out predictive log-likelihood (quality diagnostic)
Symbol |
Notes |
|---|---|
|
True one-step-ahead predictive log-likelihood of a Poisson-LGSSM → |
|
Same for the hybrid model; |
These are pure NumPy and do not require dynamax — they score the
true Poisson (and Gaussian) likelihood of observations under the
one-step-ahead predictive state, the right metric for convergence
checks, model/restart comparison, and held-out scoring (pass a test
segment + train-fitted parameters). They replace the surrogate
Gaussian-smoother marginal_log_likelihoods trace, which is not a
valid objective (it re-linearizes each iteration).
Multi-restart selection (recommended workflow for real data)
Symbol |
Notes |
|---|---|
|
Runs PP_EM with |
|
Hybrid counterpart; scored by |
Single-fit fit_point_process_em can collapse to a degenerate solution
under weak observability (see caveat below); multi-restart selection on
the true held-out predictive log-likelihood is the production-quality
mitigation. Use these *_best_of(...) entry points instead of single-fit
on real data.
Optional EM hardening (v0.4.2 — opt-in)
Two new keyword arguments on fit_point_process_em /
fit_point_process_em_best_of and a ridge_lambda on fit_hybrid_em /
fit_hybrid_em_best_of. Both default to the v0.4.1 behavior, so
existing code is unchanged.
init="log_empirical_rate"(PP_EM only): seedsx0frompinv(C) @ log(empirical_mean_rate)so the implied initial firing rate matches what the data shows, removing one bad-init mode. The randomCdraw is unchanged. Recommended on weakly-observable data.ridge_lambda=λ(PP_EM and hybrid; default0.0): biases the A M-step toward the identity via a Gaussian prior atA=I:A = (S10 + λI)(S11 + λI)⁻¹. WhenS10, S11 → 0(the weak-observability collapse mode) the limit becomesIrather than0. Try0.1–1.0if you see theA → 0collapse in*_best_oftraces.
result = fit_point_process_em_best_of(
spike_counts, state_dim=3, n_restarts=8,
init="log_empirical_rate", # data-driven x0
ridge_lambda=0.5, # bias A toward identity
)
Result dataclasses
Dataclass |
Fields |
|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
All result arrays are plain NumPy — callers don’t need to know about JAX or pytrees.
Recipe
import numpy as np
from nstat.extras.em.dynamax_bridge import fit_linear_gaussian_em
# Simulate a 2-state linear-Gaussian process
rng = np.random.default_rng(0)
T, state_dim, emission_dim = 300, 2, 2
A_true = np.eye(state_dim) * 0.9
Q_true = np.eye(state_dim) * 0.02
R_true = np.eye(emission_dim) * 0.1
x = np.zeros((T, state_dim))
y = np.zeros((T, emission_dim))
x[0] = rng.multivariate_normal(np.zeros(state_dim), np.eye(state_dim))
y[0] = x[0] + rng.multivariate_normal(np.zeros(emission_dim), R_true)
for t in range(1, T):
x[t] = A_true @ x[t - 1] + rng.multivariate_normal(np.zeros(state_dim), Q_true)
y[t] = x[t] + rng.multivariate_normal(np.zeros(emission_dim), R_true)
# Fit via EM
result = fit_linear_gaussian_em(y, state_dim=2, n_iter=30)
print(f"Final log-likelihood: {result.log_likelihoods[-1]:.2f}")
print(f"Learned Â:\n{result.transition_matrix}")
Scope
Feature |
Status |
|---|---|
|
shipped — KF_EM equivalent |
|
shipped — point-process Kalman filter under Gaussian approximation |
|
shipped — point-process forward-backward smoother |
|
shipped — PP_EM equivalent (CMGF E-step + closed-form/Newton M-step, Smith & Brown 2003 PPLDS) |
|
shipped — mPPCO_EM equivalent (IRLS-pseudo-obs augmented LG smoother E-step + closed-form / Newton M-step) |
|
shipped — true one-step-ahead held-out predictive log-likelihood (pure NumPy, no dynamax) |
|
shipped — multi-restart EM + held-out-predictive-LL selection (Tier 0.3 mitigation for the weak-observability |
PP_EM and mPPCO_EM — experimental status & caveats
⚠️
fit_point_process_emandfit_hybrid_emare EXPERIMENTAL. They fit the observation model (firing rates, and for the hybrid the Gaussian noiseR) correctly. As of the Tier 0.1 identifiability pass the latentA,Care now returned in a canonical gauge (whiten + SVD-rotate + sign-fix), removing the scale/rotation drift — but EM can still converge to different local optima across seeds, so a single fit’sA/Cshould be interpreted with care. Both functions emit aUserWarningto this effect.
The gauge freedom, and how it is now pinned. A Poisson LDS has a
gauge freedom: the transform (A, C, x) → (T A T⁻¹, C T⁻¹, T x) leaves
the observable log-rate C x — and hence the likelihood — exactly
invariant for any invertible T (the full GL(d) group, d² degrees
of freedom). EM has no reason to prefer any point on this orbit, so an
unconstrained fit lets the absolute scale and rotation of A/C drift
freely (the original PR showed |C| of 5–100 on fits whose rates
were perfectly sensible).
This is the MATLAB PP_EMCreateConstraints role. The Python port pins
the gauge to the standard LDS canonical form (cf. Macke et al. 2011;
Buesing et al. 2012) once after EM convergence — never per
iteration, which fights the Newton trust-region and destabilizes the
fit:
Whiten the latent so the empirical state second moment becomes the identity (
T = M^{-1/2}) — removes the symmetric gauge DOF.SVD-rotate so the stacked emission matrix has orthogonal columns ordered by descending singular value — removes the residual
O(d).Sign-fix each axis so the largest-magnitude entry of each emission column is positive — removes the
2^dsign flips.
The returned emission matrix therefore satisfies CᵀC = diag(S²)
(a machine-precision-exact, seed-stable invariant the tests assert).
What remains is local-optima multiplicity — distinct fits with
genuinely different likelihoods — not gauge freedom; pinning that would
require multi-restart model selection (tracked separately).
What IS reliable (use these):
The fitted firing rate
exp(C x)/ smoothed log-rate — the identifiable observable. Re-smooth at the returned parameters withcmgf_poisson_smootherto obtain it.For
fit_hybrid_em, the Gaussian observation noiseR(recovers the true value within a small factor — it lives in observation space and is gauge-invariant).The held-out predictive log-likelihood (
point_process_predictive_ll/hybrid_predictive_ll) — a true, gauge-invariant quality score; see below.
Checking fit quality — use the predictive log-likelihood, not the EM
trace. The marginal_log_likelihoods returned by the trainers is a
surrogate (the Gaussian-smoother likelihood of the re-linearized IRLS
pseudo-observations); it changes basis every iteration and is not
monotonic or comparable across fits. For a real metric, score the
observations with the one-step-ahead predictive log-likelihood:
import numpy as np
from nstat.extras.em.dynamax_bridge import (
fit_point_process_em, point_process_predictive_ll,
)
y_train, y_test = y[:800], y[800:] # held-out split
fit = fit_point_process_em(y_train, state_dim=3, n_iter=30, seed=0)
score = point_process_predictive_ll(
y_test, fit.transition_matrix, fit.observation_matrix,
fit.transition_covariance, fit.initial_state_mean,
fit.initial_state_covariance,
)
print(score.total) # higher = better; compare seeds/state_dims
print(score.per_timestep) # locate where a fit predicts poorly
Because it is gauge-invariant and pure NumPy, it is the right tool to
pick state_dim, compare EM restarts, or detect a bad fit.
⚠️ Observability caveat (a real limitation, now mitigated). PP_EM’s held-out predictive performance depends strongly on how much the spikes constrain the latent. With weak observability (few neurons and/or small loadings) a single PP_EM fit can converge to a degenerate solution — dynamics
A → 0, inflatedC/Q— that tracks the in-sample mean rate but generalizes worse than a constant-rate model (the predictive LL can be sharply negative). With strong observability (many informative neurons)Ais recovered and the held-out predictive LL improves over the initialization.The recommended workflow on real data is therefore
fit_point_process_em_best_of(...)(orfit_hybrid_em_best_of(...)), not a single-seed fit — those run several restarts and pick the seed with the best held-out predictive LL, automatically discarding degenerate runs. Single-fitfit_point_process_emremains available as a low-level primitive for cases where you need a specific seed (e.g. reproducing a paper figure).
What was fixed in the deep-dive pass (improvements over the initial PR):
Lag-one cross-covariance is now exact: the E-step uses an IRLS pseudo-observation linearization + a purpose-built time-varying-R RTS smoother (
_kalman_rts_smoother_tv) that returns the lag-one smoothed cross-covariances. This stopped the previous A→0 collapse (the moment-matching approximation that dropped the cross-cov term biasedAtoward zero).Time-varying pseudo-observation noise
R_t = 1/λ_t: substituting a fixedR(forced by a batched smoother) breaks the IRLS weight cancellation and was numerically unstable (SVD non-convergence) at low rates. The new smoother accepts per-timestepR_t.Gaussian
RM-step trace correction: now includes theC_g Σ_t C_g'latent-uncertainty term, without whichRcollapsed toward zero over iterations.Gauge + step bounding: a cheap per-iteration unit-RMS diagonal scale pin plus a Newton trust-region keep
|C|finite during the fit, and a single full canonical-gauge transform after convergence (whiten + SVD-rotate + sign-fix) pins the remaining rotational and sign freedom. This is the Tier 0.1 identifiability pass — thePP_EMCreateConstraintsequivalent;A/Care now returned in a unique canonical frame (CᵀCdiagonal, descending).
Still approximate / deferred to a future release:
Multi-restart model selection — the canonical gauge makes a single fit’s
A/Cwell-defined, but EM can still reach different local optima across seeds (genuinely different likelihoods, not gauge copies). Picking the best of several restarts is the remaining step toward fully reproducibleA/C.The Laplace
E[exp(C x_t)]uses the diagonal quadratic correction; sufficient for moderate rates, may underestimate variance at high rates.The reported
marginal_log_likelihoodsare the surrogate Gaussian smoother log-likelihoods (the IRLS pseudo-observations are re-linearized each iteration), not the true Poisson marginal likelihood — do not use the trace as a convergence diagnostic.
CMGF Poisson recipe
For inference on a known model (filter or smoother), the bridge is a
thin wrapper around Dynamax’s
:func:conditional_moments_gaussian_filter / :func:smoother for the
Poisson-LGSSM:
import numpy as np
from nstat.extras.em.dynamax_bridge import (
cmgf_poisson_filter, cmgf_poisson_smoother,
)
# Known model: x_t = A x_{t-1} + w_t, y_t ~ Poisson(exp(C x_t))
A = np.eye(2) * 0.95
C = np.eye(2) * 0.3
Q = np.eye(2) * 0.05
x0 = np.zeros(2)
P0 = np.eye(2) * 0.1
# y is a (T, emission_dim) integer-valued spike-count array.
filtered = cmgf_poisson_filter(y, A, C, Q, x0, P0)
smoothed = cmgf_poisson_smoother(y, A, C, Q, x0, P0)
print(filtered.state_means.shape) # (T, 2)
print(smoothed.marginal_log_likelihood)
Counterpart to MATLAB nSTAT’s PPDecodeFilter / PP_fixedIntervalSmoother.
Gotchas
JAX install footprint is ~200 MB. The
[dynamax]extra is intentionally not in[all-extras]until a CI-functional run validates Dynamax compatibility with the rest of the test matrix.Pytree → NumPy conversion. Dynamax represents parameters as nested JAX pytrees (
ParamsLGSSMDynamics,ParamsLGSSMEmissions,ParamsLGSSMInitial). The bridge unpacks these into plain NumPy arrays so users don’t need to learn the pytree convention.EM monotonicity. EM log-likelihood is theoretically non-decreasing. The bridge tests assert this with a tolerance of
-1e-6to absorb floating-point noise; if Dynamax’s optimizer produces a decrement larger than that, the test will surface it.
End-to-end demo
examples/extras/em_dynamax_demo.py
fits a 2-state LG model on 300 samples and prints the parameter
estimates + EM log-likelihood trace.
Upstream references
Dynamax: https://github.com/probml/dynamax
License: MIT (GPL-2 compatible)
Active development (v1.0, 1,688 commits as of audit)
Maintained by: Kevin Murphy’s probml group