Files
nvfp4-megamoe-kernel/dsv4/ops/gemm_runner.py

544 lines
22 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""NVFP4 GEMM runner: warmup, compile, and execute grouped/fused GEMMs."""
import math
import torch
import cutlass
import cutlass.cute as cute
import cutlass.torch as cutlass_torch
import cutlass.utils as utils
from dsv4.kernels.gemm.grouped import ScaledGroupedGemmKernel
from dsv4.kernels.gemm.fused_swiglu import FusedSwiGLUScaledGroupedGemmKernel
from dsv4.ops.quantize import (
quantize_activation_nvfp4,
quantize_weight_to_nvfp4,
quantize_to_nvfp4,
deinterleave_quantize_nvfp4_cuda,
SF_VEC_SIZE,
)
from dsv4.ops.layouts import (
interleave_l1_weights,
deinterleave_l1_weights,
assemble_scales_2d_side,
assemble_scales_3d_side,
make_b_k_major,
compute_expert_offsets,
ceil_div,
round_up,
)
# Cache compiled kernels + pre-allocated workspace by cache_key
# Each entry: {'compiled': callable, 'workspace': Tensor, 'workspace_size': int}
#
# Key design decisions (Bug #1 fix):
# - cute.compile does NOT corrupt GPU memory (verified 2026-05-20 on B200).
# The original _needs_token_refill hack was a misdiagnosis. The real bug
# was elsewhere (likely OOB write or weight loading).
# - Workspace is pre-allocated per cache entry during warmup_compilation()
# and reused on subsequent calls. No torch.full() in the hot path.
# - CuTe tensor wrappers (from_dlpack + mark_layout_dynamic) are cheap
# metadata wrappers. We re-create them per call from real tensors.
# Caching them would hold stale references to tensors that get freed.
_compiled_kernel_cache = {}
def warmup_compilation(num_experts, K_packed, N_packed, device,
mma_tiler_mn=(128, 128), cluster_shape_mn=(1, 1)):
"""Eagerly JIT-compile the GEMM kernel for a specific shape.
Call this BEFORE model weights are loaded to ensure cute.compile
runs exactly once per shape. The compiled kernel is cached and
reused by run_nvfp4_grouped_gemm on the forward path.
Uses random non-zero data. Zero-filled FP4/FP8 tensors cause
cudaErrorIllegalInstruction because the GEMM arithmetic hits
invalid values (division by zero in scale dequantization, NaN
propagation). Random data produces valid intermediate results
that exercise the kernel's full arithmetic path.
The warmup tensors are freed immediately after compilation.
Memory cost is minimal (~50MB for typical DeepSeek-V4 shapes).
Args:
num_experts: number of experts (local, after expert parallelism)
K_packed: K dimension in float4 elements (i.e. K_original // 2)
N_packed: N dimension in float4 elements (i.e. N_original // 2)
device: 'cuda:X' or 'cuda'
mma_tiler_mn: GEMM tiling (default (128,128))
cluster_shape_mn: cluster shape (default (1,1))
"""
cache_key = (num_experts, str(device), mma_tiler_mn, cluster_shape_mn,
K_packed, N_packed)
if cache_key in _compiled_kernel_cache:
return # Already compiled
# Generate VALID FP4 data by quantizing random BF16 through our pipeline.
# Random uint8 bytes as FP4 bit patterns produce NaN/Inf when the GEMM
# dequantizes them (FP4 value * FP8 scale * FP32 global scale), which
# causes cudaErrorIllegalInstruction in the Blackwell MMA hardware.
# The ONLY safe approach: generate random BF16, quantize through our
# quantize_to_nvfp4, producing mathematically consistent FP4 + FP8 scales.
_warmup_a_bf16 = torch.randn(128, K_packed * 2, dtype=torch.bfloat16, device=device) * 0.1
mat_a, scale_a, _ = quantize_to_nvfp4(_warmup_a_bf16)
del _warmup_a_bf16
_warmup_b_bf16 = torch.randn(1, K_packed * 2, N_packed * 2, dtype=torch.bfloat16, device=device) * 0.1 # 1 expert: kernel compiles same regardless of count
mat_b, scale_b, _ = quantize_to_nvfp4(_warmup_b_bf16)
del _warmup_b_bf16
out = torch.zeros(128, N_packed, dtype=torch.bfloat16, device=device)
expert_offsets = torch.full((1,), 128, dtype=torch.int32, device=device)
global_scale_a = torch.ones(1, dtype=torch.float32, device=device)
global_scale_b = torch.ones(1, dtype=torch.float32, device=device)
kernel = ScaledGroupedGemmKernel(
scenario="2Dx3D",
sf_vec_size=SF_VEC_SIZE,
accumulate_on_output=False,
separate_tensormap_init=True,
consistent_token_padding=False,
mma_tiler_mnk=(*mma_tiler_mn, 256),
cluster_shape_mnk=(*cluster_shape_mn, 1),
)
def to_cute(t):
# Fix: from_dlpack checks torch.cuda.current_device() against tensor device.
# Inside CUDA graph capture on non-default GPUs, current_device() may not match.
# We temporarily patch current_device to return the tensor's device index.
# This is safe because during graph capture, the device is logically fixed.
_orig_cd = torch.cuda.current_device
if t.is_cuda and t.device.index != _orig_cd():
torch.cuda.current_device = lambda: t.device.index
ct = cutlass_torch.from_dlpack(t)
torch.cuda.current_device = _orig_cd
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)
b_c = to_cute(mat_b)
sfa_c = to_cute(scale_a)
sfb_c = to_cute(scale_b)
c_c = to_cute(out)
offs_c = to_cute(expert_offsets)
workspace_size = kernel.get_workspace_size(num_experts)
workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=device)
ws_c = to_cute(workspace)
gsa_c = to_cute(global_scale_a)
gsb_c = to_cute(global_scale_b)
import cuda.bindings.driver as cuda
cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1]
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
compiled = cute.compile(
kernel, a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
max_active_clusters, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
# Warmup run (required by CuTeDSL)
compiled(
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
torch.cuda.synchronize()
# Cache compiled kernel + pre-allocated workspace
# (NOT CuTe wrappers — they hold refs to dummy tensors)
_compiled_kernel_cache[cache_key] = {
'compiled': compiled,
'workspace': workspace, # pre-allocated, reused
'workspace_size': workspace_size,
}
# Free dummy data tensors (workspace is kept in cache)
del mat_a, mat_b, scale_a, scale_b, out, expert_offsets
del global_scale_a, global_scale_b
del a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, gsa_c, gsb_c
torch.cuda.empty_cache()
def run_nvfp4_grouped_gemm(
mat_a, # (tokens_sum, K_packed) float4_e2m1fn_x2
mat_b, # (experts, K_packed, N_packed) float4_e2m1fn_x2, K-major
scale_a, # assembled 2D side (padded + swizzled)
scale_b, # assembled 3D side (padded + swizzled)
expert_offsets, # (experts,) int32 cumulative token offsets
global_scale_a=None, # (experts,) float32
global_scale_b=None, # (experts,) float32
mma_tiler_mn=(128, 128),
cluster_shape_mn=(1, 1),
out=None, # pre-allocated output buffer for CUDA graph capture
):
"""Run the CuTeDSL NVFP4 scaled grouped GEMM.
2Dx3D: A(tokens, K) x B(experts, K, N) -> C(tokens, N)
Kernel is compiled once (either via warmup_compilation() during init
or lazily on first call), then cached. Workspace is pre-allocated
during warmup and reused — no torch.full() in the hot path.
No torch.cuda.synchronize() or .item() in the forward path.
"""
num_experts = mat_b.shape[0]
n_dim = mat_b.shape[2]
tokens_sum = mat_a.shape[0]
if out is None:
out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device)
else:
out.zero_()
# NVFP4-3: use 2-CTA UMMA for M>=256 (1.7-1.9× throughput at prefill)
use_2cta = tokens_sum >= 256 and cluster_shape_mn[0] % 2 == 0
effective_mma_tiler_mn = (
(mma_tiler_mn[0] * 2, mma_tiler_mn[1]) if use_2cta else mma_tiler_mn
)
effective_cluster_shape_mn = (
(cluster_shape_mn[0] * 2, cluster_shape_mn[1]) if use_2cta else cluster_shape_mn
)
cache_key = (num_experts, str(mat_a.device), effective_mma_tiler_mn, effective_cluster_shape_mn,
mat_a.shape[1], mat_b.shape[2], use_2cta)
if cache_key not in _compiled_kernel_cache:
# Lazy compilation — safety net if warmup_compilation wasn't called.
# This should NOT happen in production (warmup is called during init).
kernel = ScaledGroupedGemmKernel(
scenario="2Dx3D",
sf_vec_size=SF_VEC_SIZE,
accumulate_on_output=False,
separate_tensormap_init=True,
consistent_token_padding=False,
mma_tiler_mnk=(*effective_mma_tiler_mn, 256),
cluster_shape_mnk=(*effective_cluster_shape_mn, 1),
use_2cta_instrs=use_2cta,
)
def to_cute(t):
_orig_cd = torch.cuda.current_device
if t.is_cuda and t.device.index != _orig_cd():
torch.cuda.current_device = lambda: t.device.index
ct = cutlass_torch.from_dlpack(t)
torch.cuda.current_device = _orig_cd
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)
b_c = to_cute(mat_b)
sfa_c = to_cute(scale_a)
sfb_c = to_cute(scale_b)
c_c = to_cute(out)
offs_c = to_cute(expert_offsets)
workspace_size = kernel.get_workspace_size(num_experts)
workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=mat_a.device)
ws_c = to_cute(workspace)
gsa_c = to_cute(global_scale_a) if global_scale_a is not None else None
gsb_c = to_cute(global_scale_b) if global_scale_b is not None else None
import cuda.bindings.driver as cuda
cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1]
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
compiled = cute.compile(
kernel, a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
max_active_clusters, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
compiled(
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
_compiled_kernel_cache[cache_key] = {
'compiled': compiled,
'workspace': workspace,
'workspace_size': workspace_size,
}
# --- Invoke cached kernel ---
entry = _compiled_kernel_cache[cache_key]
compiled = entry['compiled']
workspace = entry['workspace']
# Re-create CuTe wrappers from current tensors each call.
# This is cheap (metadata only, no GPU work) and avoids stale
# references to tensors from previous calls that may have been freed.
def to_cute(t):
# Fix: from_dlpack checks torch.cuda.current_device() against tensor device.
# Inside CUDA graph capture on non-default GPUs, current_device() may not match.
# We temporarily patch current_device to return the tensor's device index.
# This is safe because during graph capture, the device is logically fixed.
_orig_cd = torch.cuda.current_device
if t.is_cuda and t.device.index != _orig_cd():
torch.cuda.current_device = lambda: t.device.index
ct = cutlass_torch.from_dlpack(t)
torch.cuda.current_device = _orig_cd
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)
b_c = to_cute(mat_b)
sfa_c = to_cute(scale_a)
sfb_c = to_cute(scale_b)
c_c = to_cute(out)
offs_c = to_cute(expert_offsets)
ws_c = to_cute(workspace)
gsa_c = to_cute(global_scale_a) if global_scale_a is not None else None
gsb_c = to_cute(global_scale_b) if global_scale_b is not None else None
import cuda.bindings.driver as cuda
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
compiled(
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
return out
# ── Fused SwiGLU GEMM (Stage 1: SiLU in registers, BF16 output) ──────
_fused_kernel_cache = {}
def warmup_fused_swiglu_compilation(num_experts, K_packed, N_packed, device,
swiglu_limit=0.0,
mma_tiler_mn=(128, 128),
cluster_shape_mn=(1, 1)):
"""Eagerly JIT-compile the fused SwiGLU GEMM kernel.
Must be called during model initialization. See warmup_compilation()
for the standard GEMM equivalent.
"""
from dsv4.kernels.gemm.fused_swiglu import FusedSwiGLUScaledGroupedGemmKernel
cache_key = ('fused', num_experts, str(device), mma_tiler_mn, cluster_shape_mn,
K_packed, N_packed, swiglu_limit)
if cache_key in _fused_kernel_cache:
return
# Generate VALID FP4 data by quantizing random BF16 (same as warmup_compilation)
_warmup_a_bf16 = torch.randn(128, K_packed * 2, dtype=torch.bfloat16, device=device) * 0.1
mat_a, scale_a, _ = quantize_to_nvfp4(_warmup_a_bf16)
del _warmup_a_bf16
_warmup_b_bf16 = torch.randn(1, K_packed * 2, N_packed * 2, dtype=torch.bfloat16, device=device) * 0.1 # 1 expert: kernel compiles same regardless of count
mat_b, scale_b, _ = quantize_to_nvfp4(_warmup_b_bf16)
del _warmup_b_bf16
# BF16 output (Stage 1: we still write BF16)
# The fused kernel writes intermediate (N/2) since gate+up → silu result
out = torch.zeros(128, N_packed, dtype=torch.bfloat16, device=device)
expert_offsets = torch.full((1,), 128, dtype=torch.int32, device=device)
global_scale_a = torch.ones(1, dtype=torch.float32, device=device)
global_scale_b = torch.ones(1, dtype=torch.float32, device=device)
# NVFP4-3: Enable 2-CTA UMMA when MMA tile M >= 256 and cluster M is even.
# 2-CTA gives 1.7-1.9× throughput at prefill shapes by doubling effective MMA tile M.
# Not beneficial at decode (M < 256) — would waste hardware.
use_2cta = mma_tiler_mn[0] >= 256 and cluster_shape_mn[0] % 2 == 0
kernel = FusedSwiGLUScaledGroupedGemmKernel(
scenario="2Dx3D",
sf_vec_size=SF_VEC_SIZE,
accumulate_on_output=False,
separate_tensormap_init=True,
consistent_token_padding=False,
mma_tiler_mnk=(*mma_tiler_mn, 256),
cluster_shape_mnk=(*cluster_shape_mn, 1),
fused_swiglu=True,
swiglu_limit=swiglu_limit,
use_2cta_instrs=use_2cta,
)
def to_cute(t):
# Fix: from_dlpack checks torch.cuda.current_device() against tensor device.
# Inside CUDA graph capture on non-default GPUs, current_device() may not match.
# We temporarily patch current_device to return the tensor's device index.
# This is safe because during graph capture, the device is logically fixed.
_orig_cd = torch.cuda.current_device
if t.is_cuda and t.device.index != _orig_cd():
torch.cuda.current_device = lambda: t.device.index
ct = cutlass_torch.from_dlpack(t)
torch.cuda.current_device = _orig_cd
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)
b_c = to_cute(mat_b)
sfa_c = to_cute(scale_a)
sfb_c = to_cute(scale_b)
c_c = to_cute(out)
offs_c = to_cute(expert_offsets)
workspace_size = kernel.get_workspace_size(num_experts)
workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=device)
ws_c = to_cute(workspace)
gsa_c = to_cute(global_scale_a)
gsb_c = to_cute(global_scale_b)
import cuda.bindings.driver as cuda
cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1]
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
compiled = cute.compile(
kernel, a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
max_active_clusters, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
compiled(
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
torch.cuda.synchronize()
_fused_kernel_cache[cache_key] = {
'compiled': compiled,
'workspace': workspace,
'workspace_size': workspace_size,
}
del mat_a, mat_b, scale_a, scale_b, out, expert_offsets
del global_scale_a, global_scale_b
del a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, gsa_c, gsb_c
torch.cuda.empty_cache()
def run_fused_swiglu_grouped_gemm(
mat_a, # (tokens_sum, K_packed) float4_e2m1fn_x2
mat_b, # (experts, K_packed, N_packed) float4_e2m1fn_x2, K-major
scale_a, # assembled 2D side (padded + swizzled)
scale_b, # assembled 3D side (padded + swizzled)
expert_offsets, # (experts,) int32 cumulative token offsets
global_scale_a=None, # (experts,) float32
global_scale_b=None, # (experts,) float32
swiglu_limit=0.0,
mma_tiler_mn=(128, 128),
cluster_shape_mn=(1, 1),
out=None, # pre-allocated output buffer for CUDA graph capture
):
"""Run the fused SwiGLU NVFP4 scaled grouped GEMM.
Stage 1: SiLU is applied to the full accumulator in registers,
then written as BF16 to C. Gate/up pairing is not yet implemented.
"""
from dsv4.kernels.gemm.fused_swiglu import FusedSwiGLUScaledGroupedGemmKernel
num_experts = mat_b.shape[0]
n_dim = mat_b.shape[2]
tokens_sum = mat_a.shape[0]
if out is None:
out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device)
else:
out.zero_()
# NVFP4-3: use 2-CTA UMMA for M>=256 (1.7-1.9× throughput at prefill)
# At decode (M<256), 1-CTA is correct (2-CTA wastes hardware)
use_2cta = tokens_sum >= 256 and cluster_shape_mn[0] % 2 == 0
effective_mma_tiler_mn = (
(mma_tiler_mn[0] * 2, mma_tiler_mn[1]) if use_2cta else mma_tiler_mn
)
effective_cluster_shape_mn = (
(cluster_shape_mn[0] * 2, cluster_shape_mn[1]) if use_2cta else cluster_shape_mn
)
cache_key = ('fused', num_experts, str(mat_a.device), effective_mma_tiler_mn, effective_cluster_shape_mn,
mat_a.shape[1], mat_b.shape[2], swiglu_limit, use_2cta)
if cache_key not in _fused_kernel_cache:
# Lazy compilation
kernel = FusedSwiGLUScaledGroupedGemmKernel(
scenario="2Dx3D",
sf_vec_size=SF_VEC_SIZE,
accumulate_on_output=False,
separate_tensormap_init=True,
consistent_token_padding=False,
mma_tiler_mnk=(*effective_mma_tiler_mn, 256),
cluster_shape_mnk=(*effective_cluster_shape_mn, 1),
use_2cta_instrs=use_2cta,
fused_swiglu=True,
swiglu_limit=swiglu_limit,
)
def to_cute(t):
_orig_cd = torch.cuda.current_device
if t.is_cuda and t.device.index != _orig_cd():
torch.cuda.current_device = lambda: t.device.index
ct = cutlass_torch.from_dlpack(t)
torch.cuda.current_device = _orig_cd
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)
b_c = to_cute(mat_b)
sfa_c = to_cute(scale_a)
sfb_c = to_cute(scale_b)
c_c = to_cute(out)
offs_c = to_cute(expert_offsets)
workspace_size = kernel.get_workspace_size(num_experts)
workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=mat_a.device)
ws_c = to_cute(workspace)
gsa_c = to_cute(global_scale_a) if global_scale_a is not None else None
gsb_c = to_cute(global_scale_b) if global_scale_b is not None else None
import cuda.bindings.driver as cuda
cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1]
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
compiled = cute.compile(
kernel, a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
max_active_clusters, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
compiled(
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
_fused_kernel_cache[cache_key] = {
'compiled': compiled,
'workspace': workspace,
'workspace_size': workspace_size,
}
entry = _fused_kernel_cache[cache_key]
compiled = entry['compiled']
workspace = entry['workspace']
def to_cute(t):
# Fix: from_dlpack checks torch.cuda.current_device() against tensor device.
# Inside CUDA graph capture on non-default GPUs, current_device() may not match.
# We temporarily patch current_device to return the tensor's device index.
# This is safe because during graph capture, the device is logically fixed.
_orig_cd = torch.cuda.current_device
if t.is_cuda and t.device.index != _orig_cd():
torch.cuda.current_device = lambda: t.device.index
ct = cutlass_torch.from_dlpack(t)
torch.cuda.current_device = _orig_cd
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)
b_c = to_cute(mat_b)
sfa_c = to_cute(scale_a)
sfb_c = to_cute(scale_b)
c_c = to_cute(out)
offs_c = to_cute(expert_offsets)
ws_c = to_cute(workspace)
gsa_c = to_cute(global_scale_a) if global_scale_a is not None else None
gsb_c = to_cute(global_scale_b) if global_scale_b is not None else None
import cuda.bindings.driver as cuda
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
compiled(
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
return out