From e82c4139da2e3174ff8e0e122d550e4111053369 Mon Sep 17 00:00:00 2001 From: yukuai Date: Mon, 23 Jun 2025 17:13:36 +0800 Subject: [PATCH] Revert "Fixed the bug in get_swizzle_mode function related to elem_size setting. (#115)" This reverts commit ac428e25e0dbbb44a302ccf7c23a724206addbb3. This PR causes wgrad to hang during testing. Revert it until we resolve the issue --- deep_gemm/jit_kernels/gemm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index f85da84..574f821 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -17,8 +17,8 @@ def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: in return divisible and num_sms % num_tma_multicast == 0 -def get_swizzle_mode(block_n: int, is_fp32_out: bool) -> int: - elem_size = 4 if is_fp32_out else 2 +def get_swizzle_mode(block_n: int) -> int: + elem_size = 2 for mode_bytes in (128, 64, 32): if (block_n * elem_size) % mode_bytes == 0: return mode_bytes @@ -38,7 +38,7 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k assert block_k == 128 # Try swizzle first, as it does not waste shared memory - swizzle_mode = get_swizzle_mode(block_n, is_fp32_out) + swizzle_mode = get_swizzle_mode(block_n) block_n_padding = get_block_n_padding_for_smem_d( block_n) if swizzle_mode == 0 else 0