Commit Graph

90 Commits

Author SHA1 Message Date
4968ce064d even more stuff 2026-05-21 05:55:22 +00:00
15c987244f v28 attempt: PV MMA (128,64) - cosine 0.004, debugging 2026-05-21 05:41:44 +00:00
97656a5cd1 Stage B: two MMAs + identity softmax — crash fixed, softmax output still wrong
Key fixes:
- PipelineUmmaAsync consumer group: 32*4=128 threads (not 4 warps)
- TMEM offsets computed from find_tmem_tensor_col_offset (not hardcoded)
- P fragment from p_tmem_s.outer + make_fragment_A (matching fmha.py)
- V SMEM aliasing via recast_ptr

Status:
- Stage A: cosine 0.999999 
- Stage B: runs without crash, identity softmax cosine -0.02 
- Diagnostics: TMEM layout inspection, bisection results
2026-05-20 20:26:25 +00:00
a5b48be7d5 stuff 2026-05-20 07:15:01 +00:00
67d5e26080 Fix warmup compilation + add sparse topk metadata kernels
Bug #2 fix: warmup_compilation and warmup_fused_swiglu_compilation now
use valid FP4 data by quantizing random BF16 through quantize_to_nvfp4.
Random uint8 bytes as FP4 bit patterns cause cudaErrorIllegalInstruction
in Blackwell MMA hardware. Re-enabled warmup calls in runner.py.

Bug #1 kernel: sparse_topk_metadata.cu with:
  - build_c128a_topk_metadata: position-based compressed KV slot lookup
    via block table for C128A (compress_ratio=128) decode tokens
  - compute_c4a_global_topk: local topk index -> global slot ID mapping
    via block table for C4A (compress_ratio=4) decode tokens
  - Both tested: correct block table lookups, proper padding

Bug #3 kernel: C4A uses compute_c4a_global_topk (same .cu file)
  - Replaces vLLM Triton kernel with our own CUDA kernel

Deleted stale STATUS.md, FUSED_EPILOGUE_STATUS.md, FUSED_EPILOGUE_PLAN.md, CURRENT_BUGMD
2026-05-20 06:43:43 +00:00
bbba289bd8 feat: GPU-native SWA + sparse decode attention kernels (CuTeDSL)
- native_swa_decode.py: BlackwellSWADecodeKernel
  - CTA mapping: 1 CTA per (decode_token, q_head_group)
  - Online softmax with KV tile streaming (16 tokens/tile)
  - Pre-dequantized bf16 KV (fp8 dequant on host - MLIR cvt_fpext
    requires 32-bit aligned vector, no scalar fp8->bf16 support)
  - Cosine 0.9999+ vs PyTorch batched SDPA reference
  - Fallback _fallback_batched_sdp when CuTeDSL unavailable

- native_sparse_decode.py: BlackwellSparseDecodeKernel
  - Combined SWA + compressed KV in single attention pass
  - Supports CSA (cr=4) and HCA (cr=128) layers
  - Sink weight merge on host side
  - Cosine 0.9999+ vs combined SDPA reference

- fp8_bf16.py: Documents MLIR limitation (cvt_fpext requires
  vector<4xf8>, no scalar support). Pre-dequant is the workaround.

- vLLM wiring (attention.py):
  - SWA-only layers: native_swa_decode_attention
  - CSA/HCA layers: native_sparse_decode_attention with topk + attn_sink
  - csa_attention.py updated to use native kernels

- Tests: test_decode_pipeline.py, test_sparse_decode.py both passing
2026-05-20 05:46:15 +00:00
04eca7c6da Custom CUDA kernel for de-interleave plus NVFP4 quantize 2026-05-20 04:39:47 +00:00
061d5692a9 Remove debug print statements from pipeline 2026-05-20 04:20:46 +00:00
aa8563c626 Fused SwiGLU epilogue with granularity-8 weight interleave
- Fix interleave_l1_weights: remove //2 bug (g=granularity_bf16 for N-axis)
- Apply L1 weight+SF interleave in runner._ensure_stacked() and moe_pipeline
- De-interleave L1 GEMM output before gate/up split
- Fused SwiGLU kernel: epi_tile=(128,8) for subtile-level pairing
  - Even subtiles = gate: SiLU in FP32 registers, save to register buffer
  - Odd subtiles = up: silu(gate)*up from buffer
  - Both branches produce same BF16 tensor type (CuTeDSL constraint)
