fix: add UTCCP SMEM warp transpose for NVFP4 packed UE4M3 scales

This commit is contained in:
2026-05-12 16:48:06 +00:00
parent 8a53228745
commit 54a7de03a0

View File

@@ -884,11 +884,27 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
// We need 2 UTCCP calls per SF: one per K-column
using cute_utccp_t = cute::SM100_UTCCP_4x32dp128bit_2cta;
// NVFP4: UTCCP requires SMEM transpose for packed UE4M3 scales
// The packed format (4 UE4M3/int32) must be transposed so UTCCP
// distributes K-group data to the right TMEM columns for scale_vec::4X
auto utccp_required_smem_warp_transpose = [&](uint32_t* smem_ptr) {
uint32_t values[4];
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i)
values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx);
__syncwarp();
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i)
ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]);
};
#pragma unroll
for (uint32_t kk = 0; kk < 2; ++kk) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfa[stage_idx] + kk * SF_BLOCK_M + i * kNumUTCCPAlignedElems;
utccp_required_smem_warp_transpose(smem_ptr);
cutlass::arch::fence_view_async_shared();
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + (kk * (SF_BLOCK_M / kNumUTCCPAlignedElems) + i) * 4);
}
@@ -898,6 +914,8 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y,
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfb[stage_idx] + kk * BLOCK_N + i * kNumUTCCPAlignedElems;
utccp_required_smem_warp_transpose(smem_ptr);
cutlass::arch::fence_view_async_shared();
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + (kk * (SF_BLOCK_N / kNumUTCCPAlignedElems) + i) * 4);
}