Previously only stopped on tokenizer.eos_token_id. DSV4 uses special
turn-end tokens (<|end_of_sentence|>, USER_TOKEN=128803) that indicate
the assistant turn is complete. Missing these caused decode to continue
past the model's natural stopping point, producing degenerate output.
Also increased diagnostic logging (every step for first 20 steps) to
catch turn-end token emissions.
- fmha_multihead_launch.cu: PyTorch launch wrapper for fmha_6warp_multihead_kernel
(c10::BFloat16 boundary, uint16_t bf16_t inside kernel, zero-cost casts)
- fmha_multihead_op.py: torch.utils.cpp_extension JIT loader + custom_op registration
(dsv4::fmha_multihead_decode for torch.compile)
- production.py: fast path dispatch for T=1, n_segments==1, hd in {64,128,256}
Falls through to CuTeDSL slow path for multi-segment/prefill
- test_p3_fast_decode.py: integration test (MHA/MQA/GQA, cosine >= 0.999998)
Architecture:
Grid: dim3(1, n_h, batch_size) — one CTA per (head, batch)
MQA: k_head_stride=0 so all Q heads share same K/V
Single kernel launch, zero cudaDeviceSynchronize on hot path
Normalized output for single-segment decode
Keep epilogue_tma_store for final output (proven path).
Only fix the multi-KV-tile O rescale using paired atoms from
epilogue_tmem_copy_and_partition. The paired atoms share addressing,
making the TMEM->REGS->modify->TMEM cycle lossless.
Guarded by const_expr(n_kv_tiles > 1) so single-tile path (n=128)
is completely unaffected — zero regression risk.
Full correction epilogue (one-way TMEM->REGS->SMEM->GMEM) deferred
until we can address the MLIR compilation time issue.
Replace hand-constructed Ld32x32bOp/St32x32bOp TMEM round-trip with the
proven correction epilogue pattern from fused_swiglu.py:
1. O rescale (kt>0): TMEM→REGS (paired load), multiply by acc_scale,
REGS→TMEM (paired store via retile_to_S). No layout mismatch.
2. Final O output: One-way TMEM→REGS→SMEM→GMEM using
epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition
+ TMA partition. Register-level normalization (divide by row_sum)
or raw BF16 cast for D5a path.
This fixes both D1.5 issues:
- Issue 1: TMEM round-trip corruption (hand-constructed atoms)
- Issue 2: O rescale for multi-KV-tile (kt>0)
Supports normalize=True (in-kernel) and normalize=False (D5a external).
Uses epilog_sync_bar + c_pipe for SMEM→GMEM, replacing epilogue_tma_store.