- run_nvfp4_moe_fused() pipeline: fused L1 + PyTorch L2
- Runner: fused_swiglu=True option for CuTeDSLMoERunner
- Layertest: both fused and non-fused paths PASS (cosine 0.988)
- README.md updated with current status and lessons learned
2026-05-20 04:13:52 +00:00
6c04155167 wip: Step 2 gate/up pairing — SiLU validated, runtime conditionals blocked by CuTeDSL
SiLU in registers: PASS (0.034% error, Step 1 stable)
Gate/up subtile detection: blocked by CuTeDSL type system

CuTeDSL compiles the kernel for ALL subtile iterations at once.
Runtime conditionals (if is_gate_subtile) that affect:
- Register tensor assignment → DSLRuntimeError (type structure mismatch)
- TMA store skipping → corrupted output
- Mask blending → wrong results

Path forward: use const_expr debug flag for the BF16 side output,
or process gate/up in a separate post-GEMM kernel.
2026-05-20 03:26:20 +00:00
9f0c1b8c5d wip: Step 1 SiLU validation complete, Step 2 gate/up pairing planning
Step 1 VALIDATED:
- cute.exp works on register tensors in the epilogue
- SiLU (x / (1+exp(-x))) produces correct results
- Relative error vs PyTorch: 0.034%, max abs: 0.0625 (BF16 precision)

Step 2 (gate/up pairing) approach:
- Register-level pairing requires understanding acc_vec layout from tiled_copy_r2s
- DeepGEMM pattern: (values[0], values[2]) pairs for tcgen05.ld
- CuTeDSL retile may produce different layout than direct PTX loads
- SMEM-level SiLU is a valid intermediate: avoids GMEM round-trip while
  working in logical (M, N) coordinate space
- Non-interleaved weights + SMEM SiLU is simplest starting point
2026-05-20 03:16:34 +00:00
b84f2f7bf9 fix: cutlass.Float32 not cutlass.float32_t in fused epilogue
Step 1 SiLU validation: PASS
- cute.exp works on register tensors
- SiLU (x / (1+exp(-x))) in registers matches PyTorch reference
- Relative error: 0.034%, Max abs error: 0.0625 (BF16 precision limit)
2026-05-20 03:12:23 +00:00
08992b818d wip: add run_fused_swiglu_grouped_gemm bridge + step1 test 2026-05-20 03:10:56 +00:00
9c43c69a4c wip: fused SwiGLU Stage 1 - SiLU in registers (full acc_vec)
Stage 1 of the fused epilogue: applies SiLU (x * sigmoid(x)) to the
full accumulator register tensor before writing BF16 to C.

This validates that cute.exp and element-wise FP32 operations work
on CuTe register tensors in the epilogue. The gate/up pairing is
not yet implemented (Stage 2).

The fused_swiglu flag is const_expr(0) by default, so the standard
epilogue path is unchanged unless the flag is enabled.
2026-05-20 03:07:02 +00:00
2f053f674e wip: fused SwiGLU kernel scaffold + bridge interleave + plan
- fused_swiglu_grouped_mm.py: copypaste of torch_scaled_grouped_mm.py with
  class rename and fused_swiglu/swiglu_limit params added
- bridge.py: added interleave_l1_weights, deinterleave_l1_weights,
  warmup_fused_swiglu_compilation
- Pure-PyTorch interleave invariant passes (A@cat vs deinterleave(A@interleave))
- Standalone GEMM interleave test fails due to kernel-internal N-tiling
  layout (expected, skipping per plan)
- FUSED_EPILOGUE_PLAN.md updated with register layout, amax shuffle plan,
  4-step implementation strategy
