Files
nvfp4-megamoe-kernel/cutedsl/bridge.py

372 lines
13 KiB
Python
Raw Normal View History

"""
Bridge layer for the CuTeDSL NVFP4 MoE kernel.
Handles tensor layout conversion from our pipeline's format to what
the ScaledGroupedGemmKernel expects:
- BF16 NVFP4 quantization (float4_e2m1fn_x2)
- Scale factor assembly (padding + swizzle)
- B tensor K-major stride conversion
- Expert offset computation
"""
import math
import torch
import cutlass
import cutlass.cute as cute
import cutlass.torch as cutlass_torch
import cutlass.utils as utils
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
ScaledGroupedGemmKernel,
pad_and_swizzle_single,
assemble_raw_scales_2d3d_2d_side,
assemble_raw_scales_2d3d_3d_side,
cat_byte_reinterpretable_tensors,
stack_byte_reinterpretable_tensors,
)
# ── Constants ──────────────────────────────────────────────────────────
E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
# Cache compiled kernels by (num_experts, device, K, N)
_compiled_kernel_cache = {}
# Cached LUT for E2M1 quantization (created once per device, cudagraph-safe)
_NVFP4_STEP_LUT_CACHE = {}
def _get_step_to_idx_lut(device):
"""Get or create the E2M1 step-to-index LUT for the given device.
Cached per device to avoid CPU->CUDA copies during cudagraph capture.
Must be pre-populated during warmup (before torch.compile/cudagraph capture)
so the lock is never entered on the compiled path.
"""
# Fast path: already cached — no lock needed (torch.compile-safe)
if device in _NVFP4_STEP_LUT_CACHE:
return _NVFP4_STEP_LUT_CACHE[device]
# Slow path: first call, create the LUT
lut = torch.as_tensor(
[0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 7, 7],
dtype=torch.int8, device=device,
)
_NVFP4_STEP_LUT_CACHE[device] = lut
return lut
SF_VEC_SIZE = 16 # NVFP4 block size
def ceil_div(a, b):
return (a + b - 1) // b
def round_up(a, b):
return ceil_div(a, b) * b
# ── Quantization ──────────────────────────────────────────────────────
def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE):
"""Quantize BF16 tensor to NVFP4.
Args:
x_bf16: (..., D) BF16 tensor
Returns:
x_fp4: (..., D//2) float4_e2m1fn_x2 native PyTorch FP4
x_sf: (..., D//16) float8_e4m3fn block scales
global_scale: float32 scalar
"""
x_f32 = x_bf16.float()
amax = x_f32.abs().max().clamp(min=1e-8).float()
global_scale = amax / (6.0 * 448.0)
x_norm = x_f32 / global_scale
last_dim = x_norm.shape[-1]
n_blocks = ceil_div(last_dim, block_size)
if last_dim % block_size != 0:
pad_size = n_blocks * block_size - last_dim
x_norm = torch.nn.functional.pad(x_norm, (0, pad_size))
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
block_amax = x_reshaped.abs().amax(dim=-1).clamp(min=1e-8)
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
# Nearest E2M1
block_sf_expanded = block_scale.float().unsqueeze(-1)
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)
signs = torch.sign(x_scaled)
abs_scaled = x_scaled.abs().clamp(max=6.0)
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
step_to_idx = _get_step_to_idx_lut(x_bf16.device)
indices = step_to_idx[half_steps.long()]
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
even = nibbles[..., ::2]
odd = nibbles[..., 1::2]
packed = (odd << 4) | even
packed_shape = list(x_bf16.shape)
packed_shape[-1] = last_dim // 2
x_fp4 = packed.view(torch.float4_e2m1fn_x2).reshape(packed_shape)
sf_shape = list(x_bf16.shape[:-1]) + [n_blocks]
block_scale = block_scale.reshape(sf_shape)
return x_fp4, block_scale, global_scale
def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE):
"""Quantize BF16 activation tensor to NVFP4 (cudagraph-safe).
Unlike quantize_to_nvfp4(), this takes a pre-computed global_scale
instead of computing it via .max() (which forces CPU-GPU sync).
All operations are pure GPU with no CPU-GPU syncs.
Args:
x_bf16: (..., D) BF16 tensor
global_scale: float32 scalar (pre-computed, NOT from .max())
block_size: NVFP4 block size
Returns:
x_fp4: (..., D//2) float4_e2m1fn_x2
x_sf: (..., D//16) float8_e4m3fn
"""
x_f32 = x_bf16.float()
x_norm = x_f32 / global_scale
last_dim = x_norm.shape[-1]
n_blocks = ceil_div(last_dim, block_size)
if last_dim % block_size != 0:
pad_size = n_blocks * block_size - last_dim
x_norm = torch.nn.functional.pad(x_norm, (0, pad_size))
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
block_amax = x_reshaped.abs().amax(dim=-1).clamp(min=1e-8)
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
block_sf_expanded = block_scale.float().unsqueeze(-1)
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)
signs = torch.sign(x_scaled)
abs_scaled = x_scaled.abs().clamp(max=6.0)
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
step_to_idx = _get_step_to_idx_lut(x_bf16.device)
indices = step_to_idx[half_steps.long()]
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
even = nibbles[..., ::2]
odd = nibbles[..., 1::2]
packed = (odd << 4) | even
packed_shape = list(x_bf16.shape)
packed_shape[-1] = last_dim // 2
x_fp4 = packed.view(torch.float4_e2m1fn_x2).reshape(packed_shape)
sf_shape = list(x_bf16.shape[:-1]) + [n_blocks]
block_scale = block_scale.reshape(sf_shape)
2026-05-16 20:40:18 +00:00
return x_fp4, block_scale
def quantize_weight_to_nvfp4(w_bf16, block_size=SF_VEC_SIZE):
"""Quantize BF16 weight matrix to NVFP4.
The weight is (K, N) where K is the input dim (packed dimension).
Block scales are computed along K (dim 0).
Args:
w_bf16: (K, N) BF16 weight matrix
Returns:
w_fp4: (K//2, N) float4_e2m1fn_x2 K is the packed dim
w_sf: (K//16, N) float8_e4m3fn block scales along K
global_scale: float32 scalar
"""
K, N = w_bf16.shape
w_f32 = w_bf16.float()
amax = w_f32.abs().max().clamp(min=1e-8).float()
global_scale = amax / (6.0 * 448.0)
w_norm = w_f32 / global_scale
k_blocks = ceil_div(K, block_size)
if K % block_size != 0:
w_norm = torch.nn.functional.pad(w_norm, (0, 0, 0, k_blocks * block_size - K))
w_reshaped = w_norm.reshape(k_blocks, block_size, N)
w_block_amax = w_reshaped.abs().amax(dim=1).clamp(min=1e-8)
w_sf = (w_block_amax / 6.0).to(torch.float8_e4m3fn)
w_block_sf = w_sf.float().unsqueeze(1)
w_scaled = w_reshaped / w_block_sf.clamp(min=1e-8)
signs = torch.sign(w_scaled)
abs_scaled = w_scaled.abs().clamp(max=6.0)
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
step_to_idx = _get_step_to_idx_lut(w_bf16.device)
indices = step_to_idx[half_steps.long()]
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
even = nibbles[:, ::2, :]
odd = nibbles[:, 1::2, :]
packed = (odd << 4) | even
w_fp4 = packed.reshape(K // 2, N).view(torch.float4_e2m1fn_x2)
return w_fp4, w_sf, global_scale
# ── Scale Factor Assembly ─────────────────────────────────────────────
def assemble_scales_2d_side(raw_scales):
"""Assemble activation scale factors for the 2Dx3D scenario.
Args:
raw_scales: list of (M_e, K_sf) float8_e4m3fn tensors, one per expert
Returns:
Assembled and swizzled scale tensor
"""
return assemble_raw_scales_2d3d_2d_side(raw_scales)
def assemble_scales_3d_side(raw_scales):
"""Assemble weight scale factors for the 2Dx3D scenario.
Args:
raw_scales: list of (K_sf, N) float8_e4m3fn tensors, one per expert
NOTE: These will be transposed to (N, K_sf) before swizzling,
since the kernel expects N as the non-K dimension.
Returns:
Assembled and swizzled scale tensor
"""
# Kernel expects (N, K_sf) — transpose before swizzling
transposed = [sf.T.contiguous() for sf in raw_scales]
return assemble_raw_scales_2d3d_3d_side(transposed)
# ── Tensor Layout Conversion ──────────────────────────────────────────
def make_b_k_major(b_tensor):
"""Convert B tensor from N-major to K-major layout.
The kernel expects B with stride (E*K*N, 1, K) K is contiguous.
torch.stack produces stride (E*K*N, N, 1) N is contiguous.
Args:
b_tensor: (experts, K_packed, N_packed) float4_e2m1fn_x2, N-major
Returns:
Same shape, K-major strides
"""
return b_tensor.permute(0, 2, 1).contiguous().permute(0, 2, 1)
def compute_expert_offsets(tokens_per_expert, num_experts, device="cuda"):
"""Compute cumulative token offsets for the grouped GEMM.
Args:
tokens_per_expert: list of int, one per expert
Returns:
offs: (num_experts,) int32 cumulative sum
"""
offs = torch.tensor(
[sum(tokens_per_expert[:e+1]) for e in range(num_experts)],
dtype=torch.int32, device=device,
)
return offs
# ── Kernel Launch ─────────────────────────────────────────────────────
def run_nvfp4_grouped_gemm(
mat_a, # (tokens_sum, K_packed) float4_e2m1fn_x2
mat_b, # (experts, K_packed, N_packed) float4_e2m1fn_x2, K-major
scale_a, # assembled 2D side (padded + swizzled)
scale_b, # assembled 3D side (padded + swizzled)
expert_offsets, # (experts,) int32 cumulative token offsets
global_scale_a=None, # (experts,) float32
global_scale_b=None, # (experts,) float32
mma_tiler_mn=(128, 128),
cluster_shape_mn=(1, 1),
):
"""Run the CuTeDSL NVFP4 scaled grouped GEMM.
2Dx3D: A(tokens, K) x B(experts, K, N) -> C(tokens, N)
CUDAGraph-compatible: kernel is compiled once on first call (warmup),
then only the compiled kernel is invoked on subsequent calls.
No torch.cuda.synchronize() or .item() in the forward path.
"""
num_experts = mat_b.shape[0]
n_dim = mat_b.shape[2] # packed N (in float4 elements)
tokens_sum = mat_a.shape[0]
out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device)
kernel = ScaledGroupedGemmKernel(
scenario="2Dx3D",
sf_vec_size=SF_VEC_SIZE,
accumulate_on_output=False,
separate_tensormap_init=True,
consistent_token_padding=False,
mma_tiler_mnk=(*mma_tiler_mn, 256),
cluster_shape_mnk=(*cluster_shape_mn, 1),
)
# Convert to CuTe tensors with dynamic layout
def to_cute(t):
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)
b_c = to_cute(mat_b)
sfa_c = to_cute(scale_a)
sfb_c = to_cute(scale_b)
c_c = to_cute(out)
offs_c = to_cute(expert_offsets)
workspace_size = kernel.get_workspace_size(num_experts)
workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=mat_a.device)
ws_c = to_cute(workspace)
gsa_c = to_cute(global_scale_a) if global_scale_a is not None else None
gsb_c = to_cute(global_scale_b) if global_scale_b is not None else None
import cuda.bindings.driver as cuda
cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1]
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
# Cache compiled kernel by (num_experts, K, N)
cache_key = (num_experts, str(mat_a.device), mma_tiler_mn, cluster_shape_mn,
mat_a.shape[1], mat_b.shape[2])
if cache_key not in _compiled_kernel_cache:
# First call: compile with real tensors
compiled = cute.compile(
kernel, a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
max_active_clusters, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
# Warmup run
compiled(
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
_compiled_kernel_cache[cache_key] = compiled
else:
# Subsequent calls: just invoke the compiled kernel
compiled = _compiled_kernel_cache[cache_key]
compiled(
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
return out