The _NVFP4_STEP_LUT_LOCK caused 'Unsupported context manager' under torch.compile/cudagraph. LUT is now pre-populated during warmup so the fast path (cache hit) never hits a lock. Also removed all init/warmup debug prints from CuTeDSL kernels.
372 lines
13 KiB
Python
372 lines
13 KiB
Python
"""
|
|
Bridge layer for the CuTeDSL NVFP4 MoE kernel.
|
|
|
|
Handles tensor layout conversion from our pipeline's format to what
|
|
the ScaledGroupedGemmKernel expects:
|
|
- BF16 → NVFP4 quantization (float4_e2m1fn_x2)
|
|
- Scale factor assembly (padding + swizzle)
|
|
- B tensor K-major stride conversion
|
|
- Expert offset computation
|
|
"""
|
|
import math
|
|
import torch
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
import cutlass.torch as cutlass_torch
|
|
import cutlass.utils as utils
|
|
|
|
from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
|
ScaledGroupedGemmKernel,
|
|
pad_and_swizzle_single,
|
|
assemble_raw_scales_2d3d_2d_side,
|
|
assemble_raw_scales_2d3d_3d_side,
|
|
cat_byte_reinterpretable_tensors,
|
|
stack_byte_reinterpretable_tensors,
|
|
)
|
|
|
|
# ── Constants ──────────────────────────────────────────────────────────
|
|
|
|
E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
|
|
|
|
# Cache compiled kernels by (num_experts, device, K, N)
|
|
_compiled_kernel_cache = {}
|
|
|
|
# Cached LUT for E2M1 quantization (created once per device, cudagraph-safe)
|
|
_NVFP4_STEP_LUT_CACHE = {}
|
|
def _get_step_to_idx_lut(device):
|
|
"""Get or create the E2M1 step-to-index LUT for the given device.
|
|
|
|
Cached per device to avoid CPU->CUDA copies during cudagraph capture.
|
|
Must be pre-populated during warmup (before torch.compile/cudagraph capture)
|
|
so the lock is never entered on the compiled path.
|
|
"""
|
|
# Fast path: already cached — no lock needed (torch.compile-safe)
|
|
if device in _NVFP4_STEP_LUT_CACHE:
|
|
return _NVFP4_STEP_LUT_CACHE[device]
|
|
# Slow path: first call, create the LUT
|
|
lut = torch.as_tensor(
|
|
[0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 7, 7],
|
|
dtype=torch.int8, device=device,
|
|
)
|
|
_NVFP4_STEP_LUT_CACHE[device] = lut
|
|
return lut
|
|
SF_VEC_SIZE = 16 # NVFP4 block size
|
|
|
|
|
|
def ceil_div(a, b):
|
|
return (a + b - 1) // b
|
|
|
|
|
|
def round_up(a, b):
|
|
return ceil_div(a, b) * b
|
|
|
|
|
|
# ── Quantization ──────────────────────────────────────────────────────
|
|
|
|
def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE):
|
|
"""Quantize BF16 tensor to NVFP4.
|
|
|
|
Args:
|
|
x_bf16: (..., D) BF16 tensor
|
|
|
|
Returns:
|
|
x_fp4: (..., D//2) float4_e2m1fn_x2 — native PyTorch FP4
|
|
x_sf: (..., D//16) float8_e4m3fn — block scales
|
|
global_scale: float32 scalar
|
|
"""
|
|
x_f32 = x_bf16.float()
|
|
amax = x_f32.abs().max().clamp(min=1e-8).float()
|
|
global_scale = amax / (6.0 * 448.0)
|
|
x_norm = x_f32 / global_scale
|
|
|
|
last_dim = x_norm.shape[-1]
|
|
n_blocks = ceil_div(last_dim, block_size)
|
|
|
|
if last_dim % block_size != 0:
|
|
pad_size = n_blocks * block_size - last_dim
|
|
x_norm = torch.nn.functional.pad(x_norm, (0, pad_size))
|
|
|
|
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
|
|
block_amax = x_reshaped.abs().amax(dim=-1).clamp(min=1e-8)
|
|
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
|
|
|
|
# Nearest E2M1
|
|
block_sf_expanded = block_scale.float().unsqueeze(-1)
|
|
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)
|
|
|
|
signs = torch.sign(x_scaled)
|
|
abs_scaled = x_scaled.abs().clamp(max=6.0)
|
|
|
|
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
|
|
step_to_idx = _get_step_to_idx_lut(x_bf16.device)
|
|
indices = step_to_idx[half_steps.long()]
|
|
|
|
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
|
|
even = nibbles[..., ::2]
|
|
odd = nibbles[..., 1::2]
|
|
packed = (odd << 4) | even
|
|
|
|
packed_shape = list(x_bf16.shape)
|
|
packed_shape[-1] = last_dim // 2
|
|
x_fp4 = packed.view(torch.float4_e2m1fn_x2).reshape(packed_shape)
|
|
|
|
sf_shape = list(x_bf16.shape[:-1]) + [n_blocks]
|
|
block_scale = block_scale.reshape(sf_shape)
|
|
|
|
return x_fp4, block_scale, global_scale
|
|
|
|
|
|
def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE):
|
|
"""Quantize BF16 activation tensor to NVFP4 (cudagraph-safe).
|
|
|
|
Unlike quantize_to_nvfp4(), this takes a pre-computed global_scale
|
|
instead of computing it via .max() (which forces CPU-GPU sync).
|
|
All operations are pure GPU with no CPU-GPU syncs.
|
|
|
|
Args:
|
|
x_bf16: (..., D) BF16 tensor
|
|
global_scale: float32 scalar (pre-computed, NOT from .max())
|
|
block_size: NVFP4 block size
|
|
|
|
Returns:
|
|
x_fp4: (..., D//2) float4_e2m1fn_x2
|
|
x_sf: (..., D//16) float8_e4m3fn
|
|
"""
|
|
x_f32 = x_bf16.float()
|
|
x_norm = x_f32 / global_scale
|
|
|
|
last_dim = x_norm.shape[-1]
|
|
n_blocks = ceil_div(last_dim, block_size)
|
|
|
|
if last_dim % block_size != 0:
|
|
pad_size = n_blocks * block_size - last_dim
|
|
x_norm = torch.nn.functional.pad(x_norm, (0, pad_size))
|
|
|
|
x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size)
|
|
block_amax = x_reshaped.abs().amax(dim=-1).clamp(min=1e-8)
|
|
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
|
|
|
|
block_sf_expanded = block_scale.float().unsqueeze(-1)
|
|
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)
|
|
signs = torch.sign(x_scaled)
|
|
abs_scaled = x_scaled.abs().clamp(max=6.0)
|
|
|
|
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
|
|
step_to_idx = _get_step_to_idx_lut(x_bf16.device)
|
|
indices = step_to_idx[half_steps.long()]
|
|
|
|
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
|
|
even = nibbles[..., ::2]
|
|
odd = nibbles[..., 1::2]
|
|
packed = (odd << 4) | even
|
|
|
|
packed_shape = list(x_bf16.shape)
|
|
packed_shape[-1] = last_dim // 2
|
|
x_fp4 = packed.view(torch.float4_e2m1fn_x2).reshape(packed_shape)
|
|
|
|
sf_shape = list(x_bf16.shape[:-1]) + [n_blocks]
|
|
block_scale = block_scale.reshape(sf_shape)
|
|
|
|
return x_fp4, block_scale
|
|
|
|
|
|
def quantize_weight_to_nvfp4(w_bf16, block_size=SF_VEC_SIZE):
|
|
"""Quantize BF16 weight matrix to NVFP4.
|
|
|
|
The weight is (K, N) where K is the input dim (packed dimension).
|
|
Block scales are computed along K (dim 0).
|
|
|
|
Args:
|
|
w_bf16: (K, N) BF16 weight matrix
|
|
|
|
Returns:
|
|
w_fp4: (K//2, N) float4_e2m1fn_x2 — K is the packed dim
|
|
w_sf: (K//16, N) float8_e4m3fn — block scales along K
|
|
global_scale: float32 scalar
|
|
"""
|
|
K, N = w_bf16.shape
|
|
w_f32 = w_bf16.float()
|
|
amax = w_f32.abs().max().clamp(min=1e-8).float()
|
|
global_scale = amax / (6.0 * 448.0)
|
|
w_norm = w_f32 / global_scale
|
|
|
|
k_blocks = ceil_div(K, block_size)
|
|
if K % block_size != 0:
|
|
w_norm = torch.nn.functional.pad(w_norm, (0, 0, 0, k_blocks * block_size - K))
|
|
|
|
w_reshaped = w_norm.reshape(k_blocks, block_size, N)
|
|
w_block_amax = w_reshaped.abs().amax(dim=1).clamp(min=1e-8)
|
|
w_sf = (w_block_amax / 6.0).to(torch.float8_e4m3fn)
|
|
|
|
w_block_sf = w_sf.float().unsqueeze(1)
|
|
w_scaled = w_reshaped / w_block_sf.clamp(min=1e-8)
|
|
|
|
signs = torch.sign(w_scaled)
|
|
abs_scaled = w_scaled.abs().clamp(max=6.0)
|
|
|
|
half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8)
|
|
step_to_idx = _get_step_to_idx_lut(w_bf16.device)
|
|
indices = step_to_idx[half_steps.long()]
|
|
nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8)
|
|
|
|
even = nibbles[:, ::2, :]
|
|
odd = nibbles[:, 1::2, :]
|
|
packed = (odd << 4) | even
|
|
|
|
w_fp4 = packed.reshape(K // 2, N).view(torch.float4_e2m1fn_x2)
|
|
return w_fp4, w_sf, global_scale
|
|
|
|
|
|
# ── Scale Factor Assembly ─────────────────────────────────────────────
|
|
|
|
def assemble_scales_2d_side(raw_scales):
|
|
"""Assemble activation scale factors for the 2Dx3D scenario.
|
|
|
|
Args:
|
|
raw_scales: list of (M_e, K_sf) float8_e4m3fn tensors, one per expert
|
|
|
|
Returns:
|
|
Assembled and swizzled scale tensor
|
|
"""
|
|
return assemble_raw_scales_2d3d_2d_side(raw_scales)
|
|
|
|
|
|
def assemble_scales_3d_side(raw_scales):
|
|
"""Assemble weight scale factors for the 2Dx3D scenario.
|
|
|
|
Args:
|
|
raw_scales: list of (K_sf, N) float8_e4m3fn tensors, one per expert
|
|
NOTE: These will be transposed to (N, K_sf) before swizzling,
|
|
since the kernel expects N as the non-K dimension.
|
|
|
|
Returns:
|
|
Assembled and swizzled scale tensor
|
|
"""
|
|
# Kernel expects (N, K_sf) — transpose before swizzling
|
|
transposed = [sf.T.contiguous() for sf in raw_scales]
|
|
return assemble_raw_scales_2d3d_3d_side(transposed)
|
|
|
|
|
|
# ── Tensor Layout Conversion ──────────────────────────────────────────
|
|
|
|
def make_b_k_major(b_tensor):
|
|
"""Convert B tensor from N-major to K-major layout.
|
|
|
|
The kernel expects B with stride (E*K*N, 1, K) — K is contiguous.
|
|
torch.stack produces stride (E*K*N, N, 1) — N is contiguous.
|
|
|
|
Args:
|
|
b_tensor: (experts, K_packed, N_packed) float4_e2m1fn_x2, N-major
|
|
|
|
Returns:
|
|
Same shape, K-major strides
|
|
"""
|
|
return b_tensor.permute(0, 2, 1).contiguous().permute(0, 2, 1)
|
|
|
|
|
|
def compute_expert_offsets(tokens_per_expert, num_experts, device="cuda"):
|
|
"""Compute cumulative token offsets for the grouped GEMM.
|
|
|
|
Args:
|
|
tokens_per_expert: list of int, one per expert
|
|
|
|
Returns:
|
|
offs: (num_experts,) int32 — cumulative sum
|
|
"""
|
|
offs = torch.tensor(
|
|
[sum(tokens_per_expert[:e+1]) for e in range(num_experts)],
|
|
dtype=torch.int32, device=device,
|
|
)
|
|
return offs
|
|
|
|
|
|
# ── Kernel Launch ─────────────────────────────────────────────────────
|
|
|
|
def run_nvfp4_grouped_gemm(
|
|
mat_a, # (tokens_sum, K_packed) float4_e2m1fn_x2
|
|
mat_b, # (experts, K_packed, N_packed) float4_e2m1fn_x2, K-major
|
|
scale_a, # assembled 2D side (padded + swizzled)
|
|
scale_b, # assembled 3D side (padded + swizzled)
|
|
expert_offsets, # (experts,) int32 cumulative token offsets
|
|
global_scale_a=None, # (experts,) float32
|
|
global_scale_b=None, # (experts,) float32
|
|
mma_tiler_mn=(128, 128),
|
|
cluster_shape_mn=(1, 1),
|
|
):
|
|
"""Run the CuTeDSL NVFP4 scaled grouped GEMM.
|
|
|
|
2Dx3D: A(tokens, K) x B(experts, K, N) -> C(tokens, N)
|
|
|
|
CUDAGraph-compatible: kernel is compiled once on first call (warmup),
|
|
then only the compiled kernel is invoked on subsequent calls.
|
|
No torch.cuda.synchronize() or .item() in the forward path.
|
|
"""
|
|
num_experts = mat_b.shape[0]
|
|
n_dim = mat_b.shape[2] # packed N (in float4 elements)
|
|
tokens_sum = mat_a.shape[0]
|
|
|
|
out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device)
|
|
|
|
kernel = ScaledGroupedGemmKernel(
|
|
scenario="2Dx3D",
|
|
sf_vec_size=SF_VEC_SIZE,
|
|
accumulate_on_output=False,
|
|
separate_tensormap_init=True,
|
|
consistent_token_padding=False,
|
|
mma_tiler_mnk=(*mma_tiler_mn, 256),
|
|
cluster_shape_mnk=(*cluster_shape_mn, 1),
|
|
)
|
|
|
|
# Convert to CuTe tensors with dynamic layout
|
|
def to_cute(t):
|
|
ct = cutlass_torch.from_dlpack(t)
|
|
return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t))
|
|
|
|
a_c = to_cute(mat_a)
|
|
b_c = to_cute(mat_b)
|
|
sfa_c = to_cute(scale_a)
|
|
sfb_c = to_cute(scale_b)
|
|
c_c = to_cute(out)
|
|
offs_c = to_cute(expert_offsets)
|
|
|
|
workspace_size = kernel.get_workspace_size(num_experts)
|
|
workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=mat_a.device)
|
|
ws_c = to_cute(workspace)
|
|
|
|
gsa_c = to_cute(global_scale_a) if global_scale_a is not None else None
|
|
gsb_c = to_cute(global_scale_b) if global_scale_b is not None else None
|
|
|
|
import cuda.bindings.driver as cuda
|
|
cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1]
|
|
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size)
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
|
|
# Cache compiled kernel by (num_experts, K, N)
|
|
cache_key = (num_experts, str(mat_a.device), mma_tiler_mn, cluster_shape_mn,
|
|
mat_a.shape[1], mat_b.shape[2])
|
|
|
|
if cache_key not in _compiled_kernel_cache:
|
|
# First call: compile with real tensors
|
|
compiled = cute.compile(
|
|
kernel, a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
|
|
max_active_clusters, stream,
|
|
global_scale_a=gsa_c, global_scale_b=gsb_c,
|
|
)
|
|
# Warmup run
|
|
compiled(
|
|
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
|
|
stream,
|
|
global_scale_a=gsa_c, global_scale_b=gsb_c,
|
|
)
|
|
_compiled_kernel_cache[cache_key] = compiled
|
|
else:
|
|
# Subsequent calls: just invoke the compiled kernel
|
|
compiled = _compiled_kernel_cache[cache_key]
|
|
compiled(
|
|
a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c,
|
|
stream,
|
|
global_scale_a=gsa_c, global_scale_b=gsb_c,
|
|
)
|
|
|
|
return out
|