2026-05-20 03:04:38 +00:00
4f178d6e9c chore: remove unused _expert_id_range after bincount migration 2026-05-20 02:17:44 +00:00
84a2f6d441 perf: replace expert counting O(n*E) comparison with torch.bincount O(n)
Bug #5 fix: (sorted_ids.unsqueeze(1) == expert_id_range.unsqueeze(0)).sum(dim=0)
materializes a (num_slots × num_experts) bool tensor every forward — 48K × 384 = 18M
elements. torch.bincount(sorted_ids, minlength=num_experts) gives the same result
in O(n) with no intermediate allocation. ~200× less work.

Also removes the now-unused _expert_id_range buffer.
2026-05-20 02:17:23 +00:00
4882d8553c fix: zero out x_norm for underflow blocks before division in NVFP4 quantization
Bug #4 fix: When a block has amax > 0 but amax/6 underflows to 0 in
FP8 (amax < 6*2^-9 ≈ 0.0117), the block scale is 0, but the division
x / clamp(0, 1e-8) inflates x into nonzero FP4 buckets (up to ±6.0).
This produces semantically wrong FP4 even though dequant gives 0 (6*0=0).

Root cause: we only detected truly-zero blocks (amax == 0) but not
underflow blocks (0 < amax < FP8_threshold). The fix:

1. Detect both zero and underflow blocks: block_amax < 6 * 2^-9
2. Zero out x_reshaped for these blocks BEFORE division
3. Force FP8 scale to 0 for these blocks

This ensures x_scaled = 0 → FP4 nibbles = 0 → dequant = 0.
Verified: bug scenario now produces nibble=0, scale=0.
Checkpoint byte match remains 100%.
2026-05-20 02:16:49 +00:00
e653712598 fix: detect zero blocks in NVFP4 quantization, force FP4+FP8 to exact zero
Bug #3 fix: The clamp(min=1e-8) on block_amax prevented NaN from 0/0
but allowed truly-zero blocks to get a nonzero FP8 scale (5e-12 from
underflow). While the kernel produces 0 * 0 = 0 (no NaN), the nonzero
scale is semantically wrong and could interact badly with future kernels.

Fix: detect zero blocks explicitly (block_amax == 0), clamp only for
safe division, then force FP8 scale to exact zero for zero blocks via
torch.where. The FP4 nibbles are already zero (0 / anything = 0).

Verified: checkpoint byte match remains 100%, zero blocks produce
exact-zero dequantization, no NaN propagation.

Applies to all three quantization functions:
- quantize_to_nvfp4 (activation with computed gs)
- quantize_activation_nvfp4 (activation with pre-computed gs)
- quantize_weight_to_nvfp4 (weight quantization)
2026-05-20 02:14:50 +00:00
1857bdedc3 chore: deprecate prepare_weights_from_dequantized and prepare_weights_direct
Verified that our NVFP4 packing convention (odd<<4|even, round-half-to-even)
matches the DeepSeek-V4 checkpoint exactly: 100% byte-identical round-trip
across all tested experts. The dequantize->requantize path is lossless in
practice but wasteful. Marked both prepare_weights_from_dequantized and
prepare_weights_direct as deprecated in favor of prepare_weights_from_stacked
which loads checkpoint FP4 bytes directly via .view().

Also added test_fp4_roundtrip.py for future reference.
2026-05-20 02:11:40 +00:00
ef398006a7 fix: correct scale factor dimensions in warmup (K_sf = ceil_div(K_packed,8) not ceil_div(K_packed,16))
K_packed = original_K // 2. The scale factor dimension is
K_sf = ceil_div(original_K, 16) = ceil_div(K_packed * 2, 16) = ceil_div(K_packed, 8).
The previous code used ceil_div(K_packed, 16) which was wrong.
2026-05-20 02:08:26 +00:00
8f1a20562f fix: root-cause JIT memory corruption myth, add eager warmup, remove _needs_token_refill
Bug #1 fix: The _needs_token_refill workaround was a band-aid over a
misdiagnosis. cute.compile does NOT corrupt GPU memory (verified on B200).
The original corruption was from a different bug (likely OOB write or
weight loading issue).

