Files
nvfp4-megamoe-kernel/cutedsl/bridge.py
biondizzle ecc7b83334 fix: compile CuTeDSL kernel with actual tensor shapes, not dummy 256x256
The compiled kernel's TMA descriptors are sized based on compilation
shapes. Using dummy 256x256 shapes caused wrong memory access patterns
for the real 3584x6144 data. Now uses actual K_packed and N_packed
from the runtime tensors.
2026-05-16 19:58:13 +00:00

433 lines
16 KiB
Python

"""
Bridge layer for the CuTeDSL NVFP4 MoE kernel.
Handles tensor layout conversion from our pipeline's format to what
the ScaledGroupedGemmKernel expects:
- BF16 → NVFP4 quantization (float4_e2m1fn_x2)
- Scale factor assembly (padding + swizzle)
- B tensor K-major stride conversion
- Expert offset computation
CUDA-graph-compatible: no .item() calls, no torch.cuda.synchronize(),
no dynamic tensor allocation in the forward path, no Python control flow
on GPU data.
"""
import math
import threading
# Cached LUT for E2M1 quantization (created once per device, cudagraph-safe)
_NVFP4_STEP_LUT_CACHE = {}
_NVFP4_STEP_LUT_LOCK = threading.Lock()
def _get_step_to_idx_lut(device):
"""Get or create the E2M1 step-to-index LUT for the given device.
Cached per device to avoid CPU→CUDA copies during cudagraph capture.
"""
with _NVFP4_STEP_LUT_LOCK:
if device not in _NVFP4_STEP_LUT_CACHE:
_NVFP4_STEP_LUT_CACHE[device] = torch.as_tensor(
[0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 7, 7],
dtype=torch.int8, device=device,
)
return _NVFP4_STEP_LUT_CACHE[device]
import torch
import cutlass
import cutlass.cute as cute
import cutlass.torch as cutlass_torch
import cutlass.utils as utils
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
ScaledGroupedGemmKernel,
pad_and_swizzle_single,
assemble_raw_scales_2d3d_2d_side,
assemble_raw_scales_2d3d_3d_side,
cat_byte_reinterpretable_tensors,
stack_byte_reinterpretable_tensors,
)
# ── Constants ──────────────────────────────────────────────────────────
E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
SF_VEC_SIZE = 16 # NVFP4 block size
def ceil_div(a, b):
return (a + b - 1) // b
def round_up(a, b):
return ceil_div(a, b) * b
# ── Quantization ──────────────────────────────────────────────────────
def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE):
"""Quantize BF16 tensor to NVFP4.
NOTE: This function is NOT cudagraph-safe because it uses .max()
which forces a CPU-GPU sync. It should only be called during
weight preparation (offline), NOT during the forward pass.
For activation quantization during forward, use
quantize_activation_nvfp4() instead (cudagraph-safe, fixed global scale).
Args:
x_bf16: (..., D) BF16 tensor
Returns:
x_fp4: (..., D//2) float4_e2m1fn_x2 — native PyTorch FP4
x_sf: (..., D//16) float8_e4m3fn — block scales
global_scale: float32 scalar
"""
x_f32 = x_bf16.float()
amax = x_f32.abs().max().clamp(min=1e-8).float()
global_scale = amax / (6.0 * 448.0)
x_norm = x_f32 / global_scale
last_dim = x_norm.shape[-1]
n_blocks = ceil_div(last_dim, block_size)
if last_dim % block_size != 0:
pad_size = n_blocks * block_size - last_dim
x_norm = torch.nn.functional.pad(x_norm, (0, pad_size))
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
block_amax = x_reshaped.abs().amax(dim=-1).clamp(min=1e-8)
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
# Nearest E2M1 — memory-efficient clamp approach
block_sf_expanded = block_scale.float().unsqueeze(-1)
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)
signs = torch.sign(x_scaled)
abs_scaled = x_scaled.abs().clamp(max=6.0)
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
step_to_idx = _get_step_to_idx_lut(x_bf16.device)
indices = step_to_idx[half_steps.long()]
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
even = nibbles[..., ::2]
odd = nibbles[..., 1::2]
packed = (odd << 4) | even
packed_shape = list(x_bf16.shape)
packed_shape[-1] = last_dim // 2
x_fp4 = packed.view(torch.float4_e2m1fn_x2).reshape(packed_shape)
sf_shape = list(x_bf16.shape[:-1]) + [n_blocks]
block_scale = block_scale.reshape(sf_shape)
return x_fp4, block_scale, global_scale
def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE):
"""Quantize BF16 activation tensor to NVFP4 (cudagraph-safe).
Unlike quantize_to_nvfp4(), this takes a pre-computed global_scale
instead of computing it via .max() (which forces CPU-GPU sync).
The global_scale should be computed once during warmup and cached.
All operations are pure GPU with no CPU-GPU syncs.
Args:
x_bf16: (..., D) BF16 tensor
global_scale: float32 scalar (pre-computed, NOT from .max())
block_size: NVFP4 block size
Returns:
x_fp4: (..., D//2) float4_e2m1fn_x2
x_sf: (..., D//16) float8_e4m3fn
"""
x_f32 = x_bf16.float()
x_norm = x_f32 / global_scale
last_dim = x_norm.shape[-1]
n_blocks = ceil_div(last_dim, block_size)
if last_dim % block_size != 0:
pad_size = n_blocks * block_size - last_dim
x_norm = torch.nn.functional.pad(x_norm, (0, pad_size))
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
block_amax = x_reshaped.abs().amax(dim=-1).clamp(min=1e-8)
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
block_sf_expanded = block_scale.float().unsqueeze(-1)
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)
signs = torch.sign(x_scaled)
abs_scaled = x_scaled.abs().clamp(max=6.0)
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
step_to_idx = _get_step_to_idx_lut(x_bf16.device)
indices = step_to_idx[half_steps.long()]
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
even = nibbles[..., ::2]
odd = nibbles[..., 1::2]
packed = (odd << 4) | even
packed_shape = list(x_bf16.shape)
packed_shape[-1] = last_dim // 2
x_fp4 = packed.view(torch.float4_e2m1fn_x2).reshape(packed_shape)
sf_shape = list(x_bf16.shape[:-1]) + [n_blocks]
block_scale = block_scale.reshape(sf_shape)
return x_fp4, block_scale
def quantize_weight_to_nvfp4(w_bf16, block_size=SF_VEC_SIZE):
"""Quantize BF16 weight matrix to NVFP4.
The weight is (K, N) where K is the input dim (packed dimension).
Block scales are computed along K (dim 0).
Args:
w_bf16: (K, N) BF16 weight matrix
Returns:
w_fp4: (K//2, N) float4_e2m1fn_x2 — K is the packed dim
w_sf: (K//16, N) float8_e4m3fn — block scales along K
global_scale: float32 scalar
"""
K, N = w_bf16.shape
w_f32 = w_bf16.float()
amax = w_f32.abs().max().clamp(min=1e-8).float()
global_scale = amax / (6.0 * 448.0)
w_norm = w_f32 / global_scale
k_blocks = ceil_div(K, block_size)
if K % block_size != 0:
w_norm = torch.nn.functional.pad(w_norm, (0, 0, 0, k_blocks * block_size - K))
w_reshaped = w_norm.reshape(k_blocks, block_size, N)
w_block_amax = w_reshaped.abs().amax(dim=1).clamp(min=1e-8)
w_sf = (w_block_amax / 6.0).to(torch.float8_e4m3fn)
w_block_sf = w_sf.float().unsqueeze(1)
w_scaled = w_reshaped / w_block_sf.clamp(min=1e-8)
magnitudes = torch.tensor(E2M1_MAGNITUDES, dtype=torch.float32, device=w_bf16.device)
signs = torch.sign(w_scaled)
abs_scaled = w_scaled.abs().unsqueeze(-1)
distances = (abs_scaled - magnitudes).abs()
indices = distances.argmin(dim=-1)
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
even = nibbles[:, ::2, :]
odd = nibbles[:, 1::2, :]
packed = (odd << 4) | even
w_fp4 = packed.reshape(K // 2, N).view(torch.float4_e2m1fn_x2)
return w_fp4, w_sf, global_scale
# ── Scale Factor Assembly ─────────────────────────────────────────────
def assemble_scales_2d_side(raw_scales):
"""Assemble activation scale factors for the 2Dx3D scenario.
Args:
raw_scales: list of (M_e, K_sf) float8_e4m3fn tensors, one per expert
Returns:
Assembled and swizzled scale tensor
"""
return assemble_raw_scales_2d3d_2d_side(raw_scales)
def assemble_scales_3d_side(raw_scales):
"""Assemble weight scale factors for the 2Dx3D scenario.
Args:
raw_scales: list of (K_sf, N) float8_e4m3fn tensors, one per expert
NOTE: These will be transposed to (N, K_sf) before swizzling,
since the kernel expects N as the non-K dimension.
Returns:
Assembled and swizzled scale tensor
"""
# Kernel expects (N, K_sf) — transpose before swizzling
transposed = [sf.T.contiguous() for sf in raw_scales]
return assemble_raw_scales_2d3d_3d_side(transposed)
# ── Tensor Layout Conversion ──────────────────────────────────────────
def make_b_k_major(b_tensor):
"""Convert B tensor from N-major to K-major layout.
The kernel expects B with stride (E*K*N, 1, K) — K is contiguous.
torch.stack produces stride (E*K*N, N, 1) — N is contiguous.
Args:
b_tensor: (experts, K_packed, N_packed) float4_e2m1fn_x2, N-major
Returns:
Same shape, K-major strides
"""
return b_tensor.permute(0, 2, 1).contiguous().permute(0, 2, 1)
def compute_expert_offsets(tokens_per_expert, num_experts, device="cuda"):
"""Compute cumulative token offsets for the grouped GEMM.
Args:
tokens_per_expert: list of int, one per expert
Returns:
offs: (num_experts,) int32 — cumulative sum
"""
offs = torch.tensor(
[sum(tokens_per_expert[:e+1]) for e in range(num_experts)],
dtype=torch.int32, device=device,
)
return offs
# ── Compiled Kernel Cache ─────────────────────────────────────────────
_compiled_kernel_cache = {}
def _get_compiled_kernel(num_experts, device, mma_tiler_mn, cluster_shape_mn):
"""Get or compile the CuTeDSL grouped GEMM kernel (cached by shape config).
The kernel compilation is deterministic for a given (num_experts, device, tiler, cluster)
config, so we cache it to avoid recompiling on every forward call.
"""
cache_key = (num_experts, str(device), mma_tiler_mn, cluster_shape_mn, K_packed, N_packed)
if cache_key in _compiled_kernel_cache:
return _compiled_kernel_cache[cache_key]
kernel = ScaledGroupedGemmKernel(
scenario="2Dx3D",
sf_vec_size=SF_VEC_SIZE,
accumulate_on_output=False,
separate_tensormap_init=True,
consistent_token_padding=False,
mma_tiler_mnk=(*mma_tiler_mn, 256),
cluster_shape_mnk=(*cluster_shape_mn, 1),
)
import cuda.bindings.driver as cuda
cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1]
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
# We need dummy tensors to compile against — shapes must match runtime tensors
# The compiled kernel uses mark_layout_dynamic but TMA descriptors
# are sized based on the compilation shapes
K_packed = mat_a.shape[1] # actual K packed dimension
N_packed = mat_b.shape[2] # actual N dimension
tokens = 1
dummy_a = torch.zeros(tokens, K_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2)
dummy_b = torch.zeros(num_experts, K_packed, N_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2)
dummy_sfa = torch.zeros(1, 1, dtype=torch.float16, device=device).to(torch.float8_e4m3fn)
dummy_sfb = torch.zeros(1, 1, dtype=torch.float16, device=device).to(torch.float8_e4m3fn)
dummy_c = torch.zeros(tokens, N_packed, dtype=torch.bfloat16, device=device)
dummy_offs = torch.zeros(num_experts, dtype=torch.int32, device=device)
ws_size = kernel.get_workspace_size(num_experts)
dummy_ws = torch.full((ws_size,), 255, dtype=torch.uint8, device=device)
dummy_gsa = torch.ones(num_experts, dtype=torch.float32, device=device)
dummy_gsb = torch.ones(num_experts, dtype=torch.float32, device=device)
def to_cute(t):
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
compiled = cute.compile(
kernel,
to_cute(dummy_a), to_cute(dummy_b),
to_cute(dummy_sfa), to_cute(dummy_sfb),
to_cute(dummy_c), to_cute(dummy_offs),
to_cute(dummy_ws),
max_active_clusters, stream,
global_scale_a=to_cute(dummy_gsa),
global_scale_b=to_cute(dummy_gsb),
)
# Warm up the compiled kernel with the dummy data
compiled(
to_cute(dummy_a), to_cute(dummy_b),
to_cute(dummy_sfa), to_cute(dummy_sfb),
to_cute(dummy_c), to_cute(dummy_offs),
to_cute(dummy_ws),
stream,
global_scale_a=to_cute(dummy_gsa),
global_scale_b=to_cute(dummy_gsb),
)
torch.cuda.synchronize()
# Free dummies
del dummy_a, dummy_b, dummy_sfa, dummy_sfb, dummy_c, dummy_offs, dummy_ws, dummy_gsa, dummy_gsb
_compiled_kernel_cache[cache_key] = (compiled, kernel, max_active_clusters)
return compiled, kernel, max_active_clusters
# ── Kernel Launch ─────────────────────────────────────────────────────
def run_nvfp4_grouped_gemm(
mat_a, # (tokens_sum, K_packed) float4_e2m1fn_x2
mat_b, # (experts, K_packed, N_packed) float4_e2m1fn_x2, K-major
scale_a, # assembled 2D side (padded + swizzled)
scale_b, # assembled 3D side (padded + swizzled)
expert_offsets, # (experts,) int32 cumulative token offsets
global_scale_a=None, # (experts,) float32
global_scale_b=None, # (experts,) float32
mma_tiler_mn=(128, 128),
cluster_shape_mn=(1, 1),
):
"""Run the CuTeDSL NVFP4 scaled grouped GEMM.
2Dx3D: A(tokens, K) x B(experts, K, N) -> C(tokens, N)
CUDA-graph-compatible: uses cached compiled kernel, no synchronize(),
no cute.compile() in the forward path.
"""
num_experts = mat_b.shape[0]
n_dim = mat_b.shape[2] # N dimension (logical, not packed — float4_e2m1fn_x2 packs along K, not N)
tokens_sum = mat_a.shape[0]
device = mat_a.device
out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=device)
compiled, kernel, max_active_clusters = _get_compiled_kernel(
num_experts, device, mma_tiler_mn, cluster_shape_mn
)
# Convert to CuTe tensors with dynamic layout
def to_cute(t):
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)
b_c = to_cute(mat_b)
sfa_c = to_cute(scale_a)
sfb_c = to_cute(scale_b)
c_c = to_cute(out)
offs_c = to_cute(expert_offsets)
workspace_size = kernel.get_workspace_size(num_experts)
workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=device)
ws_c = to_cute(workspace)
gsa_c = to_cute(global_scale_a) if global_scale_a is not None else None
gsb_c = to_cute(global_scale_b) if global_scale_b is not None else None
import cuda.bindings.driver as cuda
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
compiled(
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
# NOTE: No torch.cuda.synchronize() here — cudagraph capture forbids it.
# The caller is responsible for any needed synchronization.
return out