fix: compile CuTeDSL kernel with actual tensor shapes, not dummy 256x256

The compiled kernel's TMA descriptors are sized based on compilation
shapes. Using dummy 256x256 shapes caused wrong memory access patterns
for the real 3584x6144 data. Now uses actual K_packed and N_packed
from the runtime tensors.
This commit is contained in:
2026-05-16 19:58:13 +00:00
parent cc75a55bd9
commit ecc7b83334

View File

@@ -298,7 +298,7 @@ def _get_compiled_kernel(num_experts, device, mma_tiler_mn, cluster_shape_mn):
The kernel compilation is deterministic for a given (num_experts, device, tiler, cluster)
config, so we cache it to avoid recompiling on every forward call.
"""
cache_key = (num_experts, str(device), mma_tiler_mn, cluster_shape_mn)
cache_key = (num_experts, str(device), mma_tiler_mn, cluster_shape_mn, K_packed, N_packed)
if cache_key in _compiled_kernel_cache:
return _compiled_kernel_cache[cache_key]
@@ -317,10 +317,11 @@ def _get_compiled_kernel(num_experts, device, mma_tiler_mn, cluster_shape_mn):
max_active_clusters = utils.HardwareInfo().get_max_active_clusters(cluster_size)
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
# We need dummy tensors to compile against — use minimal sizes
# The compiled kernel works with dynamic shapes via mark_layout_dynamic
K_packed = 256 # minimal
N_packed = 256
# We need dummy tensors to compile against — shapes must match runtime tensors
# The compiled kernel uses mark_layout_dynamic but TMA descriptors
# are sized based on the compilation shapes
K_packed = mat_a.shape[1] # actual K packed dimension
N_packed = mat_b.shape[2] # actual N dimension
tokens = 1
dummy_a = torch.zeros(tokens, K_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2)
dummy_b = torch.zeros(num_experts, K_packed, N_packed, dtype=torch.uint8, device=device).view(torch.float4_e2m1fn_x2)