Changes:
- bridge.py: Add warmup_compilation() for eager JIT before runtime buffers
  exist. Pre-allocate workspace per cache entry (no torch.full in hot path).
  Cache stores {compiled, workspace, workspace_size} instead of just compiled.
  CuTe tensor wrappers re-created per call (cheap metadata, avoids stale refs).
- runner.py: Remove _needs_token_refill hack. Add eager warmup call in
  _ensure_stacked() for both L1 and L2 GEMM shapes.
- nvfp4_linear.py: Add eager warmup in finalize_weights() for single GEMM.

The warmup approach ensures cute.compile runs exactly once per shape during
model init, before any forward pass. This is deterministic and eliminates
any possible interaction between JIT and runtime GPU memory.
2026-05-20 02:08:01 +00:00
6ec0afc318 fix: handle 3D swa_indices and correct kv_bf16 expand dims 2026-05-20 01:36:27 +00:00
aa593361e7 feat: add native CuTeDSL SWA decode attention kernel stub + batched SDPA fallback 2026-05-20 01:28:05 +00:00
3599b44c0f fix: replace _allocate_buffers with _ensure_buffer_size for dynamic sizing 2026-05-20 00:02:10 +00:00
1d5e70adfb fix: dynamic buffer sizing in nvfp4_linear for varying token counts 2026-05-19 23:59:55 +00:00
0023fee706 Add blackwell_attention module and comprehensive test 2026-05-19 15:30:29 +00:00
9d067add90 Fix device reference in full_attention_reference 2026-05-19 08:01:31 +00:00
3e3e998578 Fix attention: manual causal mask for batched single-query 2026-05-19 08:01:08 +00:00
1e675ccc9a Fix causal mask shape for SDPA: (1,1,T,T) broadcast 2026-05-19 08:00:39 +00:00
57615029a4 Fix KV expand for SDPA: (T,HD) → (T*NH, T, HD) 2026-05-19 08:00:08 +00:00
dd3a12bbda Fix full_attention_reference: broadcast KV to all heads+positions 2026-05-19 07:59:28 +00:00
910015c47e Fix kv shape: expand to (T, NH, HD) before reshape 2026-05-19 07:58:42 +00:00
3de75c4e37 Add CSA/HCA attention kernel (PyTorch SDPA, Blackwell-safe)
Replaces vLLM's broken FlashMLA sparse attention which doesn't work on
SM100 (Blackwell). Uses torch.nn.functional.scaled_dot_product_attention
which works on all GPUs.

Architecture:
- CSA (C128A): Batched sparse gather + SDPA on top-k positions
- HCA (C4A): Same with compressed KV + per-layer indexer
- SWA: Sliding window attention
- Full reference: standard SDPA for testing without compression

Also adds test_csa_attention_b200.py to verify the full attention path.
2026-05-19 07:58:10 +00:00
0a7769972f Fix garbled shared_expert_pipeline.py: imports/class were merged 2026-05-19 07:18:10 +00:00
05cdde1676 Fix wo_a: scatter each group's data at correct offset in padded buffer
The grouped GEMM expects each group's tokens at their own offset range:
- Group 0: rows [0, padded_T)
- Group 1: rows [padded_T, 2*padded_T)
- etc.

Previously we wrote all groups' data contiguously starting at row 0,
so group 1+ would read zeros from the padding area. Now we scatter
each group's quantized activation at the correct offset.

Also:
- Size buffer for total_max_rows = padded_max * n_groups
- Use assemble_scales_2d_side for multi-group scale assembly
- Extract output per-group at correct offsets
2026-05-19 02:45:57 +00:00
5f5b997fc3 Fix wo_a: permute to groups-first layout for grouped GEMM
The grouped GEMM expects mat_a to be laid out contiguously per group:
[all tokens for group0, all tokens for group1, ...]
A simple reshape of (T, G, D) → (T*G, D) gives interleaved layout
which is wrong. Fix: permute to (G, T, D) before flattening.
Same fix for output: permute (G, T, R) → (T, G, R).
2026-05-19 02:41:32 +00:00
882d4996ff Replace DeepGEMM fp8_einsum with CuTeDSL NVFP4 for wo_a (o_proj)
The B200 container crashes in DeepGEMM's fp8_einsum (t.dim() == N assertion
in layout.hpp:39) when processing wo_a (o-projection first half) in the
attention layer. The crash is caused by scale tensor dimension mismatch
for the SM100 recipe (1, 1, 128).

