← What's New

Tier 0.3 — Harden PP_EM via multi-restart selection

2026-05-30 · nstat/extras/em/dynamax_bridge.py · branch feat/extras-em-harden-pp-em (stacked on feat/extras-em-predictive-loglik / Tier 0.2)

Why

Tier 0.2's predictive-LL diagnostic exposed a real PP_EM limitation: under weak observability (few neurons / small loadings) a single-seed fit can collapse to a degenerate solution (A → 0, inflated C/Q) whose held-out predictive log-likelihood is worse than a constant-rate model. With strong observability the fit recovers, but the user can't tell which regime they're in without checking the diagnostic — and even then a single fit may land in a poor local optimum.

What shipped

Both compose the Tier 0.1 (canonical gauge) + Tier 0.2 (true held-out predictive log-likelihood) building blocks:

  1. Split observations into train (leading 1 - holdout_fraction) + test (trailing holdout_fraction) by time.
  2. Fit PP_EM (or hybrid) with n_restarts different seeds on the train segment.
  3. Score each fit on the test segment with the appropriate *_predictive_ll.
  4. Return the fit with the highest predictive LL — plus the full seeds-and-LLs trace for transparency.
from nstat.extras.em.dynamax_bridge import fit_point_process_em_best_of
result = fit_point_process_em_best_of(spike_counts, state_dim=3, n_restarts=8)
result.best_result          # PointProcessEMResult — gauge-pinned + locally best
result.best_predictive_ll   # held-out LL of the chosen fit
result.all_predictive_lls   # full diagnostic trace across seeds

Why this is the right fix

Three factors made multi-restart the highest-value-per-line option:

Validation

CheckResult
End-to-end smoke (decoder + classifier branches)PASS
best_seed / best_predictive_ll consistent with argmax(all_predictive_lls)PASS
best_predictive_ll ≥ median(all_predictive_lls)PASS (by construction)
Input validation (n_restarts < 1, holdout_fraction out of range, mismatched hybrid lengths, train segment too short)PASS
Hybrid smoke + shape contractPASS

tests/extras/test_dynamax_bridge.py: 28 passed in the dynamax venv (was 24; +4 new multi-restart tests).

Deferred / out of scope

The deeper M-step regularization options surfaced in the original Tier 0.3 plan — data-driven init from log-empirical-rate, and a ridge on the A/Q M-step — are not shipped here. Multi-restart selection on the diagnostic was the highest-value-per-line change and is what the 0.2 finding most directly called for. Both regularization options can be added incrementally if specific fixtures still need them.

Files changed