Skip to main content

fp32-CRF diagnostic — 2026-05-28

The 2026-05-28 night-2 postmortem left one open question after the v0.6.0 NaN incidents:

CRF training on a 33×33 transition table in bf16 NaN'd twice. Fix was to disable CRF training entirely. The bf16 hypothesis remains unconfirmed — needs a fp32 follow-up.

Two divergences had been observed:

AttemptLRcrf_loss_weightNaN at step
11.5e-40.5~950
21.0e-40.1~1700

Both were post-warmup, after the LR reached its constant value. The transition table is 33×33 in Stage 3 (vs 21×21 in Stage 2) with masked -inf entries enforcing the structural BIO grammar (no I-X after O, no I-X after I-Y for X≠Y, etc.). The hypothesis: bf16's 7-bit mantissa loses too much precision against -inf-sentinel arithmetic in the CRF's logsumexp forward pass.

DeepSeek's turn 7 sign-off explicitly required this experiment as a precondition for v0.6.2 — independent of v0.6.2's corpus changes (different variable).

Implementation

crf_fp32: bool = False flag added to ModelConfig (corpus-python/src/mailwoman_train/config.py) and threaded through MailwomanCoarseEncoder (model.py). When set, the CRF forward call runs under torch.autocast(device_type=..., enabled=False) with emissions and the mask upcast to fp32:

if self.crf_fp32:
device_type = logits.device.type
with torch.autocast(device_type=device_type, enabled=False):
emissions_fp32 = logits.float()
crf_mask = attention_mask.to(emissions_fp32.dtype)
crf_loss = self.crf(
emissions=emissions_fp32,
tags=labels.clamp(min=0),
mask=crf_mask,
reduction=crf_reduction,
)

Everything outside the CRF call continues in bf16 — encoder forward, embedding lookup, CE loss, optimizer step. Only the CRF's logsumexp over the transition table runs in fp32. The crf_loss is downcast back to ce_loss's dtype before the dual-loss sum so the optimizer sees one consistent tensor.

Default is False — every existing config remains bit-identical to its prior runs.

Diagnostic config

configs/v0_6_2-crf-fp32-diagnostic.yaml. Identical to v0.6.1-stage3-streets EXCEPT:

Fieldv0.6.1Diagnostic
crf_loss_weight0.00.5
crf_fp32n/atrue
max_steps1000003000
eval_every_steps2000500

3000 steps clears both observed NaN windows (step 950 + step 1700) with margin. ~30 min on A100.

Result: PASS

Training reached step 3000 with no NaN. Loss curve:

Steptrain_lossval_lossmacro_f1
2543.81
10032.21
500(warmup)
9501.5647
10001.53754.34180.2235
15000.90443.85350.2589
17000.8934
20000.6809(next eval)
25000.59273.63480.2979
28250.4913

Both previously-NaN'd steps (bold) cleared cleanly. macro_f1 climbing monotonically: 0.2235 → 0.2589 → 0.2979. val_loss dropping: 4.34 → 3.85 → 3.63.

The bf16 + masked -inf transition hypothesis is empirically confirmed. fp32 precision inside the CRF's logsumexp forward eliminates the precision loss; the rest of the model runs in bf16 for throughput as before.

Implications

  1. v0.6.3 can re-enable CRF. Set crf_loss_weight: 0.5 (or higher) + crf_fp32: true in the v0.6.3 config. The learned transition table that fell out of v0.6.0's failed attempts is now within reach.

  2. v0.6.2 stays CE-only. Per the NaN protocol's "one variable at a time" rule: v0.6.2 already changes corpus composition (synth-street weight + new synth-no-street shard). Adding CRF activation in the same release confounds attribution. Defer to v0.6.3.

  3. fp32-CRF throughput cost is minimal. The diagnostic ran at ~6 steps/s on an A100-SXM4-40GB, comparable to v0.6.0's CE-only throughput. The autocast boundary only affects the CRF forward (a small linear-in-seq-len operation), not the encoder.

  4. The NaN protocol's "diagnose ONE knob → retry" rule paid off here. The fp32 experiment was a 30-min run that closed a question that had been open since the postmortem. The cost of NOT running it was indefinite future uncertainty.

Reproducing

# Push the changed model.py + config.py + new diagnostic yaml to the Modal volume:
modal volume put mailwoman-training corpus-python/src/mailwoman_train/model.py \
corpus-python/src/mailwoman_train/model.py --force
modal volume put mailwoman-training corpus-python/src/mailwoman_train/config.py \
corpus-python/src/mailwoman_train/config.py --force
modal volume put mailwoman-training \
corpus-python/src/mailwoman_train/configs/v0_6_2-crf-fp32-diagnostic.yaml \
corpus-python/src/mailwoman_train/configs/v0_6_2-crf-fp32-diagnostic.yaml --force

# Clear pyc cache (Modal's documented gotcha — labels.py / model.py changes need this):
modal volume rm mailwoman-training corpus-python/src/mailwoman_train/__pycache__ -r

# Launch the experiment (~30 min on A100):
modal run -d scripts/modal/train_remote.py --config v0_6_2-crf-fp32-diagnostic.yaml --resume none

# Watch:
modal app logs <app-id> # look for loss=nan in the step 900-2000 range

See also