Instead of fighting DeepGEMM, replace the entire wo_a path with our own
CuTeDSL NVFP4 kernel:

1. inverse_rope_bf16() — Python implementation of inverse RoPE
   (replaces fused_inv_rope_fp8_quant CUDA kernel)
2. CuTeDSLNvfp4WoA — NVFP4 grouped linear for wo_a using
   ScaledGroupedGemm with n_local_groups=8 groups
3. wo_a weight quantized to NVFP4 instead of FP8 (native NVFP4,
   no conversion to another quantization)

Changes:
- cutedsl/inverse_rope.py: BF16 inverse RoPE (conjugate rotation)
- cutedsl/wo_a_grouped_linear.py: CuTeDSL NVFP4 grouped GEMM for wo_a
- vllm/patches/deepseek_v4_attention.py: Use NVFP4 path when runner
  is initialized, keep DeepGEMM fallback
- vllm/patches/deepseek_v4.py: Init NVFP4 runner instead of FP8 quant
- tests/test_wo_a.py: Unit test for inverse RoPE + wo_a GEMM
2026-05-19 02:36:30 +00:00
bab1f75f29 Fix gs None error in legacy _ensure_stacked path 2026-05-19 02:17:53 +00:00
48fa64dfda Eliminate weight copies: pass stacked checkpoint tensors directly
Memory optimization for MoE weight processing:

Before (3-4 copies of weights in memory):
1. Original checkpoint weights in layer.w13_weight (copy 1)
2. Per-expert permuted copies (copy 2)
3. torch.stack() in runner._ensure_stacked (copy 3)
4. make_b_k_major re-stride (copy 4)
5. Scales: permute then assemble_scales_3d_side un-permutes (wasted)

After (1-2 copies):
1. View checkpoint as fp4 (NO copy — byte-preserving view)
2. Pass (E, N, K) stacked tensor directly to runner
3. Runner permutes to (E, K, N) contiguous (copy 1), frees stacked ref
4. make_b_k_major re-strides (copy 2), frees (E, K, N) ref
5. Scales: already (N, K_sf) from checkpoint, call assembly directly
6. Free layer.w13_weight etc. immediately after extracting views

Also: assemble_scales_3d_side transposes (K_sf, N)→(N, K_sf) internally,
but checkpoint scales are ALREADY (N, K_sf). Skip the double-transpose
by calling assemble_raw_scales_2d3d_3d_side directly.
2026-05-19 02:16:43 +00:00
35fab6cff3 Replace autograd.Function with torch.library.custom_op for Dynamo compat
Dynamo (torch.compile fullgraph) cannot trace through CuTeDSL internals
(cute.compile, JIT, etc.). The autograd.Function approach was unreliable
with fullgraph mode — Dynamo would still try to trace through it.

Fix: torch.library.custom_op makes Dynamo treat our GEMM as an opaque
black box. No reimplementing the kernel — just route through the existing
runner via a registry pattern:
  - Runners registered in global dict with integer IDs
  - Custom op takes (tensors, runner_id, shape_hint) -> tensor
  - Dynamo calls fake impl for shape inference, never touches the runner
  - At execution time, real impl looks up runner and calls _run_impl

Changes:
  - New: cutedsl/custom_ops.py (custom op definitions + registry)
  - New: tests/test_custom_op.py (local unit tests, no GPU needed)
  - Removed: _Nvfp4LinearApply, _MoEApply (autograd.Function classes)
  - Updated: nvfp4_linear.py, runner.py, cutedsl.py, nvfp4_cutedsl.py
    to use custom ops instead of autograd.Function
  - Updated: cutedsl_quant_method.py to use custom op + registry
