From 54a7de03a08726d90e0ec1afe03c3d6d39daa764 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 12 May 2026 16:48:06 +0000 Subject: [PATCH] fix: add UTCCP SMEM warp transpose for NVFP4 packed UE4M3 scales --- .../impls/sm100_fp8_nvfp4_mega_moe.cuh | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh index 05b43c3..503c645 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_nvfp4_mega_moe.cuh @@ -884,11 +884,27 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, // We need 2 UTCCP calls per SF: one per K-column using cute_utccp_t = cute::SM100_UTCCP_4x32dp128bit_2cta; + // NVFP4: UTCCP requires SMEM transpose for packed UE4M3 scales + // The packed format (4 UE4M3/int32) must be transposed so UTCCP + // distributes K-group data to the right TMEM columns for scale_vec::4X + auto utccp_required_smem_warp_transpose = [&](uint32_t* smem_ptr) { + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + #pragma unroll for (uint32_t kk = 0; kk < 2; ++kk) { #pragma unroll for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { auto smem_ptr = smem_sfa[stage_idx] + kk * SF_BLOCK_M + i * kNumUTCCPAlignedElems; + utccp_required_smem_warp_transpose(smem_ptr); + cutlass::arch::fence_view_async_shared(); mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + (kk * (SF_BLOCK_M / kNumUTCCPAlignedElems) + i) * 4); } @@ -898,6 +914,8 @@ sm100_fp8_nvfp4_mega_moe_impl(void* y, #pragma unroll for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { auto smem_ptr = smem_sfb[stage_idx] + kk * BLOCK_N + i * kNumUTCCPAlignedElems; + utccp_required_smem_warp_transpose(smem_ptr); + cutlass::arch::fence_view_async_shared(); mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + (kk * (SF_BLOCK_N / kNumUTCCPAlignedElems) + i) * 4); }