fp32-CRF diagnostic — 2026-05-28
The 2026-05-28 night-2 postmortem left one open question after the v0.6.0 NaN incidents:
CRF training on a 33×33 transition table in bf16 NaN'd twice. Fix was to disable CRF training entirely. The bf16 hypothesis remains unconfirmed — needs a fp32 follow-up.
Two divergences had been observed:
| Attempt | LR | crf_loss_weight | NaN at step |
|---|---|---|---|
| 1 | 1.5e-4 | 0.5 | ~950 |
| 2 | 1.0e-4 | 0.1 | ~1700 |
Both were post-warmup, after the LR reached its constant value. The transition table is
33×33 in Stage 3 (vs 21×21 in Stage 2) with masked -inf entries enforcing the
structural BIO grammar (no I-X after O, no I-X after I-Y for X≠Y, etc.). The
hypothesis: bf16's 7-bit mantissa loses too much precision against -inf-sentinel
arithmetic in the CRF's logsumexp forward pass.
DeepSeek's turn 7 sign-off explicitly required this experiment as a precondition for v0.6.2 — independent of v0.6.2's corpus changes (different variable).
Implementation
crf_fp32: bool = False flag added to ModelConfig (corpus-python/src/mailwoman_train/config.py)
and threaded through MailwomanCoarseEncoder (model.py). When set, the CRF forward call
runs under torch.autocast(device_type=..., enabled=False) with emissions and the mask
upcast to fp32:
if self.crf_fp32:
device_type = logits.device.type
with torch.autocast(device_type=device_type, enabled=False):
emissions_fp32 = logits.float()
crf_mask = attention_mask.to(emissions_fp32.dtype)
crf_loss = self.crf(
emissions=emissions_fp32,
tags=labels.clamp(min=0),
mask=crf_mask,
reduction=crf_reduction,
)
Everything outside the CRF call continues in bf16 — encoder forward, embedding lookup,
CE loss, optimizer step. Only the CRF's logsumexp over the transition table runs in
fp32. The crf_loss is downcast back to ce_loss's dtype before the dual-loss sum so the
optimizer sees one consistent tensor.
Default is False — every existing config remains bit-identical to its prior runs.
Diagnostic config
configs/v0_6_2-crf-fp32-diagnostic.yaml. Identical to v0.6.1-stage3-streets EXCEPT:
| Field | v0.6.1 | Diagnostic |
|---|---|---|
crf_loss_weight | 0.0 | 0.5 |
crf_fp32 | n/a | true |
max_steps | 100000 | 3000 |
eval_every_steps | 2000 | 500 |
3000 steps clears both observed NaN windows (step 950 + step 1700) with margin. ~30 min on A100.
Result: PASS
Training reached step 3000 with no NaN. Loss curve:
| Step | train_loss | val_loss | macro_f1 |
|---|---|---|---|
| 25 | 43.81 | — | — |
| 100 | 32.21 | — | — |
| 500 | (warmup) | — | — |
| 950 | 1.5647 | — | — |
| 1000 | 1.5375 | 4.3418 | 0.2235 |
| 1500 | 0.9044 | 3.8535 | 0.2589 |
| 1700 | 0.8934 | — | — |
| 2000 | 0.6809 | (next eval) | — |
| 2500 | 0.5927 | 3.6348 | 0.2979 |
| 2825 | 0.4913 | — | — |
Both previously-NaN'd steps (bold) cleared cleanly. macro_f1 climbing monotonically: 0.2235 → 0.2589 → 0.2979. val_loss dropping: 4.34 → 3.85 → 3.63.
The bf16 + masked -inf transition hypothesis is empirically confirmed. fp32 precision
inside the CRF's logsumexp forward eliminates the precision loss; the rest of the model
runs in bf16 for throughput as before.
Implications
-
v0.6.3 can re-enable CRF. Set
crf_loss_weight: 0.5(or higher) +crf_fp32: truein the v0.6.3 config. The learned transition table that fell out of v0.6.0's failed attempts is now within reach. -
v0.6.2 stays CE-only. Per the NaN protocol's "one variable at a time" rule: v0.6.2 already changes corpus composition (synth-street weight + new synth-no-street shard). Adding CRF activation in the same release confounds attribution. Defer to v0.6.3.
-
fp32-CRF throughput cost is minimal. The diagnostic ran at ~6 steps/s on an A100-SXM4-40GB, comparable to v0.6.0's CE-only throughput. The autocast boundary only affects the CRF forward (a small linear-in-seq-len operation), not the encoder.
-
The NaN protocol's "diagnose ONE knob → retry" rule paid off here. The fp32 experiment was a 30-min run that closed a question that had been open since the postmortem. The cost of NOT running it was indefinite future uncertainty.
Reproducing
# Push the changed model.py + config.py + new diagnostic yaml to the Modal volume:
modal volume put mailwoman-training corpus-python/src/mailwoman_train/model.py \
corpus-python/src/mailwoman_train/model.py --force
modal volume put mailwoman-training corpus-python/src/mailwoman_train/config.py \
corpus-python/src/mailwoman_train/config.py --force
modal volume put mailwoman-training \
corpus-python/src/mailwoman_train/configs/v0_6_2-crf-fp32-diagnostic.yaml \
corpus-python/src/mailwoman_train/configs/v0_6_2-crf-fp32-diagnostic.yaml --force
# Clear pyc cache (Modal's documented gotcha — labels.py / model.py changes need this):
modal volume rm mailwoman-training corpus-python/src/mailwoman_train/__pycache__ -r
# Launch the experiment (~30 min on A100):
modal run -d scripts/modal/train_remote.py --config v0_6_2-crf-fp32-diagnostic.yaml --resume none
# Watch:
modal app logs <app-id> # look for loss=nan in the step 900-2000 range
See also
- 2026-05-28 night-2 postmortem — the open question this diagnostic closes
- v0.6.2 training config — the mainline retrain (CE-only per NaN protocol; CRF deferred to v0.6.3 with these fp32 changes)
- Street-supplement architecture — the v0.6.2 retrain context