Commit Graph

10 Commits

Author SHA1 Message Date
d772885d7e single_shot_inference: proper mHC+RMSNorm+inverse RoPE pipeline
Major rewrite of single_shot_inference.py:
- Replace broken mHC (gentle normalization hack) with proper Sinkhorn-Knopp
- Add RMSNorm before each sub-block (attention + FFN)
- Add inverse RoPE on attention output (paper §2.3.3)
- Fix KV cache: RoPE applied before caching, K=V in DSV4 MQA
- Fix MoE: proper dense routing with e_bias, SwiGLU clamping
- Proper weight mapping: fn→W_stacked, base→S_pre/S_res/S_post, scale→alphas
- Add identity mHC fallback when weights missing
- No emergency normalization, no bandaids
2026-05-31 02:45:52 +00:00
4b9eed02e1 Cleanup C1-C7: delete dead CuTeDSL FMHA, test probes, scratch files
- Deleted fmha.py (CuTeDSL slow path), FmhaKernel, Python KV merge
- Deleted fmha_sm100.cuh, fmha_sm100_tc.cuh, fmha_sm100_launch.cu, fmha_epilogue_sm100.cuh
- Moved fmha_qk_verify.cuh to tests/unit/qk_verify_kernel.cuh
- Deleted decode_sparse.py, decode_swa.py, kernels/decode/
- Deleted 46 test_d*.py probes, test_smem_*, test_cotiled_*, test_tmem_*,
  test_smem_p_*, test_ultra_minimal, test_fmha_pv16, test_working_softmax_maybe
- Deleted root scratch: debug_linear.py, test_mapping.py, run_router_tests.py
- Moved archive/ to archived_plans/code_archive/
- Rewrote production.py: single fast path via 6-warp multi-tile kernel
- Added STATUS.md, audit_attention_live.md
- Moved NEXT_PRIORITIES*.md to archived_plans/
2026-05-30 21:08:12 +00:00
1e6adf5e01 P3: wire 6-warp multi-head FMHA decode fast path into production.py
- fmha_multihead_launch.cu: PyTorch launch wrapper for fmha_6warp_multihead_kernel
  (c10::BFloat16 boundary, uint16_t bf16_t inside kernel, zero-cost casts)
- fmha_multihead_op.py: torch.utils.cpp_extension JIT loader + custom_op registration
  (dsv4::fmha_multihead_decode for torch.compile)
- production.py: fast path dispatch for T=1, n_segments==1, hd in {64,128,256}
  Falls through to CuTeDSL slow path for multi-segment/prefill
- test_p3_fast_decode.py: integration test (MHA/MQA/GQA, cosine >= 0.999998)

Architecture:
  Grid: dim3(1, n_h, batch_size) — one CTA per (head, batch)
  MQA: k_head_stride=0 so all Q heads share same K/V
  Single kernel launch, zero cudaDeviceSynchronize on hot path
  Normalized output for single-segment decode
2026-05-30 08:12:23 +00:00
2dcfc0089f auto: pre-test commit 2026-05-28 15:49:47 +00:00
74dba6ab9d auto: pre-test commit 2026-05-28 04:40:20 +00:00
acf46c494c NVFP4-1.1: update approach doc and fp4_quant with CuTeDSL API fixes 2026-05-28 04:09:58 +00:00
80b6b79f9e NVFP4-1.1: FP4 quantization primitives for CuTeDSL kernels
- fp8_e4m3_from_float32: manual FP8 E4M3 cast (bias=7, exp 0-15 valid,
  NaN guard for exp=15/mant=7, mantissa overflow handling)
- fp8_e4m3_to_float32: dequantize FP8 E4M3 bit pattern back to Float32
- half_step_to_e2m1_idx: E2M1 step mapping (0-12 → 0-7)
- quantize_e2m1_nibble: per-element E2M1 quantize + sign + pack
- Verified 0/500 trial failures against Python reference
- Key fixes discovered during validation:
  1. FP8 E4M3 bias is 7, NOT 8
  2. Exponent range is 0-15 (exp=15/mant=7 is NaN; others valid)
  3. Subnormal formula: val = m * 2^(-9) = m/512 (NOT m/1024)
  4. Round-to-nearest-even (not round-half-up) for half_step and mantissa
  5. Mantissa overflow (round to 8) must increment exponent
2026-05-28 03:39:55 +00:00
064ececc9a Update docs: D1.5 TMEM round-trip fundamentally broken, Python KV merge is production path 2026-05-26 19:53:10 +00:00
43f0b5d1e8 D1.5: Fix O rescale with paired atoms (incremental approach)
Keep epilogue_tma_store for final output (proven path).
Only fix the multi-KV-tile O rescale using paired atoms from
epilogue_tmem_copy_and_partition. The paired atoms share addressing,
making the TMEM->REGS->modify->TMEM cycle lossless.

Guarded by const_expr(n_kv_tiles > 1) so single-tile path (n=128)
is completely unaffected — zero regression risk.

Full correction epilogue (one-way TMEM->REGS->SMEM->GMEM) deferred
until we can address the MLIR compilation time issue.
2026-05-26 19:34:26 +00:00
4bb0e063cc D1.5: Replace broken TMEM round-trip with correction epilogue (paired atoms)
Replace hand-constructed Ld32x32bOp/St32x32bOp TMEM round-trip with the
proven correction epilogue pattern from fused_swiglu.py:

1. O rescale (kt>0): TMEM→REGS (paired load), multiply by acc_scale,
   REGS→TMEM (paired store via retile_to_S). No layout mismatch.

2. Final O output: One-way TMEM→REGS→SMEM→GMEM using
   epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition
   + TMA partition. Register-level normalization (divide by row_sum)
   or raw BF16 cast for D5a path.

This fixes both D1.5 issues:
- Issue 1: TMEM round-trip corruption (hand-constructed atoms)
- Issue 2: O rescale for multi-KV-tile (kt>0)

Supports normalize=True (in-kernel) and normalize=False (D5a external).
Uses epilog_sync_bar + c_pipe for SMEM→GMEM, replacing epilogue_tma_store.
2026-05-26 19:11:19 +00:00