revert: remove SMEM warp transpose (deadlock in elect_one_sync, not needed with transform_sf_token_idx)
This commit is contained in:
@@ -882,29 +882,15 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
// NVFP4: group=16 → 2 SF K-columns per BLOCK_K (128/16/4=2)
|
||||
// Each UTCCP call moves 128 int32s → 4 TMEM cols
|
||||
// We need 2 UTCCP calls per SF: one per K-column
|
||||
// NOTE: No SMEM warp transpose needed — transform_sf_token_idx
|
||||
// pre-arranges the data in the correct UTCCP layout via global memory
|
||||
using cute_utccp_t = cute::SM100_UTCCP_4x32dp128bit_2cta;
|
||||
|
||||
// NVFP4: UTCCP requires SMEM transpose for packed UE4M3 scales
|
||||
// The packed format (4 UE4M3/int32) must be transposed so UTCCP
|
||||
// distributes K-group data to the right TMEM columns for scale_vec::4X
|
||||
auto utccp_required_smem_warp_transpose = [&](uint32_t* smem_ptr) {
|
||||
uint32_t values[4];
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 4; ++ i)
|
||||
values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx);
|
||||
__syncwarp();
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 4; ++ i)
|
||||
ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]);
|
||||
};
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t kk = 0; kk < 2; ++kk) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
|
||||
auto smem_ptr = smem_sfa[stage_idx] + kk * SF_BLOCK_M + i * kNumUTCCPAlignedElems;
|
||||
utccp_required_smem_warp_transpose(smem_ptr);
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
|
||||
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + (kk * (SF_BLOCK_M / kNumUTCCPAlignedElems) + i) * 4);
|
||||
}
|
||||
@@ -914,8 +900,6 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
|
||||
auto smem_ptr = smem_sfb[stage_idx] + kk * BLOCK_N + i * kNumUTCCPAlignedElems;
|
||||
utccp_required_smem_warp_transpose(smem_ptr);
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
|
||||
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + (kk * (SF_BLOCK_N / kNumUTCCPAlignedElems) + i) * 4);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user