diff --git a/cutedsl/bridge.py b/cutedsl/bridge.py index 3ea58ebb..a061430a 100644 --- a/cutedsl/bridge.py +++ b/cutedsl/bridge.py @@ -7,75 +7,48 @@ the ScaledGroupedGemmKernel expects: - Scale factor assembly (padding + swizzle) - B tensor K-major stride conversion - Expert offset computation - -CUDA-graph-compatible: no .item() calls, no torch.cuda.synchronize() -in the forward path, no Python control flow on GPU data. - -Compilation uses real tensors (not dummy shapes) because the CuTeDSL -kernel's TMA descriptors are sized from compilation-time tensor shapes. -This happens once during warmup, outside cudagraph capture. """ import math -import threading - import torch import cutlass import cutlass.cute as cute import cutlass.torch as cutlass_torch +import cutlass.utils as utils from cutedsl.kernel.moe.torch_scaled_grouped_mm import ( ScaledGroupedGemmKernel, - utils, - ceil_div, + pad_and_swizzle_single, + assemble_raw_scales_2d3d_2d_side, + assemble_raw_scales_2d3d_3d_side, + cat_byte_reinterpretable_tensors, + stack_byte_reinterpretable_tensors, ) # ── Constants ────────────────────────────────────────────────────────── +E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] SF_VEC_SIZE = 16 # NVFP4 block size -E2M1_MAGNITUDES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] -# Cached LUT for E2M1 quantization (created once per device, cudagraph-safe) -_NVFP4_STEP_LUT_CACHE = {} -_NVFP4_STEP_LUT_LOCK = threading.Lock() - - -def _get_step_to_idx_lut(device): - """Get or create the E2M1 step-to-index LUT for the given device. - - Cached per device to avoid CPU→CUDA copies during cudagraph capture. - """ - with _NVFP4_STEP_LUT_LOCK: - if device not in _NVFP4_STEP_LUT_CACHE: - _NVFP4_STEP_LUT_CACHE[device] = torch.as_tensor( - [0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 7, 7], - dtype=torch.int8, device=device, - ) - return _NVFP4_STEP_LUT_CACHE[device] +def ceil_div(a, b): + return (a + b - 1) // b def round_up(a, b): - return ((a + b - 1) // b) * b + return ceil_div(a, b) * b -# ── Quantization ─────────────────────────────────────────────────────── +# ── Quantization ────────────────────────────────────────────────────── def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE): """Quantize BF16 tensor to NVFP4. - - NOTE: This function is NOT cudagraph-safe because it uses .max() - which forces a CPU-GPU sync. It should only be called during - weight preparation (offline), NOT during the forward pass. - For activation quantization during forward, use - quantize_activation_nvfp4() instead (cudagraph-safe, fixed global scale). - Args: x_bf16: (..., D) BF16 tensor Returns: - x_fp4: (..., D//2) float4_e2m1fn_x2 - x_sf: (..., D//16) float8_e4m3fn + x_fp4: (..., D//2) float4_e2m1fn_x2 — native PyTorch FP4 + x_sf: (..., D//16) float8_e4m3fn — block scales global_scale: float32 scalar """ x_f32 = x_bf16.float() @@ -94,15 +67,15 @@ def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE): block_amax = x_reshaped.abs().amax(dim=-1).clamp(min=1e-8) block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn) - # Nearest E2M1 — memory-efficient clamp approach + # Nearest E2M1 block_sf_expanded = block_scale.float().unsqueeze(-1) x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8) + + magnitudes = torch.tensor(E2M1_MAGNITUDES, dtype=torch.float32, device=x_bf16.device) signs = torch.sign(x_scaled) - abs_scaled = x_scaled.abs().clamp(max=6.0) - - half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8) - step_to_idx = _get_step_to_idx_lut(x_bf16.device) - indices = step_to_idx[half_steps.long()] + abs_scaled = x_scaled.abs().unsqueeze(-1) + distances = (abs_scaled - magnitudes).abs() + indices = distances.argmin(dim=-1) nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8) even = nibbles[..., ::2] @@ -119,70 +92,12 @@ def quantize_to_nvfp4(x_bf16, block_size=SF_VEC_SIZE): return x_fp4, block_scale, global_scale -def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE): - """Quantize BF16 activation tensor to NVFP4 (cudagraph-safe). - - Unlike quantize_to_nvfp4(), this takes a pre-computed global_scale - instead of computing it via .max() (which forces CPU-GPU sync). - The global_scale should be computed once during warmup and cached. - - All operations are pure GPU with no CPU-GPU syncs. - - Args: - x_bf16: (..., D) BF16 tensor - global_scale: float32 scalar (pre-computed, NOT from .max()) - block_size: NVFP4 block size - - Returns: - x_fp4: (..., D//2) float4_e2m1fn_x2 - x_sf: (..., D//16) float8_e4m3fn - """ - x_f32 = x_bf16.float() - x_norm = x_f32 / global_scale - - last_dim = x_norm.shape[-1] - n_blocks = ceil_div(last_dim, block_size) - - if last_dim % block_size != 0: - pad_size = n_blocks * block_size - last_dim - x_norm = torch.nn.functional.pad(x_norm, (0, pad_size)) - - x_reshaped = x_norm.reshape(*x_norm.shape[:-1], n_blocks, block_size) - block_amax = x_reshaped.abs().amax(dim=-1).clamp(min=1e-8) - block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn) - - block_sf_expanded = block_scale.float().unsqueeze(-1) - x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8) - signs = torch.sign(x_scaled) - abs_scaled = x_scaled.abs().clamp(max=6.0) - - half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8) - step_to_idx = _get_step_to_idx_lut(x_bf16.device) - indices = step_to_idx[half_steps.long()] - - nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8) - even = nibbles[..., ::2] - odd = nibbles[..., 1::2] - packed = (odd << 4) | even - - packed_shape = list(x_bf16.shape) - packed_shape[-1] = last_dim // 2 - x_fp4 = packed.view(torch.float4_e2m1fn_x2).reshape(packed_shape) - - sf_shape = list(x_bf16.shape[:-1]) + [n_blocks] - block_scale = block_scale.reshape(sf_shape) - - return x_fp4, block_scale - - def quantize_weight_to_nvfp4(w_bf16, block_size=SF_VEC_SIZE): """Quantize BF16 weight matrix to NVFP4. The weight is (K, N) where K is the input dim (packed dimension). Block scales are computed along K (dim 0). - NOTE: NOT cudagraph-safe — uses .max() for global scale. - Args: w_bf16: (K, N) BF16 weight matrix @@ -203,67 +118,27 @@ def quantize_weight_to_nvfp4(w_bf16, block_size=SF_VEC_SIZE): w_reshaped = w_norm.reshape(k_blocks, block_size, N) w_block_amax = w_reshaped.abs().amax(dim=1).clamp(min=1e-8) - block_scale = (w_block_amax / 6.0).to(torch.float8_e4m3fn) + w_sf = (w_block_amax / 6.0).to(torch.float8_e4m3fn) - block_sf_expanded = block_scale.float().unsqueeze(1) - w_scaled = w_reshaped / block_sf_expanded.clamp(min=1e-8) + w_block_sf = w_sf.float().unsqueeze(1) + w_scaled = w_reshaped / w_block_sf.clamp(min=1e-8) - # Nearest E2M1 — memory-efficient clamp approach + magnitudes = torch.tensor(E2M1_MAGNITUDES, dtype=torch.float32, device=w_bf16.device) signs = torch.sign(w_scaled) - abs_scaled = w_scaled.abs().clamp(max=6.0) - - half_steps = (abs_scaled * 2.0).round().clamp(0, 12).to(torch.int8) - step_to_idx = _get_step_to_idx_lut(w_bf16.device) - indices = step_to_idx[half_steps.long()] - + abs_scaled = w_scaled.abs().unsqueeze(-1) + distances = (abs_scaled - magnitudes).abs() + indices = distances.argmin(dim=-1) nibbles = torch.where(signs < 0, indices + 8, indices).to(torch.uint8) + even = nibbles[:, ::2, :] odd = nibbles[:, 1::2, :] packed = (odd << 4) | even w_fp4 = packed.reshape(K // 2, N).view(torch.float4_e2m1fn_x2) - return w_fp4, block_scale, global_scale + return w_fp4, w_sf, global_scale -# ── Tensor Layout Conversion ─────────────────────────────────────────── - -def make_b_k_major(b_tensor): - """Convert B tensor from N-major to K-major layout. - - The kernel expects B with stride (E*K*N, 1, K) — K is contiguous. - torch.stack produces stride (E*K*N, N, 1) — N is contiguous. - - double-permute trick: transpose, make contiguous, transpose back. - Same shape, but K-contiguous memory layout. - """ - return b_tensor.permute(0, 2, 1).contiguous().permute(0, 2, 1) - - -def compute_expert_offsets(tokens_per_expert, num_experts, device="cuda"): - """Compute cumulative token offsets for grouped GEMM. - - Args: - tokens_per_expert: list of int, number of tokens per expert - num_experts: int, total number of experts - device: torch device - - Returns: - (num_experts + 1,) int32 tensor of cumulative offsets - """ - offsets = [0] - for t in tokens_per_expert: - offsets.append(offsets[-1] + t) - return torch.tensor(offsets, dtype=torch.int32, device=device) - - -# ── Scale Assembly ───────────────────────────────────────────────────── - -from cutedsl.kernel.moe.torch_scaled_grouped_mm import ( - assemble_raw_scales_2d3d_2d_side, - assemble_raw_scales_2d3d_3d_side, - pad_and_swizzle_single, -) - +# ── Scale Factor Assembly ───────────────────────────────────────────── def assemble_scales_2d_side(raw_scales): """Assemble activation scale factors for the 2Dx3D scenario. @@ -282,33 +157,58 @@ def assemble_scales_3d_side(raw_scales): Args: raw_scales: list of (K_sf, N) float8_e4m3fn tensors, one per expert + NOTE: These will be transposed to (N, K_sf) before swizzling, + since the kernel expects N as the non-K dimension. Returns: Assembled and swizzled scale tensor """ - # The 3D side expects (N, K_sf) so transpose if needed - transposed = [] - for s in raw_scales: - if s.shape[0] < s.shape[1]: - # Likely (K_sf, N) — transpose to (N, K_sf) - transposed.append(s.permute(1, 0).contiguous()) - else: - transposed.append(s) + # Kernel expects (N, K_sf) — transpose before swizzling + transposed = [sf.T.contiguous() for sf in raw_scales] return assemble_raw_scales_2d3d_3d_side(transposed) +# ── Tensor Layout Conversion ────────────────────────────────────────── + +def make_b_k_major(b_tensor): + """Convert B tensor from N-major to K-major layout. + + The kernel expects B with stride (E*K*N, 1, K) — K is contiguous. + torch.stack produces stride (E*K*N, N, 1) — N is contiguous. + + Args: + b_tensor: (experts, K_packed, N_packed) float4_e2m1fn_x2, N-major + + Returns: + Same shape, K-major strides + """ + return b_tensor.permute(0, 2, 1).contiguous().permute(0, 2, 1) + + +def compute_expert_offsets(tokens_per_expert, num_experts, device="cuda"): + """Compute cumulative token offsets for the grouped GEMM. + + Args: + tokens_per_expert: list of int, one per expert + + Returns: + offs: (num_experts,) int32 — cumulative sum + """ + offs = torch.tensor( + [sum(tokens_per_expert[:e+1]) for e in range(num_experts)], + dtype=torch.int32, device=device, + ) + return offs + + # ── Kernel Launch ───────────────────────────────────────────────────── -# Cache compiled kernels by (num_experts, device, K, N) -_compiled_kernel_cache = {} - - def run_nvfp4_grouped_gemm( mat_a, # (tokens_sum, K_packed) float4_e2m1fn_x2 mat_b, # (experts, K_packed, N_packed) float4_e2m1fn_x2, K-major scale_a, # assembled 2D side (padded + swizzled) scale_b, # assembled 3D side (padded + swizzled) - expert_offsets, # (experts+1,) int32 cumulative token offsets + expert_offsets, # (experts,) int32 cumulative token offsets global_scale_a=None, # (experts,) float32 global_scale_b=None, # (experts,) float32 mma_tiler_mn=(128, 128), @@ -317,19 +217,12 @@ def run_nvfp4_grouped_gemm( """Run the CuTeDSL NVFP4 scaled grouped GEMM. 2Dx3D: A(tokens, K) x B(experts, K, N) -> C(tokens, N) - - Compiles with real tensors each call. The CuTeDSL kernel's TMA - descriptors are bound to the compilation-time tensor addresses, - so caching across different tensor allocations produces wrong results. - - The forward call (after compilation) is cudagraph-safe. """ num_experts = mat_b.shape[0] - n_dim = mat_b.shape[2] + n_dim = mat_b.shape[2] # packed N (in float4 elements) tokens_sum = mat_a.shape[0] - device = mat_a.device - out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=device) + out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device) kernel = ScaledGroupedGemmKernel( scenario="2Dx3D", @@ -354,7 +247,7 @@ def run_nvfp4_grouped_gemm( offs_c = to_cute(expert_offsets) workspace_size = kernel.get_workspace_size(num_experts) - workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=device) + workspace = torch.full((workspace_size,), 255, dtype=torch.uint8, device=mat_a.device) ws_c = to_cute(workspace) gsa_c = to_cute(global_scale_a) if global_scale_a is not None else None @@ -376,5 +269,6 @@ def run_nvfp4_grouped_gemm( stream, global_scale_a=gsa_c, global_scale_b=gsb_c, ) + torch.cuda.synchronize() return out