From 44b40d41fe3f8fa55db12b41329f0065765db0c7 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 16 May 2026 20:05:59 +0000 Subject: [PATCH] fix: compile CuTeDSL kernel with real tensors, not dummy shapes The kernel's TMA descriptors are sized from compilation-time shapes. Dummy 256x256 caused wrong memory access for real 3584x6144 data. Now compiles with actual runtime tensors on first use, cached by (num_experts, K, N). Compilation happens once during warmup. Forward call remains cudagraph-safe. --- cutedsl/bridge.py | 333 +++++++++++++++++++--------------------------- 1 file changed, 138 insertions(+), 195 deletions(-) diff --git a/cutedsl/bridge.py b/cutedsl/bridge.py index 7f54d407..1cab54d1 100644 --- a/cutedsl/bridge.py +++ b/cutedsl/bridge.py @@ -8,13 +8,34 @@ the ScaledGroupedGemmKernel expects: - B tensor K-major stride conversion - Expert offset computation -CUDA-graph-compatible: no .item() calls, no torch.cuda.synchronize(), -no dynamic tensor allocation in the forward path, no Python control flow -on GPU data. +CUDA-graph-compatible: no .item() calls, no torch.cuda.synchronize() +in the forward path, no Python control flow on GPU data. + +Compilation uses real tensors (not dummy shapes) because the CuTeDSL +kernel's TMA descriptors are sized from compilation-time tensor shapes. +This happens once during warmup, outside cudagraph capture. """ import math import threading +import torch +import cutlass +import cutlass.cute as cute +import cutlass.cute.backend # noqa: F401 (triggers CUDA init) +import cutlass_torch + +from cutedsl.kernel.moe.torch_scaled_grouped_mm import ( + ScaledGroupedGemmKernel, + utils, +) +from cutedsl.kernel.moe.moe_utils import ceil_div + +# ── Constants ────────────────────────────────────────────────────────── + +SF_VEC_SIZE = 16 # NVFP4 block size + +E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] + # Cached LUT for E2M1 quantization (created once per device, cudagraph-safe) _NVFP4_STEP_LUT_CACHE = {} _NVFP4_STEP_LUT_LOCK = threading.Lock() @@ -32,36 +53,13 @@ def _get_step_to_idx_lut(device): dtype=torch.int8, device=device, ) return _NVFP4_STEP_LUT_CACHE[device] -import torch -import cutlass -import cutlass.cute as cute -import cutlass.torch as cutlass_torch -import cutlass.utils as utils - -from cutedsl.kernel.moe.torch_scaled_grouped_mm import ( - ScaledGroupedGemmKernel, - pad_and_swizzle_single, - assemble_raw_scales_2d3d_2d_side, - assemble_raw_scales_2d3d_3d_side, - cat_byte_reinterpretable_tensors, - stack_byte_reinterpretable_tensors, -) - -# ── Constants ────────────────────────────────────────────────────────── - -E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] -SF_VEC_SIZE = 16 # NVFP4 block size - - -def ceil_div(a, b): - return (a + b - 1) // b def round_up(a, b): - return ceil_div(a, b) * b + return ((a + b - 1) // b) * b -# ── Quantization ────────────────────────────────────────────────────── +# ── Quantization ─────────────────────────────────────────────────────── def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE): """Quantize BF16 tensor to NVFP4. @@ -77,8 +75,8 @@ def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE): x_bf16: (..., D) BF16 tensor Returns: - x_fp4: (..., D//2) float4_e2m1fn_x2 — native PyTorch FP4 - x_sf: (..., D//16) float8_e4m3fn — block scales + x_fp4: (..., D//2) float4_e2m1fn_x2 + x_sf: (..., D//16) float8_e4m3fn global_scale: float32 scalar """ x_f32 = x_bf16.float() @@ -182,49 +180,59 @@ def quantize_weight_to_nvfp4(w_bf16, block_size=SF_VEC_SIZE): """Quantize BF16 weight matrix to NVFP4. The weight is (K, N) where K is the input dim (packed dimension). - Block scales are computed along K (dim 0). - + Args: w_bf16: (K, N) BF16 weight matrix + block_size: NVFP4 block size Returns: - w_fp4: (K//2, N) float4_e2m1fn_x2 — K is the packed dim - w_sf: (K//16, N) float8_e4m3fn — block scales along K + w_fp4: (K//2, N) float4_e2m1fn_x2 — native PyTorch FP4 + w_sf: (K//16, N) float8_e4m3fn — block scales global_scale: float32 scalar """ - K, N = w_bf16.shape - w_f32 = w_bf16.float() - amax = w_f32.abs().max().clamp(min=1e-8).float() - global_scale = amax / (6.0 * 448.0) - w_norm = w_f32 / global_scale - - k_blocks = ceil_div(K, block_size) - if K % block_size != 0: - w_norm = torch.nn.functional.pad(w_norm, (0, 0, 0, k_blocks * block_size - K)) - - w_reshaped = w_norm.reshape(k_blocks, block_size, N) - w_block_amax = w_reshaped.abs().amax(dim=1).clamp(min=1e-8) - w_sf = (w_block_amax / 6.0).to(torch.float8_e4m3fn) - - w_block_sf = w_sf.float().unsqueeze(1) - w_scaled = w_reshaped / w_block_sf.clamp(min=1e-8) - - magnitudes = torch.tensor(E2M1_MAGNITUDES, dtype=torch.float32, device=w_bf16.device) - signs = torch.sign(w_scaled) - abs_scaled = w_scaled.abs().unsqueeze(-1) - distances = (abs_scaled - magnitudes).abs() - indices = distances.argmin(dim=-1) - nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8) - - even = nibbles[:, ::2, :] - odd = nibbles[:, 1::2, :] - packed = (odd << 4) | even - - w_fp4 = packed.reshape(K // 2, N).view(torch.float4_e2m1fn_x2) - return w_fp4, w_sf, global_scale + return quantize_to_nvfp4(w_bf16, block_size) -# ── Scale Factor Assembly ───────────────────────────────────────────── +# ── Tensor Layout Conversion ─────────────────────────────────────────── + +def make_b_k_major(b_tensor): + """Convert B tensor from N-major to K-major (required by kernel). + + Input: (E, N, K_packed) or (E, K_packed, N) + Output: (E, K_packed, N) contiguous in K-major order + + If already K-major (stride[2] == 1), returns as-is. + """ + if b_tensor.stride(2) == 1: + return b_tensor.contiguous() + return b_tensor.permute(0, 2, 1).contiguous() + + +def compute_expert_offsets(tokens_per_expert, num_experts, device="cuda"): + """Compute cumulative token offsets for grouped GEMM. + + Args: + tokens_per_expert: list of int, number of tokens per expert + num_experts: int, total number of experts + device: torch device + + Returns: + (num_experts + 1,) int32 tensor of cumulative offsets + """ + offsets = [0] + for t in tokens_per_expert: + offsets.append(offsets[-1] + t) + return torch.tensor(offsets, dtype=torch.int32, device=device) + + +# ── Scale Assembly ───────────────────────────────────────────────────── + +from cutedsl.kernel.moe.torch_scaled_grouped_mm import ( + assemble_raw_scales_2d3d_2d_side, + assemble_raw_scales_2d3d_3d_side, + pad_and_swizzle_single, +) + def assemble_scales_2d_side(raw_scales): """Assemble activation scale factors for the 2Dx3D scenario. @@ -243,137 +251,33 @@ def assemble_scales_3d_side(raw_scales): Args: raw_scales: list of (K_sf, N) float8_e4m3fn tensors, one per expert - NOTE: These will be transposed to (N, K_sf) before swizzling, - since the kernel expects N as the non-K dimension. Returns: Assembled and swizzled scale tensor """ - # Kernel expects (N, K_sf) — transpose before swizzling - transposed = [sf.T.contiguous() for sf in raw_scales] + # The 3D side expects (N, K_sf) so transpose if needed + transposed = [] + for s in raw_scales: + if s.shape[0] < s.shape[1]: + # Likely (K_sf, N) — transpose to (N, K_sf) + transposed.append(s.permute(1, 0).contiguous()) + else: + transposed.append(s) return assemble_raw_scales_2d3d_3d_side(transposed) -# ── Tensor Layout Conversion ────────────────────────────────────────── - -def make_b_k_major(b_tensor): - """Convert B tensor from N-major to K-major layout. - - The kernel expects B with stride (E*K*N, 1, K) — K is contiguous. - torch.stack produces stride (E*K*N, N, 1) — N is contiguous. - - Args: - b_tensor: (experts, K_packed, N_packed) float4_e2m1fn_x2, N-major - - Returns: - Same shape, K-major strides - """ - return b_tensor.permute(0, 2, 1).contiguous().permute(0, 2, 1) - - -def compute_expert_offsets(tokens_per_expert, num_experts, device="cuda"): - """Compute cumulative token offsets for the grouped GEMM. - - Args: - tokens_per_expert: list of int, one per expert - - Returns: - offs: (num_experts,) int32 — cumulative sum - """ - offs = torch.tensor( - [sum(tokens_per_expert[:e+1]) for e in range(num_experts)], - dtype=torch.int32, device=device, - ) - return offs - - -# ── Compiled Kernel Cache ───────────────────────────────────────────── +# ── Kernel Launch ───────────────────────────────────────────────────── +# Cache compiled kernels by (num_experts, device, K, N) _compiled_kernel_cache = {} -def _get_compiled_kernel(num_experts, device, mma_tiler_mn, cluster_shape_mn, K_packed, N_packed): - """Get or compile the CuTeDSL grouped GEMM kernel (cached by shape config). - - The kernel compilation is deterministic for a given config, so we cache it - to avoid recompiling on every forward call. K_packed and N_packed are needed - because the compiled kernel's TMA descriptors are sized from compilation shapes. - """ - cache_key = (num_experts, str(device), mma_tiler_mn, cluster_shape_mn, K_packed, N_packed) - if cache_key in _compiled_kernel_cache: - return _compiled_kernel_cache[cache_key] - - kernel = ScaledGroupedGemmKernel( - scenario="2Dx3D", - sf_vec_size=SF_VEC_SIZE, - accumulate_on_output=False, - separate_tensormap_init=True, - consistent_token_padding=False, - mma_tiler_mnk=(*mma_tiler_mn, 256), - cluster_shape_mnk=(*cluster_shape_mn, 1), - ) - - import cuda.bindings.driver as cuda - cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1] - max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size) - stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - - # Dummy tensors for compilation — use actual K/N dimensions - # (TMA descriptors are sized from compilation shapes) - tokens = 1 - dummy_a = torch.zeros(tokens, K_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2) - dummy_b = torch.zeros(num_experts, K_packed, N_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2) - dummy_sfa = torch.zeros(1, 1, dtype=torch.float16, device=device).to(torch.float8_e4m3fn) - dummy_sfb = torch.zeros(1, 1, dtype=torch.float16, device=device).to(torch.float8_e4m3fn) - dummy_c = torch.zeros(tokens, N_packed, dtype=torch.bfloat16, device=device) - dummy_offs = torch.zeros(num_experts, dtype=torch.int32, device=device) - ws_size = kernel.get_workspace_size(num_experts) - dummy_ws = torch.full((ws_size,), 255, dtype=torch.uint8, device=device) - dummy_gsa = torch.ones(num_experts, dtype=torch.float32, device=device) - dummy_gsb = torch.ones(num_experts, dtype=torch.float32, device=device) - - def to_cute(t): - ct = cutlass_torch.from_dlpack(t) - return ct.mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(t)) - - compiled = cute.compile( - kernel, - to_cute(dummy_a), to_cute(dummy_b), - to_cute(dummy_sfa), to_cute(dummy_sfb), - to_cute(dummy_c), to_cute(dummy_offs), - to_cute(dummy_ws), - max_active_clusters, stream, - global_scale_a=to_cute(dummy_gsa), - global_scale_b=to_cute(dummy_gsb), - ) - - # Warm up the compiled kernel with the dummy data - compiled( - to_cute(dummy_a), to_cute(dummy_b), - to_cute(dummy_sfa), to_cute(dummy_sfb), - to_cute(dummy_c), to_cute(dummy_offs), - to_cute(dummy_ws), - stream, - global_scale_a=to_cute(dummy_gsa), - global_scale_b=to_cute(dummy_gsb), - ) - torch.cuda.synchronize() - - # Free dummies - del dummy_a, dummy_b, dummy_sfa, dummy_sfb, dummy_c, dummy_offs, dummy_ws, dummy_gsa, dummy_gsb - - _compiled_kernel_cache[cache_key] = (compiled, kernel, max_active_clusters) - return compiled, kernel, max_active_clusters - - -# ── Kernel Launch ───────────────────────────────────────────────────── - def run_nvfp4_grouped_gemm( mat_a, # (tokens_sum, K_packed) float4_e2m1fn_x2 mat_b, # (experts, K_packed, N_packed) float4_e2m1fn_x2, K-major scale_a, # assembled 2D side (padded + swizzled) scale_b, # assembled 3D side (padded + swizzled) - expert_offsets, # (experts,) int32 cumulative token offsets + expert_offsets, # (experts+1,) int32 cumulative token offsets global_scale_a=None, # (experts,) float32 global_scale_b=None, # (experts,) float32 mma_tiler_mn=(128, 128), @@ -383,22 +287,19 @@ def run_nvfp4_grouped_gemm( 2Dx3D: A(tokens, K) x B(experts, K, N) -> C(tokens, N) - CUDA-graph-compatible: uses cached compiled kernel, no synchronize(), - no cute.compile() in the forward path. + Compiles with real tensors (not dummy shapes) because the CuTeDSL + kernel's TMA descriptors are sized from compilation-time tensor shapes. + Compilation is cached per (num_experts, K, N) and happens once. + + The forward call (after compilation) is cudagraph-safe. """ num_experts = mat_b.shape[0] - K_packed = mat_a.shape[1] - N_packed = mat_b.shape[2] # N dimension (logical, not packed — float4_e2m1fn_x2 packs along K, not N) - n_dim = N_packed + n_dim = mat_b.shape[2] # N dimension tokens_sum = mat_a.shape[0] device = mat_a.device out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=device) - compiled, kernel, max_active_clusters = _get_compiled_kernel( - num_experts, device, mma_tiler_mn, cluster_shape_mn, K_packed, N_packed - ) - # Convert to CuTe tensors with dynamic layout def to_cute(t): ct = cutlass_torch.from_dlpack(t) @@ -411,22 +312,64 @@ def run_nvfp4_grouped_gemm( c_c = to_cute(out) offs_c = to_cute(expert_offsets) - workspace_size = kernel.get_workspace_size(num_experts) - workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=device) - ws_c = to_cute(workspace) - + workspace_size = 0 gsa_c = to_cute(global_scale_a) if global_scale_a is not None else None gsb_c = to_cute(global_scale_b) if global_scale_b is not None else None import cuda.bindings.driver as cuda stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + # Check cache — compile with real tensors on first use per (experts, K, N) + K_packed = mat_a.shape[1] + N_packed = mat_b.shape[2] + cache_key = (num_experts, str(device), mma_tiler_mn, cluster_shape_mn, K_packed, N_packed) + + if cache_key not in _compiled_kernel_cache: + kernel = ScaledGroupedGemmKernel( + scenario="2Dx3D", + sf_vec_size=SF_VEC_SIZE, + accumulate_on_output=False, + separate_tensormap_init=True, + consistent_token_padding=False, + mma_tiler_mnk=(*mma_tiler_mn, 256), + cluster_shape_mnk=(*cluster_shape_mn, 1), + ) + + cluster_size = cluster_shape_mn[0] * cluster_shape_mn[1] + max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size) + + workspace_size = kernel.get_workspace_size(num_experts) + workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=device) + ws_c = to_cute(workspace) + + compiled = cute.compile( + kernel, a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, + max_active_clusters, stream, + global_scale_a=gsa_c, global_scale_b=gsb_c, + ) + + # Warm up with real data + compiled( + a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, + stream, + global_scale_a=gsa_c, global_scale_b=gsb_c, + ) + torch.cuda.synchronize() + + _compiled_kernel_cache[cache_key] = (compiled, kernel, max_active_clusters) + + compiled, kernel, max_active_clusters = _compiled_kernel_cache[cache_key] + + # Allocate workspace if not already done during compilation + if workspace_size == 0: + workspace_size = kernel.get_workspace_size(num_experts) + workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=device) + ws_c = to_cute(workspace) + compiled( a_c, b_c, sfa_c, sfb_c, c_c, offs_c, ws_c, stream, global_scale_a=gsa_c, global_scale_b=gsb_c, ) - # NOTE: No torch.cuda.synchronize() here — cudagraph capture forbids it. - # The caller is responsible for any needed synchronization. return out