diff --git a/cutedsl/bridge.py b/cutedsl/bridge.py index bd9969e5..09232f7d 100644 --- a/cutedsl/bridge.py +++ b/cutedsl/bridge.py @@ -298,7 +298,7 @@ def _get_compiled_kernel(num_experts, device, mma_tiler_mn, cluster_shape_mn): The kernel compilation is deterministic for a given (num_experts, device, tiler, cluster) config, so we cache it to avoid recompiling on every forward call. """ - cache_key = (num_experts, str(device), mma_tiler_mn, cluster_shape_mn) + cache_key = (num_experts, str(device), mma_tiler_mn, cluster_shape_mn, K_packed, N_packed) if cache_key in _compiled_kernel_cache: return _compiled_kernel_cache[cache_key] @@ -317,10 +317,11 @@ def _get_compiled_kernel(num_experts, device, mma_tiler_mn, cluster_shape_mn): max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - # We need dummy tensors to compile against — use minimal sizes - # The compiled kernel works with dynamic shapes via mark_layout_dynamic - K_packed = 256 # minimal - N_packed = 256 + # We need dummy tensors to compile against — shapes must match runtime tensors + # The compiled kernel uses mark_layout_dynamic but TMA descriptors + # are sized based on the compilation shapes + K_packed = mat_a.shape[1] # actual K packed dimension + N_packed = mat_b.shape[2] # actual N dimension tokens = 1 dummy_a = torch.zeros(tokens, K_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2) dummy_b = torch.zeros(num_experts, K_packed, N_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2)