Commit Graph

573 Commits

Author SHA1 Message Date
eebf33b97d test: clean minimal nvvm.inline_ptx test 2026-05-28 04:45:21 +00:00
882d48588b test: debug nvvm.inline_ptx with CUTLASS_LOG_LEVEL=DEBUG 2026-05-28 04:44:35 +00:00
3ffb3b807a test: minimal nvvm.inline_ptx isolation test 2026-05-28 04:43:18 +00:00
1cbb3cf752 NVFP4-1.1: Replace threshold rounding with inline PTX cvt.rni/rz/rmi
- Add f32_to_i32_rni (cvt.rni.s32.f32) for round-to-nearest-even
- Add f32_to_i32_rz (cvt.rzi.s32.f32) for round-toward-zero
- Add f32_to_i32_rmi (cvt.rmi.s32.f32) for round-to-minus-infinity
- Replace round_rne_u0_8 and abs_scaled_to_e2m1_idx threshold hacks
  with proper PTX hardware rounding in fp8_e4m3_from_float32
- quantize_e2m1_nibble now uses f32_to_i32_rni + LUT logic for half_step
- Add test_ptx_convert.py for inline PTX conversion verification
- This is the CORRECT approach per NVFP4-1.1_INLINE_PTX_APPROACH.md option 1
2026-05-28 04:40:17 +00:00
2777ebfe8e NVFP4-1.1: ultra-minimal test — Float32 comparison + Int32 select 2026-05-28 04:35:06 +00:00
2087eaef49 NVFP4-1.1: minimal threshold rounding test 2026-05-28 04:33:38 +00:00
1828a71cde NVFP4-1.1: test kernel uses Float32 input (avoids BF16 scalar load issue) 2026-05-28 04:32:08 +00:00
accc66741d NVFP4-1.1: update test kernel with threshold rounding API 2026-05-28 04:27:29 +00:00
c3d5a7b82f NVFP4-1.1: try .to(Int32) for float-to-int conversion 2026-05-28 04:02:45 +00:00
dc35d29811 NVFP4-1.1: fix cute.arch.store signature - store(ptr, val) not store(ptr, val, dtype) 2026-05-28 04:01:38 +00:00
a05a76bb6b NVFP4-1.1: add Int32 cast diagnostic test 2026-05-28 03:59:01 +00:00
6f94925491 NVFP4-1.1: fix cute.math.fmax -> cute.arch.fmax (correct CuTeDSL API) 2026-05-28 03:48:51 +00:00
60790564f0 NVFP4-1.1: fix test - two-pass kernel, cute.arch.store confirmed on B200 2026-05-28 03:46:45 +00:00
a41de129cb NVFP4-1.1: fix test kernel - use cute.copy instead of cute.arch.store 2026-05-28 03:42:24 +00:00
3a78bdf570 NVFP4-1.1: add CuTeDSL kernel test for FP4 quantization 2026-05-28 03:40:54 +00:00
80b6b79f9e NVFP4-1.1: FP4 quantization primitives for CuTeDSL kernels
- fp8_e4m3_from_float32: manual FP8 E4M3 cast (bias=7, exp 0-15 valid,
  NaN guard for exp=15/mant=7, mantissa overflow handling)
- fp8_e4m3_to_float32: dequantize FP8 E4M3 bit pattern back to Float32
- half_step_to_e2m1_idx: E2M1 step mapping (0-12 → 0-7)
- quantize_e2m1_nibble: per-element E2M1 quantize + sign + pack
- Verified 0/500 trial failures against Python reference
- Key fixes discovered during validation:
  1. FP8 E4M3 bias is 7, NOT 8
  2. Exponent range is 0-15 (exp=15/mant=7 is NaN; others valid)
  3. Subnormal formula: val = m * 2^(-9) = m/512 (NOT m/1024)
  4. Round-to-nearest-even (not round-half-up) for half_step and mantissa
  5. Mantissa overflow (round to 8) must increment exponent
2026-05-28 03:39:55 +00:00
b9f15c250f Stage E: head-packed MQA/GQA, batch dim, custom_op, integration API
- production.py: head-packed M dimension for MQA/GQA (q_per_kv*T rows
  in single launch per KV group, eliminating redundant K/V TMA loads)
- production.py: batch dimension support (outer Python loop)
- production.py: warmup_attention_kernels() for pre-compilation
- production.py: dsv4_attention_per_head() for exact per-head sink bias
- __init__.py: sparse_fmha_with_swa, dense_fmha_with_swa, swa_only_fmha
  integration functions bridging AttentionSubBlock → production FMHA
- custom_ops.py: dsv4::sparse_fmha_with_swa custom_op registration
- test_production.py: comprehensive tests (MHA/MQA/GQA, head-packed vs
  per-head parity, multi-segment KV, SWA+causal+sink, batch, edge cases)
2026-05-27 15:15:03 +00:00
2412a5431b MQA/GQA: batch Q heads into kernel batch dim, shared K/V per KV group 2026-05-27 08:31:23 +00:00
06a895ff99 Clean test suite for production attention (1/2/4 segments, multi-head) 2026-05-27 07:12:02 +00:00
3a25c7feff Test multi-KV merge (2 segments) separately from multi-head 2026-05-27 06:54:16 +00:00
e45b94c01b Test: compare both normalized and un-normalized reference 2026-05-27 06:44:37 +00:00
98c93c1cd8 Stage E: production attention wrapper + Python KV merge, clean fmha_smem_acc 2026-05-27 06:34:10 +00:00
b02e103ac0 Add c_simple GMEM tensor (non-dynamic) for SMEM accumulator TMA store 2026-05-27 05:33:30 +00:00
bf36979a8d Use CUTLASS FMHA reference pattern for sC->GMEM TMA store (flat_divide + tma_partition) 2026-05-27 05:24:39 +00:00
97bc6d8d2f Add c_direct GMEM tensor for direct writes in SMEM accumulator path 2026-05-27 05:15:47 +00:00
a858ed1c14 Fix test: normalize=False for un-normalized O comparison 2026-05-27 05:06:52 +00:00
3a7d87adba Fix test_smem_acc: use keyword args for lse/row_sums 2026-05-27 04:54:23 +00:00
6a621bdf64 D1.5: SMEM accumulator FMHA kernel — one-way TMEM→REGS→SMEM, no round-trip
TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY BROKEN.
Even NO-OP (multiply by 1.0) corrupts data.

