From ecc7b833341237da38a3f7e6b4752ea4617606f0 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 16 May 2026 19:58:13 +0000 Subject: [PATCH] fix: compile CuTeDSL kernel with actual tensor shapes, not dummy 256x256 The compiled kernel's TMA descriptors are sized based on compilation shapes. Using dummy 256x256 shapes caused wrong memory access patterns for the real 3584x6144 data. Now uses actual K_packed and N_packed from the runtime tensors. --- cutedsl/bridge.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/cutedsl/bridge.py b/cutedsl/bridge.py index bd9969e5..09232f7d 100644 --- a/cutedsl/bridge.py +++ b/cutedsl/bridge.py @@ -298,7 +298,7 @@ def _get_compiled_kernel(num_experts, device, mma_tiler_mn, cluster_shape_mn): The kernel compilation is deterministic for a given (num_experts, device, tiler, cluster) config, so we cache it to avoid recompiling on every forward call. """ - cache_key = (num_experts, str(device), mma_tiler_mn, cluster_shape_mn) + cache_key = (num_experts, str(device), mma_tiler_mn, cluster_shape_mn, K_packed, N_packed) if cache_key in _compiled_kernel_cache: return _compiled_kernel_cache[cache_key] @@ -317,10 +317,11 @@ def _get_compiled_kernel(num_experts, device, mma_tiler_mn, cluster_shape_mn): max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - # We need dummy tensors to compile against — use minimal sizes - # The compiled kernel works with dynamic shapes via mark_layout_dynamic - K_packed = 256 # minimal - N_packed = 256 + # We need dummy tensors to compile against — shapes must match runtime tensors + # The compiled kernel uses mark_layout_dynamic but TMA descriptors + # are sized based on the compilation shapes + K_packed = mat_a.shape[1] # actual K packed dimension + N_packed = mat_b.shape[2] # actual N dimension tokens = 1 dummy_a = torch.zeros(tokens, K_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2) dummy_b = torch.zeros(num_experts, K_packed, N_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2)