Files
nvfp4-megamoe-kernel/dsv4/kernels/router/dense_router_decode.py
biondizzle bab748763e Rewrite NVFP4 fused router kernel: MoE-style epilogue replaces broken SMEM merge
CRITICAL REWRITE of nvfp4_fused_router_kernel.py:
- REMOVED: Raw pointer SMEM merge (storage.merge_scores.data_ptr()[idx] = val)
  This crashed the CuTeDSL MLIR optimizer. Never use raw pointer indexing
  inside CuTeDSL kernels.
- REMOVED: Per-thread top-k accumulation + 128-thread SMEM merge. Too complex
  for MLIR, caused SIGABRT during compilation.
- ADDED: MoE-style epilogue (TMEM→regs→activation→SMEM→TMA store→GMEM)
  using paired copy atoms from CUTLASS (epilogue_tmem_copy_and_partition +
  epilogue_smem_copy_and_partition). Structurally identical to the proven
  FusedSwiGLUScaledGroupedGemmKernel epilogue. This SHOULD compile.
- Activation: sqrt(softplus(logit)) in registers (replaces SwiGLU)
- Output: FP32 activated scores written to GMEM via TMA store
- Top-k handled by activation_topk CUDA kernel in Python wrapper

Other changes:
- _activation_topk.py: Added run_fused_activation_topk_pre_activated() for
  top-k + renorm on pre-activated scores (PyTorch reference, not CUDA kernel)
- dense_router_dispatch_nvfp4_fused: Updated to match new kernel API
- Kernel now uses standard _compute_stages() for SMEM budget calculation
- Kernel now uses compute_epilogue_tile_shape() for epi_tile (not hardcoded)
- C pipeline (PipelineTmaStore) added for SMEM→GMEM overlap
2026-06-01 09:59:34 +00:00

102 lines
4.2 KiB
Python

"""DSV4 Dense Router — NVFP4 GEMM + sqrt(softplus) + bias + top-k.
Production paths (in priority order):
1. NVFP4 fused router kernel (nvfp4_fused_router_kernel.py):
Single-kernel blockscaled GEMM + fused router epilogue.
No intermediate GMEM buffer. Pure NVFP4 + Blackwell tensor cores.
2. NVFP4 GEMM + activation_topk (2-kernel path):
Nvfp4Linear (Blackwell tensor cores) + fused activation_topk CUDA kernel.
3. BF16 cuBLAS fallback: When NVFP4 scales are not available in the
checkpoint, dense_router_dispatch uses torch.nn.functional.linear
(cuBLAS, SM100 tensor cores) instead.
"""
from __future__ import annotations
from typing import Tuple, Optional
import torch
def dense_router_dispatch(
hidden_states: torch.Tensor, # [N, hidden_size] BF16
W_gate: torch.Tensor, # [hidden_size, num_experts] BF16
e_bias: torch.Tensor, # [num_experts] FP32
routed_scaling_factor: float,
top_k: int,
out_weights: torch.Tensor, # [N, top_k] FP32, pre-allocated
out_ids: torch.Tensor, # [N, top_k] int32, pre-allocated
):
"""Dispatch the dense router (BF16 cuBLAS fallback).
BF16 GEMM via torch.nn.functional.linear (cuBLAS, SM100 tensor cores),
then fused activation + top-k via the CUDA kernel.
"""
logits = torch.nn.functional.linear(hidden_states.float(), W_gate.T.float())
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
run_fused_activation_topk(
logits, e_bias, routed_scaling_factor, top_k,
out_weights, out_ids,
)
def dense_router_dispatch_nvfp4(
hidden_states: torch.Tensor, # [N, hidden_size] BF16
gate_lin, # Nvfp4Linear instance
e_bias: torch.Tensor, # [num_experts] FP32
routed_scaling_factor: float,
top_k: int,
out_weights: torch.Tensor, # [N, top_k] FP32, pre-allocated
out_ids: torch.Tensor, # [N, top_k] int32, pre-allocated
):
"""Dispatch the dense router (NVFP4 production GEMM, 2-kernel path).
NVFP4 GEMM via Nvfp4Linear (Blackwell SM100 tensor cores),
then fused activation + top-k via the CUDA kernel.
"""
logits = gate_lin(hidden_states).float() # (N, E) FP32
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
run_fused_activation_topk(
logits, e_bias, routed_scaling_factor, top_k,
out_weights, out_ids,
)
def dense_router_dispatch_nvfp4_fused(
hidden_states: torch.Tensor, # [N, hidden_size] BF16
gate_weight: torch.Tensor, # [K_packed, E_packed] uint8 NVFP4 weight
gate_weight_scale: torch.Tensor, # [K_sf, E_sf] FP8 E4M3 weight scale
gate_ws2: torch.Tensor, # weight_scale_2 (scalar or per-output)
gate_input_scale: torch.Tensor, # input_scale (activation global scale base)
e_bias: torch.Tensor, # [num_experts] FP32
routed_scaling_factor: float,
top_k: int,
out_weights: torch.Tensor, # [N, top_k] FP32, pre-allocated
out_ids: torch.Tensor, # [N, top_k] int32, pre-allocated
):
"""Dispatch the dense router (NVFP4 fused single-kernel path).
Phase 1: CuTeDSL NVFP4 blockscaled GEMM + sqrt(softplus) epilogue.
Activation is quantized to NVFP4, GEMM runs on Blackwell tensor cores,
sqrt(softplus) is fused in the epilogue (TMEM→regs→activation→SMEM→GMEM).
Writes FP32 activated scores to GMEM. No intermediate BF16 logits.
Phase 2: top-k + renorm on activated scores.
"""
from dsv4.kernels.router.nvfp4_fused_router_kernel import run_nvfp4_fused_router
gsa = gate_input_scale.float().item() if gate_input_scale.numel() == 1 else gate_input_scale.float().mean().item()
gsb_val = gate_ws2.float().item() if gate_ws2.numel() == 1 else gate_ws2.float().mean().item()
result_w, result_ids = run_nvfp4_fused_router(
hidden_states=hidden_states,
mat_b=gate_weight,
scale_b=gate_weight_scale,
gsa=gsa,
gsb_val=gsb_val,
e_bias=e_bias,
routed_scaling_factor=routed_scaling_factor,
top_k=top_k,
)
N = hidden_states.shape[0]
out_weights[:N].copy_(result_w[:N])
out_ids[:N].copy_(result_ids[:N])