New approach:
- PV always ACCUMULATE=False (fresh TMEM each kt)
- After pv_done_bar: one-way Ld32x32bOp load O_kt from TMEM→REGS
- Coordinate-indexed SMEM accumulation: sO_acc = acc_scale * sO_acc + O_kt
- sO_acc: FP32 [128, pv_n_tile] row-major (32KB at hd=64, 64KB at hd=128)
- Final: normalize, cast BF16, write to sC, TMA store to GMEM
2026-05-27 04:53:40 +00:00
42c5793add D1.5: Add isolated round-trip test comparing s_k=128 vs s_k=256 with NOOP rescale 2026-05-26 20:45:58 +00:00
3be708d923 D1.5 debug: add NOOP rescale test (acc_scale=1.0) to isolate TMEM round-trip corruption 2026-05-26 20:28:55 +00:00
c3648e4ebf D1.5 debug: add targeted s_k=256 rescale diagnostic test 2026-05-26 20:27:37 +00:00
bf2c7c8bb8 D1.5: Implement in-kernel O rescale via CUTLASS correction_rescale pattern
- Both load and store atoms built from SAME tOtO_i (composition-tiled)
- Same Repetition(corr_tile_size=16) for both copies
- pv_done_bar synchronization between MMA and softmax warps
- acc_scale computed per kt iteration, used to rescale O in TMEM
- const_expr(n_kv_tiles > 1) guards for zero overhead at s_k=128
- New test: test_d15_in_kernel_rescale.py (s_k=128/256/384)
- Minimal roundtrip test: test_tmem_roundtrip_minimal.py
2026-05-26 20:26:06 +00:00
2b4f4ce538 Remove broken D1.5 paired-atom test (TMEM round-trick is fundamentally broken) 2026-05-26 19:50:31 +00:00
40cbf0c223 Add D1.5 paired-atom O rescale test (s_k=256/384, hd=64/128) 2026-05-26 19:46:19 +00:00
487d960a6a D5c multi-tile: VERIFIED cos 0.999996 with Python KV merge + sink bias
Both segments (compressed+SWA with n_comp=96, and SWA-only with n_comp=0)
pass individually at cos 0.999996. The Python KV merge produces the
correct combined attention at cos 0.999996.

Key: n_comp is compile-time, so separate kernel instances are needed
for segments with different n_comp values. Production code would use
a kernel cache keyed on (n_comp, apply_sink_bias, ...).
2026-05-26 15:40:45 +00:00
c9eab3c7e0 diag: rewrite multi-tile test with explicit per-segment compile and reference 2026-05-26 15:39:39 +00:00
7f983fb855 diag: add direct segment 0 test to compare with run_segment 2026-05-26 15:37:06 +00:00
2a5f9dc6e3 fix: simplify run_segment to use full hd V tensor (was incorrectly splitting by pv_n_tile) 2026-05-26 15:34:57 +00:00
aa2df1a202 diag: test n_comp=96 with sink bias directly 2026-05-26 15:33:38 +00:00
25b236fe00 diag: test D5c multi-tile with no sink bias to isolate issue 2026-05-26 15:31:38 +00:00
a3989929de diag: per-segment reference comparison for D5c multi-tile 2026-05-26 15:29:52 +00:00
bbc29945e8 diag: add per-segment debug output for D5c multi-tile 2026-05-26 15:28:19 +00:00
e64392f1ac D5c: add apply_sink_bias flag (independent of n_comp)
For all-SWA segments (n_comp=0), sink bias still needs to be applied
to all positions. The apply_sink_bias flag controls compilation of
the sink bias code path, independent of n_comp offset.
2026-05-26 15:26:52 +00:00
60b6f8ee79 fix: use separate kernels for segments with/without compressed KV (n_comp is compile-time) 2026-05-26 15:23:21 +00:00
2efd15c852 fix: correct swa_len_local calculation per segment for D5c multi-tile 2026-05-26 15:22:03 +00:00
3abcc7ff09 D5c: multi-tile test using Python KV merge with sink bias 2026-05-26 15:20:45 +00:00
ffc4b542bc D5c: use single KV tile (s_k=128) to avoid broken O rescale
The D5c sink bias logic is VERIFIED CORRECT (cos 0.999996).
Multi-KV-tile fails due to known D1.5 O rescale bug (TMEM round-trip).
Using s_k=128 avoids the broken code path. Multi-tile support requires
the D1.5 correction epilog fix.
2026-05-26 15:11:50 +00:00
fc0f4bcf23 diag: test D5c with single KV tile (s_k=128) to isolate O rescale issue 2026-05-26 15:09:49 +00:00
e5381b7312 diag: add baseline test (s_k=256 D3 mask, no sink bias) to isolate D5c issue 2026-05-26 15:08:40 +00:00
016edbcc97 D5c: add row_sum output for proper external normalization
The kernel's O_unnorm is max-shifted (divided by 2^row_max), so
O_norm != O_unnorm * exp(-LSE). Instead, O_norm = O_unnorm / row_sum.
Added mRowSums output tensor to enable correct normalization.
2026-05-26 15:07:22 +00:00