Commit Graph

13 Commits

Author SHA1 Message Date
0b6ca0df80 P5 integration + B3 q_a_norm fused + gsa scalar fix
P5: Wire up fused mHC pre_block + RMSNorm + NVFP4 quantize kernel
- Replaces: pre_block bmm + rmsnorm (4+ launches) + quantize (2 launches)
- With: 2 kernel launches (mhc_rmsnorm_amax_gsa + mhc_rmsnorm_quantize_nvfp4)
- Both attn and ffn mHC paths now use P5 fused kernel
- Savings: ~5 launches/site × 2 sites × 61 layers = 610 launches/token

B3: Fused rmsnorm+quant for q_a_norm → q_b path
- q_a output → rmsnorm_quantize_nvfp4 → QuantizedActivation → q_b.run_from_quantized
- Eliminates BF16 round-trip between q_a_norm and q_b GEMM
- Saves: ~6 kernel launches per layer (rmsnorm 4+ + quantize 2 vs fused 2)

gsa scalar fix in Nvfp4Linear.run_from_quantized:
- CuTeDSL NVFP4 GEMM expects global_scale_a as per-expert scalar (shape (1,))
- Per-row gsa from fused kernels must be reduced to scalar (max) for M>1
- For M=1 decode: already scalar, no reduction needed
- Fixes potential correctness issue at prefill (M>1) when using fused paths

Cleanup: Remove --ab-compare flag and A/B comparison code (replaced by P5)
2026-06-02 21:20:34 +00:00
0d1cd1e216 P4: Add QuantizedActivation + Nvfp4Linear.run_from_quantized
- QuantizedActivation: carries (x_fp4, x_sf, gsa) for skip-quantize path
- Nvfp4Linear.run_from_quantized(): runs GEMM with pre-quantized input
- Enables fused RMSNorm+quantize to feed directly into all downstream
  linears (q_a, kv, o_proj, etc.) without re-quantizing
2026-06-02 16:37:38 +00:00
d8e17d70c1 P0+P1+P2: Enable fused SwiGLU (MoE+SE), fix SE _run_l1_fused, remove per-call gsa fill_
P0: Enable fused SwiGLU for MoE (set_fused_swiglu(True))
  - Saves 240+ unfused BF16 kernel launches per token
  - SiLU + clamp in kernel registers instead of separate launches

P1: Fix shared expert _run_l1_fused + enable fused SwiGLU
  - Fixed: _l1_sf_view -> _l1_scale_b, _l1_gs_view -> _l1_gsb
  - Fixed: expert_offsets dtype int64 -> int32
  - Added proper padded buffer + scale assembly (matching unfused path)
  - Added runtime gsa support (quantize_nvfp4_gpu_fused)

P2: Remove per-call gsa_buf.fill_() in Nvfp4Linear
  - fill_() was H2D transfer every forward pass (~5µs × 244 calls = ~1.2ms/token)
  - _gsa_buf now initialized with _activation_global_scale (not zeros)
  - After warmup_gsa, buffer already has correct value — no fill needed
2026-06-02 07:57:39 +00:00
61d5e7ba53 revert: P2 gsa fill elimination — revert to proven path for e2e stability
The fill_() is a CPU→GPU scalar write (tiny cost). The optimization
was marginal and the output quality regression (CJK tokens) needs
investigation separately. P2 can re-land after the regression is
confirmed to be sampling-related (not gsa-related).

P0/P1 (fused SwiGLU) still disabled — kernel arg-binding bug unfixed.
2026-06-02 07:32:10 +00:00
040b2eb6e7 perf: P0/P1/P2 — fused SwiGLU for MoE+SE, eliminate per-call gsa fill
P0: Enable fused SwiGLU for all MoE instances (moe._fused_swiglu = True).
    Eliminates ~8 BF16 kernel launches per MoE per token (gate/up split,
    SiLU, clamp, elementwise multiply → single fused kernel launch).

P1: Enable fused SwiGLU for shared expert (SE):
    - Added set_fused_swiglu() method to Nvfp4SharedExpert
    - Added _run_l1_fused() using run_fused_swiglu_grouped_gemm (1-group)
    - Interleave L1 weights at finalize time for fused kernel compatibility
    - Fused kernel handles SwiGLU + clamp in registers, outputs BF16

P2: Eliminate per-call _gsa_buf.fill_() in Nvfp4Linear:
    - _activation_global_scale is set once at warmup, never changes after
    - Skip redundant fill_() via _gsa_buf_initialized flag
    - Saves 244 CPU→GPU scalar fills per token (4 linears × 61 layers)

P3: Deferred (in-kernel RoPE fusion — kernel-side change, not single_shot)
2026-06-02 06:59:25 +00:00
7e3fb5f4d0 fix: add missing import for quantize_nvfp4_gpu in linear.py fixed-gsa path 2026-06-02 04:28:29 +00:00
c8faf20a99 P0 COMPLETE: Eliminate ALL .item() CPU-GPU syncs from NVFP4 activation path
Fused kernels (zero CPU sync, single kernel launch per projection):
- fused_amax_quantize.cu: amax→gsa→quantize in one pass. Replaces two-step
  compute_amax_gsa_gpu + quantize_nvfp4_gpu (had .item() sync).
- fused_deinterleave_amax_quantize.cu: Same for MoE fused_swiglu L2 path.
  Deinterleave + amax + quantize in one pass. Replaces compute_amax_gsa_gpu
  + deinterleave_quantize_nvfp4_cuda (had .item() sync).

