Files
nvfp4-megamoe-kernel/cutedsl/bridge.py
2026-05-20 07:15:01 +00:00

817 lines
31 KiB
Python

"""
Bridge layer for the CuTeDSL NVFP4 MoE kernel.
Handles tensor layout conversion from our pipeline's format to what
the ScaledGroupedGemmKernel expects:
- BF16 → NVFP4 quantization (float4_e2m1fn_x2)
- Scale factor assembly (padding + swizzle)
- B tensor K-major stride conversion
- Expert offset computation
"""
import math
import torch
import cutlass
import cutlass.cute as cute
import cutlass.torch as cutlass_torch
import cutlass.utils as utils
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
ScaledGroupedGemmKernel,
pad_and_swizzle_single,
assemble_raw_scales_2d3d_2d_side,
assemble_raw_scales_2d3d_3d_side,
cat_byte_reinterpretable_tensors,
stack_byte_reinterpretable_tensors,
)
# ── Constants ──────────────────────────────────────────────────────────
E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
# Cache compiled kernels + pre-allocated workspace by cache_key
# Each entry: {'compiled': callable, 'workspace': Tensor, 'workspace_size': int}
#
# Key design decisions (Bug #1 fix):
# - cute.compile does NOT corrupt GPU memory (verified 2026-05-20 on B200).
# The original _needs_token_refill hack was a misdiagnosis. The real bug
# was elsewhere (likely OOB write or weight loading).
# - Workspace is pre-allocated per cache entry during warmup_compilation()
# and reused on subsequent calls. No torch.full() in the hot path.
# - CuTe tensor wrappers (from_dlpack + mark_layout_dynamic) are cheap
# metadata wrappers. We re-create them per call from real tensors.
# Caching them would hold stale references to tensors that get freed.
_compiled_kernel_cache = {}
# Cached LUT for E2M1 quantization (created once per device, cudagraph-safe)
_NVFP4_STEP_LUT_CACHE = {}
def _get_step_to_idx_lut(device):
"""Get or create the E2M1 step-to-index LUT for the given device.
Cached per device to avoid CPU->CUDA copies during cudagraph capture.
Must be pre-populated during warmup (before torch.compile/cudagraph capture)
so the lock is never entered on the compiled path.
"""
# Fast path: already cached — no lock needed (torch.compile-safe)
if device in _NVFP4_STEP_LUT_CACHE:
return _NVFP4_STEP_LUT_CACHE[device]
# Slow path: first call, create the LUT
lut = torch.as_tensor(
[0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 7, 7],
dtype=torch.int8, device=device,
)
_NVFP4_STEP_LUT_CACHE[device] = lut
return lut
SF_VEC_SIZE = 16 # NVFP4 block size
def ceil_div(a, b):
return (a + b - 1) // b
def round_up(a, b):
return ceil_div(a, b) * b
# ── Quantization ──────────────────────────────────────────────────────
def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE):
"""Quantize BF16 tensor to NVFP4.
Args:
x_bf16: (..., D) BF16 tensor
Returns:
x_fp4: (..., D//2) float4_e2m1fn_x2 — native PyTorch FP4
x_sf: (..., D//16) float8_e4m3fn — block scales
global_scale: float32 scalar
"""
x_f32 = x_bf16.float()
amax = x_f32.abs().max().clamp(min=1e-8).float()
global_scale = amax / (6.0 * 448.0)
x_norm = x_f32 / global_scale
last_dim = x_norm.shape[-1]
n_blocks = ceil_div(last_dim, block_size)
if last_dim % block_size != 0:
pad_size = n_blocks * block_size - last_dim
x_norm = torch.nn.functional.pad(x_norm, (0, pad_size))
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
block_amax = x_reshaped.abs().amax(dim=-1)
# Detect zero blocks and underflow blocks (amax > 0 but too small for FP8).
# Smallest positive FP8 e4m3fn is 2^-9 ≈ 1.95e-3. If amax/6 < this,
# the block scale underflows to 0, and dividing x by the clamped 1e-8
# inflates values into nonzero FP4 buckets — producing wrong results.
zero_block = block_amax < (6.0 * 2.0 ** -9) # < ~0.0117
# Zero out x for zero/underflow blocks before division.
# This ensures x_scaled = 0 → FP4 nibbles = 0.
x_reshaped = torch.where(zero_block.unsqueeze(-1),
torch.zeros_like(x_reshaped), x_reshaped)
block_amax = block_amax.clamp(min=1e-8)
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
# Force zero/underflow blocks: FP8 scale = 0 (exact zero).
block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale)
# Nearest E2M1
block_sf_expanded = block_scale.float().unsqueeze(-1)
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)
signs = torch.sign(x_scaled)
abs_scaled = x_scaled.abs().clamp(max=6.0)
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
step_to_idx = _get_step_to_idx_lut(x_bf16.device)
indices = step_to_idx[half_steps.long()]
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
even = nibbles[..., ::2]
odd = nibbles[..., 1::2]
packed = (odd << 4) | even
packed_shape = list(x_bf16.shape)
packed_shape[-1] = last_dim // 2
x_fp4 = packed.view(torch.float4_e2m1fn_x2).reshape(packed_shape)
sf_shape = list(x_bf16.shape[:-1]) + [n_blocks]
block_scale = block_scale.reshape(sf_shape)
return x_fp4, block_scale, global_scale
def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE):
"""Quantize BF16 activation tensor to NVFP4 (cudagraph-safe).
Unlike quantize_to_nvfp4(), this takes a pre-computed global_scale
instead of computing it via .max() (which forces CPU-GPU sync).
All operations are pure GPU with no CPU-GPU syncs.
Args:
x_bf16: (..., D) BF16 tensor
global_scale: float32 scalar (pre-computed, NOT from .max())
block_size: NVFP4 block size
Returns:
x_fp4: (..., D//2) float4_e2m1fn_x2
x_sf: (..., D//16) float8_e4m3fn
"""
x_f32 = x_bf16.float()
x_norm = x_f32 / global_scale
last_dim = x_norm.shape[-1]
n_blocks = ceil_div(last_dim, block_size)
if last_dim % block_size != 0:
pad_size = n_blocks * block_size - last_dim
x_norm = torch.nn.functional.pad(x_norm, (0, pad_size))
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
block_amax = x_reshaped.abs().amax(dim=-1)
# Detect zero blocks and underflow blocks (same threshold as quantize_to_nvfp4).
zero_block = block_amax < (6.0 * 2.0 ** -9)
x_reshaped = torch.where(zero_block.unsqueeze(-1),
torch.zeros_like(x_reshaped), x_reshaped)
block_amax = block_amax.clamp(min=1e-8)
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale)
block_sf_expanded = block_scale.float().unsqueeze(-1)
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)
signs = torch.sign(x_scaled)
abs_scaled = x_scaled.abs().clamp(max=6.0)
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
step_to_idx = _get_step_to_idx_lut(x_bf16.device)
indices = step_to_idx[half_steps.long()]
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
even = nibbles[..., ::2]
odd = nibbles[..., 1::2]
packed = (odd << 4) | even
packed_shape = list(x_bf16.shape)
packed_shape[-1] = last_dim // 2
x_fp4 = packed.view(torch.float4_e2m1fn_x2).reshape(packed_shape)
sf_shape = list(x_bf16.shape[:-1]) + [n_blocks]
block_scale = block_scale.reshape(sf_shape)
return x_fp4, block_scale
def quantize_weight_to_nvfp4(w_bf16, block_size=SF_VEC_SIZE):
"""Quantize BF16 weight matrix to NVFP4.
The weight is (K, N) where K is the input dim (packed dimension).
Block scales are computed along K (dim 0).
Args:
w_bf16: (K, N) BF16 weight matrix
Returns:
w_fp4: (K//2, N) float4_e2m1fn_x2 — K is the packed dim
w_sf: (K//16, N) float8_e4m3fn — block scales along K
global_scale: float32 scalar
"""
K, N = w_bf16.shape
w_f32 = w_bf16.float()
amax = w_f32.abs().max().clamp(min=1e-8).float()
global_scale = amax / (6.0 * 448.0)
w_norm = w_f32 / global_scale
k_blocks = ceil_div(K, block_size)
if K % block_size != 0:
w_norm = torch.nn.functional.pad(w_norm, (0, 0, 0, k_blocks * block_size - K))
w_reshaped = w_norm.reshape(k_blocks, block_size, N)
w_block_amax = w_reshaped.abs().amax(dim=1)
# Detect zero blocks and underflow blocks (same threshold).
zero_block = w_block_amax < (6.0 * 2.0 ** -9)
w_reshaped = torch.where(zero_block.unsqueeze(1),
torch.zeros_like(w_reshaped), w_reshaped)
w_block_amax = w_block_amax.clamp(min=1e-8)
w_sf = (w_block_amax / 6.0).to(torch.float8_e4m3fn)
w_sf = torch.where(zero_block, torch.zeros_like(w_sf), w_sf)
w_block_sf = w_sf.float().unsqueeze(1)
w_scaled = w_reshaped / w_block_sf.clamp(min=1e-8)
signs = torch.sign(w_scaled)
abs_scaled = w_scaled.abs().clamp(max=6.0)
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
step_to_idx = _get_step_to_idx_lut(w_bf16.device)
indices = step_to_idx[half_steps.long()]
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
even = nibbles[:, ::2, :]
odd = nibbles[:, 1::2, :]
packed = (odd << 4) | even
w_fp4 = packed.reshape(K // 2, N).view(torch.float4_e2m1fn_x2)
return w_fp4, w_sf, global_scale
# ── Scale Factor Assembly ─────────────────────────────────────────────
def interleave_l1_weights(w_ekn, granularity_bf16=8):
"""Interleave gate/up weights at granularity 8 in BF16 (4 in FP4).
The fused SwiGLU epilogue requires gate/up pairs to be adjacent in the
MMA accumulator. With interleaved weights, the MMA tile produces
gate[i*8..i*8+7] and up[i*8..i*8+7] next to each other in registers,
enabling a single-register SwiGLU without SMEM round-trips.
Before: [gate_0..gate_N/2-1 | up_0..up_N/2-1]
After: [gate_0..gate_7, up_0..up_7, gate_8..gate_15, up_8..up_15, ...]
The interleave operates along the N dimension, where each column = 1 BF16
(FP4 packing is along K, not N). So g = granularity_bf16 directly.
Args:
w_ekn: (E, K_packed, N_packed) FP4 weight tensor in K-major layout
N_packed = 2*intermediate/2 = intermediate (gate+up fused)
granularity_bf16: interleave group size in BF16 elements (default 8)
Returns:
(E, K_packed, N_packed) FP4 weight tensor with interleaved gate/up
"""
E, K, N = w_ekn.shape
N_half = N // 2 # gate and up each have N/2 FP4 columns
g = granularity_bf16 # N-axis interleave: each N-col = 1 BF16 col (packing is along K)
gate = w_ekn[:, :, :N_half].reshape(E, K, N_half // g, g)
up = w_ekn[:, :, N_half:].reshape(E, K, N_half // g, g)
return torch.stack([gate, up], dim=3).reshape(E, K, N)
def deinterleave_l1_weights(w_ekn, granularity_bf16=8):
"""De-interleave gate/up weights (inverse of interleave_l1_weights).
Used for testing/verification only.
"""
g = granularity_bf16 # N-axis: each N-col = 1 BF16 col
E, K, N = w_ekn.shape
w_reshaped = w_ekn.reshape(E, K, N // (2 * g), 2, g)
gate = w_reshaped[:, :, :, 0, :].reshape(E, K, N // 2)
up = w_reshaped[:, :, :, 1, :].reshape(E, K, N // 2)
return torch.cat([gate, up], dim=2)
def assemble_scales_2d_side(raw_scales):
"""Assemble activation scale factors for the 2Dx3D scenario.
Args:
raw_scales: list of (M_e, K_sf) float8_e4m3fn tensors, one per expert
Returns:
Assembled and swizzled scale tensor
"""
return assemble_raw_scales_2d3d_2d_side(raw_scales)
def assemble_scales_3d_side(raw_scales):
"""Assemble weight scale factors for the 2Dx3D scenario.
Args:
raw_scales: list of (K_sf, N) float8_e4m3fn tensors, one per expert
NOTE: These will be transposed to (N, K_sf) before swizzling,
since the kernel expects N as the non-K dimension.
Returns:
Assembled and swizzled scale tensor
"""
# Kernel expects (N, K_sf) — transpose before swizzling
transposed = [sf.T.contiguous() for sf in raw_scales]
return assemble_raw_scales_2d3d_3d_side(transposed)
# ── Tensor Layout Conversion ──────────────────────────────────────────
def make_b_k_major(b_tensor):
"""Convert B tensor from N-major to K-major layout.
The kernel expects B with stride (E*K*N, 1, K) — K is contiguous.
torch.stack produces stride (E*K*N, N, 1) — N is contiguous.
Args:
b_tensor: (experts, K_packed, N_packed) float4_e2m1fn_x2, N-major
Returns:
Same shape, K-major strides
"""
return b_tensor.permute(0, 2, 1).contiguous().permute(0, 2, 1)
def compute_expert_offsets(tokens_per_expert, num_experts, device="cuda"):
"""Compute cumulative token offsets for the grouped GEMM.
Args:
tokens_per_expert: list of int, one per expert
Returns:
offs: (num_experts,) int32 — cumulative sum
"""
offs = torch.tensor(
[sum(tokens_per_expert[:e+1]) for e in range(num_experts)],
dtype=torch.int32, device=device,
)
return offs
# ── Kernel Launch ─────────────────────────────────────────────────────
def warmup_compilation(num_experts, K_packed, N_packed, device,
mma_tiler_mn=(128, 128), cluster_shape_mn=(1, 1)):
"""Eagerly JIT-compile the GEMM kernel for a specific shape.
Call this BEFORE model weights are loaded to ensure cute.compile
runs exactly once per shape. The compiled kernel is cached and
reused by run_nvfp4_grouped_gemm on the forward path.
Uses random non-zero data. Zero-filled FP4/FP8 tensors cause
cudaErrorIllegalInstruction because the GEMM arithmetic hits
invalid values (division by zero in scale dequantization, NaN
propagation). Random data produces valid intermediate results
that exercise the kernel's full arithmetic path.
The warmup tensors are freed immediately after compilation.
Memory cost is minimal (~50MB for typical DeepSeek-V4 shapes).
Args:
num_experts: number of experts (local, after expert parallelism)
K_packed: K dimension in float4 elements (i.e. K_original // 2)
N_packed: N dimension in float4 elements (i.e. N_original // 2)
device: 'cuda:X' or 'cuda'
mma_tiler_mn: GEMM tiling (default (128,128))
cluster_shape_mn: cluster shape (default (1,1))
"""
cache_key = (num_experts, str(device), mma_tiler_mn, cluster_shape_mn,
K_packed, N_packed)
if cache_key in _compiled_kernel_cache:
return # Already compiled
# Generate VALID FP4 data by quantizing random BF16 through our pipeline.
# Random uint8 bytes as FP4 bit patterns produce NaN/Inf when the GEMM
# dequantizes them (FP4 value * FP8 scale * FP32 global scale), which
# causes cudaErrorIllegalInstruction in the Blackwell MMA hardware.
# The ONLY safe approach: generate random BF16, quantize through our
# quantize_to_nvfp4, producing mathematically consistent FP4 + FP8 scales.
_warmup_a_bf16 = torch.randn(128, K_packed * 2, dtype=torch.bfloat16, device=device) * 0.1
mat_a, scale_a, _ = quantize_to_nvfp4(_warmup_a_bf16)
del _warmup_a_bf16
_warmup_b_bf16 = torch.randn(1, K_packed * 2, N_packed * 2, dtype=torch.bfloat16, device=device) * 0.1 # 1 expert: kernel compiles same regardless of count
mat_b, scale_b, _ = quantize_to_nvfp4(_warmup_b_bf16)
del _warmup_b_bf16
out = torch.zeros(128, N_packed, dtype=torch.bfloat16, device=device)
expert_offsets = torch.full((1,), 128, dtype=torch.int32, device=device)
global_scale_a = torch.ones(1, dtype=torch.float32, device=device)
global_scale_b = torch.ones(1, dtype=torch.float32, device=device)
kernel = ScaledGroupedGemmKernel(
scenario="2Dx3D",
sf_vec_size=SF_VEC_SIZE,
accumulate_on_output=False,
separate_tensormap_init=True,
consistent_token_padding=False,
mma_tiler_mnk=(*mma_tiler_mn, 256),
cluster_shape_mnk=(*cluster_shape_mn, 1),
)
def to_cute(t):
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)
b_c = to_cute(mat_b)
sfa_c = to_cute(scale_a)
sfb_c = to_cute(scale_b)
c_c = to_cute(out)
offs_c = to_cute(expert_offsets)
workspace_size = kernel.get_workspace_size(num_experts)
workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=device)
ws_c = to_cute(workspace)
gsa_c = to_cute(global_scale_a)
gsb_c = to_cute(global_scale_b)
import cuda.bindings.driver as cuda
cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1]
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
compiled = cute.compile(
kernel, a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
max_active_clusters, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
# Warmup run (required by CuTeDSL)
compiled(
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
torch.cuda.synchronize()
# Cache compiled kernel + pre-allocated workspace
# (NOT CuTe wrappers — they hold refs to dummy tensors)
_compiled_kernel_cache[cache_key] = {
'compiled': compiled,
'workspace': workspace, # pre-allocated, reused
'workspace_size': workspace_size,
}
# Free dummy data tensors (workspace is kept in cache)
del mat_a, mat_b, scale_a, scale_b, out, expert_offsets
del global_scale_a, global_scale_b
del a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, gsa_c, gsb_c
torch.cuda.empty_cache()
def run_nvfp4_grouped_gemm(
mat_a, # (tokens_sum, K_packed) float4_e2m1fn_x2
mat_b, # (experts, K_packed, N_packed) float4_e2m1fn_x2, K-major
scale_a, # assembled 2D side (padded + swizzled)
scale_b, # assembled 3D side (padded + swizzled)
expert_offsets, # (experts,) int32 cumulative token offsets
global_scale_a=None, # (experts,) float32
global_scale_b=None, # (experts,) float32
mma_tiler_mn=(128, 128),
cluster_shape_mn=(1, 1),
):
"""Run the CuTeDSL NVFP4 scaled grouped GEMM.
2Dx3D: A(tokens, K) x B(experts, K, N) -> C(tokens, N)
Kernel is compiled once (either via warmup_compilation() during init
or lazily on first call), then cached. Workspace is pre-allocated
during warmup and reused — no torch.full() in the hot path.
No torch.cuda.synchronize() or .item() in the forward path.
"""
num_experts = mat_b.shape[0]
n_dim = mat_b.shape[2]
tokens_sum = mat_a.shape[0]
out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device)
cache_key = (num_experts, str(mat_a.device), mma_tiler_mn, cluster_shape_mn,
mat_a.shape[1], mat_b.shape[2])
if cache_key not in _compiled_kernel_cache:
# Lazy compilation — safety net if warmup_compilation wasn't called.
# This should NOT happen in production (warmup is called during init).
kernel = ScaledGroupedGemmKernel(
scenario="2Dx3D",
sf_vec_size=SF_VEC_SIZE,
accumulate_on_output=False,
separate_tensormap_init=True,
consistent_token_padding=False,
mma_tiler_mnk=(*mma_tiler_mn, 256),
cluster_shape_mnk=(*cluster_shape_mn, 1),
)
def to_cute(t):
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)
b_c = to_cute(mat_b)
sfa_c = to_cute(scale_a)
sfb_c = to_cute(scale_b)
c_c = to_cute(out)
offs_c = to_cute(expert_offsets)
workspace_size = kernel.get_workspace_size(num_experts)
workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=mat_a.device)
ws_c = to_cute(workspace)
gsa_c = to_cute(global_scale_a) if global_scale_a is not None else None
gsb_c = to_cute(global_scale_b) if global_scale_b is not None else None
import cuda.bindings.driver as cuda
cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1]
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
compiled = cute.compile(
kernel, a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
max_active_clusters, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
compiled(
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
_compiled_kernel_cache[cache_key] = {
'compiled': compiled,
'workspace': workspace,
'workspace_size': workspace_size,
}
# --- Invoke cached kernel ---
entry = _compiled_kernel_cache[cache_key]
compiled = entry['compiled']
workspace = entry['workspace']
# Re-create CuTe wrappers from current tensors each call.
# This is cheap (metadata only, no GPU work) and avoids stale
# references to tensors from previous calls that may have been freed.
def to_cute(t):
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)
b_c = to_cute(mat_b)
sfa_c = to_cute(scale_a)
sfb_c = to_cute(scale_b)
c_c = to_cute(out)
offs_c = to_cute(expert_offsets)
ws_c = to_cute(workspace)
gsa_c = to_cute(global_scale_a) if global_scale_a is not None else None
gsb_c = to_cute(global_scale_b) if global_scale_b is not None else None
import cuda.bindings.driver as cuda
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
compiled(
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
return out
# ── Fused SwiGLU GEMM (Stage 1: SiLU in registers, BF16 output) ──────
# Cache for fused kernel (separate from standard GEMM cache)
_fused_kernel_cache = {}
def warmup_fused_swiglu_compilation(num_experts, K_packed, N_packed, device,
swiglu_limit=0.0,
mma_tiler_mn=(128, 128),
cluster_shape_mn=(1, 1)):
"""Eagerly JIT-compile the fused SwiGLU GEMM kernel.
Must be called during model initialization. See warmup_compilation()
for the standard GEMM equivalent.
"""
from cutedsl.kernel.moe.fused_swiglu_grouped_mm import FusedSwiGLUScaledGroupedGemmKernel
cache_key = ('fused', num_experts, str(device), mma_tiler_mn, cluster_shape_mn,
K_packed, N_packed, swiglu_limit)
if cache_key in _fused_kernel_cache:
return
# Generate VALID FP4 data by quantizing random BF16 (same as warmup_compilation)
_warmup_a_bf16 = torch.randn(128, K_packed * 2, dtype=torch.bfloat16, device=device) * 0.1
mat_a, scale_a, _ = quantize_to_nvfp4(_warmup_a_bf16)
del _warmup_a_bf16
_warmup_b_bf16 = torch.randn(1, K_packed * 2, N_packed * 2, dtype=torch.bfloat16, device=device) * 0.1 # 1 expert: kernel compiles same regardless of count
mat_b, scale_b, _ = quantize_to_nvfp4(_warmup_b_bf16)
del _warmup_b_bf16
# BF16 output (Stage 1: we still write BF16)
# The fused kernel writes intermediate (N/2) since gate+up → silu result
out = torch.zeros(128, N_packed, dtype=torch.bfloat16, device=device)
expert_offsets = torch.full((1,), 128, dtype=torch.int32, device=device)
global_scale_a = torch.ones(1, dtype=torch.float32, device=device)
global_scale_b = torch.ones(1, dtype=torch.float32, device=device)
kernel = FusedSwiGLUScaledGroupedGemmKernel(
scenario="2Dx3D",
sf_vec_size=SF_VEC_SIZE,
accumulate_on_output=False,
separate_tensormap_init=True,
consistent_token_padding=False,
mma_tiler_mnk=(*mma_tiler_mn, 256),
cluster_shape_mnk=(*cluster_shape_mn, 1),
fused_swiglu=True,
swiglu_limit=swiglu_limit,
)
def to_cute(t):
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)
b_c = to_cute(mat_b)
sfa_c = to_cute(scale_a)
sfb_c = to_cute(scale_b)
c_c = to_cute(out)
offs_c = to_cute(expert_offsets)
workspace_size = kernel.get_workspace_size(num_experts)
workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=device)
ws_c = to_cute(workspace)
gsa_c = to_cute(global_scale_a)
gsb_c = to_cute(global_scale_b)
import cuda.bindings.driver as cuda
cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1]
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
compiled = cute.compile(
kernel, a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
max_active_clusters, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
compiled(
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
torch.cuda.synchronize()
_fused_kernel_cache[cache_key] = {
'compiled': compiled,
'workspace': workspace,
'workspace_size': workspace_size,
}
del mat_a, mat_b, scale_a, scale_b, out, expert_offsets
del global_scale_a, global_scale_b
del a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, gsa_c, gsb_c
torch.cuda.empty_cache()
def run_fused_swiglu_grouped_gemm(
mat_a, # (tokens_sum, K_packed) float4_e2m1fn_x2
mat_b, # (experts, K_packed, N_packed) float4_e2m1fn_x2, K-major
scale_a, # assembled 2D side (padded + swizzled)
scale_b, # assembled 3D side (padded + swizzled)
expert_offsets, # (experts,) int32 cumulative token offsets
global_scale_a=None, # (experts,) float32
global_scale_b=None, # (experts,) float32
swiglu_limit=0.0,
mma_tiler_mn=(128, 128),
cluster_shape_mn=(1, 1),
):
"""Run the fused SwiGLU NVFP4 scaled grouped GEMM.
Stage 1: SiLU is applied to the full accumulator in registers,
then written as BF16 to C. Gate/up pairing is not yet implemented.
"""
from cutedsl.kernel.moe.fused_swiglu_grouped_mm import FusedSwiGLUScaledGroupedGemmKernel
num_experts = mat_b.shape[0]
n_dim = mat_b.shape[2]
tokens_sum = mat_a.shape[0]
out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device)
cache_key = ('fused', num_experts, str(mat_a.device), mma_tiler_mn, cluster_shape_mn,
mat_a.shape[1], mat_b.shape[2], swiglu_limit)
if cache_key not in _fused_kernel_cache:
# Lazy compilation
kernel = FusedSwiGLUScaledGroupedGemmKernel(
scenario="2Dx3D",
sf_vec_size=SF_VEC_SIZE,
accumulate_on_output=False,
separate_tensormap_init=True,
consistent_token_padding=False,
mma_tiler_mnk=(*mma_tiler_mn, 256),
cluster_shape_mnk=(*cluster_shape_mn, 1),
fused_swiglu=True,
swiglu_limit=swiglu_limit,
)
def to_cute(t):
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)
b_c = to_cute(mat_b)
sfa_c = to_cute(scale_a)
sfb_c = to_cute(scale_b)
c_c = to_cute(out)
offs_c = to_cute(expert_offsets)
workspace_size = kernel.get_workspace_size(num_experts)
workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=mat_a.device)
ws_c = to_cute(workspace)
gsa_c = to_cute(global_scale_a) if global_scale_a is not None else None
gsb_c = to_cute(global_scale_b) if global_scale_b is not None else None
import cuda.bindings.driver as cuda
cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1]
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
compiled = cute.compile(
kernel, a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
max_active_clusters, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
compiled(
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
_fused_kernel_cache[cache_key] = {
'compiled': compiled,
'workspace': workspace,
'workspace_size': workspace_size,
}
entry = _fused_kernel_cache[cache_key]
compiled = entry['compiled']
workspace = entry['workspace']
def to_cute(t):
ct = cutlass_torch.from_dlpack(t)
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
a_c = to_cute(mat_a)
b_c = to_cute(mat_b)
sfa_c = to_cute(scale_a)
sfb_c = to_cute(scale_b)
c_c = to_cute(out)
offs_c = to_cute(expert_offsets)
ws_c = to_cute(workspace)
gsa_c = to_cute(global_scale_a) if global_scale_a is not None else None
gsb_c = to_cute(global_scale_b) if global_scale_b is not None else None
import cuda.bindings.driver as cuda
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
compiled(
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream,
global_scale_a=gsa_c, global_scale_b=gsb_c,
)
return out
def deinterleave_quantize_nvfp4_cuda(fused_bf16, intermediate, global_scale, granularity=8):
"""De-interleave + quantize fused SwiGLU output using a custom CUDA kernel.
Single kernel launch, no Python loop. 4x faster than the Python path.
Args:
fused_bf16: (M, 2*intermediate) BF16 — fused L1 output with interleaved gate/up
intermediate: intermediate dimension (e.g., 3072)
global_scale: pre-computed global scale for quantization
granularity: interleave granularity in BF16 columns (default 8)
Returns:
x_fp4: (M, intermediate//2) float4_e2m1fn_x2 — quantized SwiGLU
x_sf: (M, intermediate//16) float8_e4m3fn — block scales
"""
from torch.utils.cpp_extension import load
import os
kernel_dir = os.path.join(os.path.dirname(__file__), "kernels")
mod = load(
name="deinterleave_quantize_nvfp4",
sources=[os.path.join(kernel_dir, "deinterleave_quantize.cu")],
extra_cuda_cflags=["-gencode=arch=compute_100a,code=sm_100a"],
verbose=False,
)
return mod.deinterleave_quantize_nvfp4(fused_bf16, intermediate, granularity, global_scale)