revert: remove SMEM warp transpose (deadlock in elect_one_sync, not needed with transform_sf_token_idx)

This commit is contained in:
2026-05-12 17:11:19 +00:00
parent 54a7de03a0
commit b95f9eb446

View File

@@ -882,29 +882,15 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// NVFP4: group=16 → 2 SF K-columns per BLOCK_K (128/16/4=2)
// Each UTCCP call moves 128 int32s → 4 TMEM cols
// We need 2 UTCCP calls per SF: one per K-column
// NOTE: No SMEM warp transpose needed — transform_sf_token_idx
// pre-arranges the data in the correct UTCCP layout via global memory
using cute_utccp_t = cute::SM100_UTCCP_4x32dp128bit_2cta;
// NVFP4: UTCCP requires SMEM transpose for packed UE4M3 scales
// The packed format (4 UE4M3/int32) must be transposed so UTCCP
// distributes K-group data to the right TMEM columns for scale_vec::4X
auto utccp_required_smem_warp_transpose = [&](uint32_t* smem_ptr) {
uint32_t values[4];
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i)
values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx);
__syncwarp();
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i)
ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]);
};
#pragma unroll
for (uint32_t kk = 0; kk < 2; ++kk) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfa[stage_idx] + kk * SF_BLOCK_M + i * kNumUTCCPAlignedElems;
utccp_required_smem_warp_transpose(smem_ptr);
cutlass::arch::fence_view_async_shared();
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + (kk * (SF_BLOCK_M / kNumUTCCPAlignedElems) + i) * 4);
}
@@ -914,8 +900,6 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfb[stage_idx] + kk * BLOCK_N + i * kNumUTCCPAlignedElems;
utccp_required_smem_warp_transpose(smem_ptr);
cutlass::arch::fence_view_async_shared();
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + (kk * (SF_BLOCK_N / kNumUTCCPAlignedElems) + i) * 4);
}