All kernel loaders use dsv4/kernels/cuda/loader.py (compile-once cache).
Was JIT-compiling on every call via torch.utils.cpp_extension.load (~100ms/call,
~500 calls/token). Now compiles once and reuses the cached module.

Updated layers:
- linear.py Nvfp4Linear._run_impl: fused kernel, gsa via GPU buffer
- moe.py Nvfp4MoE._run_impl: fused for L1 and L2 (both fused_swiglu and
  non-fused paths)
- shared_expert.py: fused for L1 and L2
- quantize.py: All functions use module loader cache
- sampler.py: Uses module loader cache
- indexer/score_topk.py: Uses module loader cache

P2: Vectorized KVCache.append_swa — index_copy_ instead of Python loop.
2 kernel launches instead of 2T. No .item() in comp_pos either.

P3: Pre-allocated comp_kv buffers — O(1) append instead of O(N) torch.cat.
max_comp=32768 per layer (32MB). No more quadratic memory growth.

~486 .item() syncs per decoded token → ~0 (only argmax + token decode remain).
2026-06-01 21:05:03 +00:00
360f76b970 Performance audit fixes: eliminate CPU-GPU syncs
PERFORMANCE_AUDIT.md validation results:
  1. Nvfp4Linear .item() sync (610/step) → FIXED: compute_amax_gsa_gpu kernel
  2. MoE .item() sync (183/step) → FIXED: same kernel
  3. SharedExpert .item() sync (122/step) → FIXED: same kernel
  4. FMHA V clone → FIXED: V=K, transpose creates copy implicitly
  5. torch.cuda.synchronize in moe_forward → FIXED: conditional on VERBOSE
  6. RoPE 8x duplication → INVALIDATED: necessary for per-GPU HBM access
  7. mHC BF16 bmm → INVALIDATED: 28K FLOPs, not a bottleneck
  8. Router .float() cast → INVALIDATED: needed for FP32 topk, ~1μs

New files:
  - dsv4/kernels/cuda/amax_gsa.cu: GPU-only amax→gsa kernel
  - dsv4/ops/quantize.py: compute_amax_gsa_gpu() wrapper

Net effect: ~915 fewer CPU-GPU syncs per decode step
Remaining syncs: ~10 per layer (quantize kernel parameter) + diagnostics
2026-06-01 20:40:19 +00:00
2b1fca6dae CRITICAL FIX: runtime activation global scale to prevent E4M3 overflow
The checkpoint's input_scale was designed for training-time FP8 quantization,
not NVFP4 activation quantization. Using it as gsa causes x/gsa to exceed
the E4M3 block scale maximum (448), leading to systematic magnitude loss
in every projection. This accumulates over 61 layers, compressing the
logit range and producing garbage tokens.

Fix: compute gsa at runtime from actual activation magnitude:
  gsa = max(|x|) / (6.0 * 448.0)
This ensures x/gsa ≤ 2688 (the maximum representable in E4M3 block scales).

Applied to: Nvfp4Linear, Nvfp4GroupedLinear, Nvfp4MoE, Nvfp4SharedExpert, Router gate
2026-06-01 14:21:16 +00:00
e671780008 fix: transpose checkpoint weights before make_b_k_major in Nvfp4Linear/SharedExpert
Critical bug: checkpoint weights are (N_packed, K_packed) N-major format,
but make_b_k_major expects (E, K_packed, N_packed) input. Without the
permute, the K and N dimensions are swapped, producing garbage output
with wrong dimensions (e.g., q_a output was 3584 instead of 1536).

Also fix scale assembly: checkpoint scales are (N, K_sf) which should
use assemble_raw_scales_2d3d_3d_side (no transpose), not
assemble_scales_3d_side (which incorrectly transposes K_sf↔N).
2026-06-01 00:30:37 +00:00
e8a7a9256f fix: convert uint8 checkpoint weights to float4_e2m1fn_x2 for CuTeDSL GEMM
The CuTeDSL kernel expects float4_e2m1fn_x2 dtype for FP4 weight tensors,
but checkpoint weights from safetensors are loaded as uint8. The uint8 and
float4_e2m1fn_x2 have the same byte representation, so .view() is safe.

Fixed in:
- Nvfp4Linear.finalize_weights()
- Nvfp4SharedExpert.finalize_weights()
- Nvfp4MoE._ensure_stacked() (both stacked and legacy paths)
2026-06-01 00:18:34 +00:00
172448514c fix: fold weight_scale_2 into global_scale_b for NVFP4 GEMM
Critical bug fix: weight_scale_2 (the second-level NVFP4 scale) was
being dropped entirely in the production pipeline. The dequant formula
is lut[w] * weight_scale * weight_scale_2, so weight_scale_2 must be
folded into the GEMM's global_scale_b parameter.

Fixes in:
- Nvfp4Linear: ws2 field, folded in finalize_weights()
- Nvfp4MoE: l1_ws2/l2_ws2 lists, folded in _ensure_stacked()
- Nvfp4SharedExpert: l1_ws2/l2_ws2 lists, folded in finalize_weights()
- single_shot_inference.py: pass weight_scale_2 through all loading paths
- Also fix missing o_a_prod key fallback in attention output
2026-06-01 00:10:50 +00:00
3fb3c925af Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00