From b95f9eb446fe6e30a59cfc344619420bae41ffa9 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 12 May 2026 17:11:19 +0000 Subject: [PATCH] revert: remove SMEM warp transpose (deadlock in elect_one_sync, not needed with transform_sf_token_idx) --- .../impls/sm100_fp8_nvfp4_mega_moe.cuh | 20 ++----------------- 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh index 503c645..49aaf35 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh @@ -882,29 +882,15 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // NVFP4: group=16 → 2 SF K-columns per BLOCK_K (128/16/4=2) // Each UTCCP call moves 128 int32s → 4 TMEM cols // We need 2 UTCCP calls per SF: one per K-column + // NOTE: No SMEM warp transpose needed — transform_sf_token_idx + // pre-arranges the data in the correct UTCCP layout via global memory using cute_utccp_t = cute::SM100_UTCCP_4x32dp128bit_2cta; - // NVFP4: UTCCP requires SMEM transpose for packed UE4M3 scales - // The packed format (4 UE4M3/int32) must be transposed so UTCCP - // distributes K-group data to the right TMEM columns for scale_vec::4X - auto utccp_required_smem_warp_transpose = [&](uint32_t* smem_ptr) { - uint32_t values[4]; - #pragma unroll - for (uint32_t i = 0; i < 4; ++ i) - values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); - __syncwarp(); - #pragma unroll - for (uint32_t i = 0; i < 4; ++ i) - ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); - }; - #pragma unroll for (uint32_t kk = 0; kk < 2; ++kk) { #pragma unroll for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { auto smem_ptr = smem_sfa[stage_idx] + kk * SF_BLOCK_M + i * kNumUTCCPAlignedElems; - utccp_required_smem_warp_transpose(smem_ptr); - cutlass::arch::fence_view_async_shared(); mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + (kk * (SF_BLOCK_M / kNumUTCCPAlignedElems) + i) * 4); } @@ -914,8 +900,6 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, #pragma unroll for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { auto smem_ptr = smem_sfb[stage_idx] + kk * BLOCK_N + i * kNumUTCCPAlignedElems; - utccp_required_smem_warp_transpose(smem_ptr); - cutlass::arch::fence_view_async_shared(); mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + (kk * (SF_BLOCK_N / kNumUTCCPAlignedElems) + i) * 4); }