2026-05-19 01:54:48 +00:00
b007937a68 Fix garbled imports in cutedsl/runner.py 2026-05-18 22:22:52 +00:00
a7ed8faec6 Proper NVFP4 integration: use ModelOptNvFp4Config + FusedMoE framework
Major refactor to eliminate all post-load hacks:
- deepseek_v4.py: use upstream model with NVFP4 weight mapper only
  (gate_proj→w1, up_proj→w3, down_proj→w2, .self_attn→.attn, .mlp→.ffn)
- Add CuTeDSLMoEExperts as a FusedMoEExpertsModular subclass
  that wraps our CuTeDSL runner as a proper vLLM MoE backend
- Register CUTEDSL backend in the NVFP4 oracle
- Use ModelOptNvFp4Config for quantization dispatch (not DeepseekV4FP8Config)
- ModelOptNvFp4LinearMethod handles NVFP4 attention/shared expert projections
- Remove nvfp4_cutedsl.py, cutedsl_quant_method.py, utils.py from Dockerfile
- CuTeDSL runner moved to cutedsl/runner.py for clean imports
- cos_sin_cache float32 fix in deepseek_v4_attention.py

No more monkey-patching, no _convert_nvfp4_post_load, no CuTeDSLNvfp4Method.
2026-05-18 22:19:23 +00:00
48386e34ad Fix torch.compile: use custom autograd Function instead of @torch.compiler.disable
torch.compile fullgraph mode can't handle @torch.compiler.disable (skips
the function and refuses to compile). Custom autograd Functions are treated
as opaque ops by torch.compile — they execute eagerly without the compiler
trying to trace into CuTeDSL internals (JIT, Path.cwd, etc).
2026-05-18 21:38:28 +00:00
85e1cd3b69 Fix torch.compile crash: @torch.compiler.disable on all CuTeDSL run()
CuTeDSL internals (Path.cwd, threading, JIT) are incompatible with
torch.dynamo tracing. Marking run() as compiler-disabled makes the
runners opaque to torch.compile — they execute eagerly while the
rest of the model gets compiled.
2026-05-18 21:07:35 +00:00
a94011ec92 Fix torch.compile crash: remove threading.Lock from LUT cache path
The _NVFP4_STEP_LUT_LOCK caused 'Unsupported context manager' under
torch.compile/cudagraph. LUT is now pre-populated during warmup so
the fast path (cache hit) never hits a lock.

Also removed all init/warmup debug prints from CuTeDSL kernels.
2026-05-18 20:54:55 +00:00
450793311c Wire CuTeDSL kernels into vLLM: replace all BF16 dequant with native NVFP4
- CuTeDSLNvfp4Method: custom quant method that creates CuTeDSL runners
  during process_weights_after_loading, then swaps to CuTeDSLNvfp4LinearMethod
  for forward dispatch
- Attention projections (fused_wqa_wkv, wq_b, wo_b) now route through
  CuTeDSLNvfp4Linear (cosine 0.992-0.996 vs BF16 reference)
- Shared expert now uses CuTeDSLSharedExpertRunner (cosine 0.992 vs BF16)
  with monkey-patched forward for fused L1+SiLU+L2 pipeline
- Deleted all BF16 dequant code (_dequant_nvfp4_to_bf16, _post_quant_fix,
  input_scale fixes)
- Deleted _post_quant_fix hook from utils.py
- Fixed SwiGLU clamp: gate clamped BEFORE SiLU (matching SiluAndMulWithClamp)
- Cleaned up all debug prints
- Updated Dockerfile with new kernel files
2026-05-18 20:27:42 +00:00
6ce6a47be9 Add NVFP4 linear runner + attention projection test
- CuTeDSLNvfp4Linear: generic single-GEMM runner for any NVFP4 projection
- test_attention.py: tests q_a_proj, q_b_proj, kv_proj, o_b_proj vs BF16
- Same pad+swizzle pattern as shared expert, but no SiLU/fusion
2026-05-18 20:14:03 +00:00
70f50a1ec6 Fix scale assembly: use correctly-sized temp buffer for swizzle 2026-05-18 20:09:50 +00:00
97bdd604e9 Fix scale assembly: reshape swizzled output to 2D 2026-05-18 20:09:19 +00:00