From 2f9d87877ed691a62796c25f2e9496a5e0b7123a Mon Sep 17 00:00:00 2001 From: Ray Wang Date: Fri, 14 Nov 2025 11:38:15 +0800 Subject: [PATCH] Use larger MMA shape (#227) --- .../jit_kernels/impls/smxx_fp8_mqa_logits.hpp | 2 +- .../deep_gemm/impls/sm100_fp8_mqa_logits.cuh | 100 ++++++++---------- 2 files changed, 46 insertions(+), 56 deletions(-) diff --git a/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp b/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp index f585639..ffd6f43 100644 --- a/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp +++ b/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp @@ -87,7 +87,7 @@ static void smxx_fp8_mqa_logits(const torch::Tensor& q, constexpr int block_qh = 128; constexpr int block_kv = 256; constexpr int num_specialized_threads = 128; - constexpr int num_math_threads = 512; + constexpr int num_math_threads = 256; constexpr int num_q_stages = 3, num_kv_stages = 3; const int block_q = block_qh / num_heads; DG_HOST_ASSERT(block_qh % num_heads == 0); diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh index 7890c26..5252ddb 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh @@ -134,7 +134,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con // Register reconfigurations constexpr uint32_t kNumSpecializedRegisters = 32; - constexpr uint32_t kNumMathRegisters = 112; + constexpr uint32_t kNumMathRegisters = 232; // Block scheduler uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0; @@ -168,7 +168,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con // UMMA settings // Construct instruction with layout F - constexpr uint32_t UMMA_M = 64; + constexpr uint32_t UMMA_M = 128; constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads; @@ -242,7 +242,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con full_kv_barriers[kv_stage_idx]->wait(kv_phase); // Issue UMMA - DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size"); + DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads, "Invalid block size"); DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim"); #pragma unroll for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { @@ -270,10 +270,11 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con // Offsets const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0); - float weights[BLOCK_Q][kNumHeads / 4]; - const auto& warp_offset = warp_idx * 16; - const auto& v_0_offset = lane_idx / 4 + 0; - const auto& v_1_offset = lane_idx / 4 + 8; + // Number of registers is not sufficient for all weights, so we need to cache some of them + constexpr uint32_t kNumWeightsInReg = 48; + float weights[BLOCK_Q][kNumWeightsInReg]; + const auto& warp_offset = warp_idx * 32; + const auto& v_offset = lane_idx; while (block_q_idx < num_q_blocks) { CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); @@ -284,9 +285,9 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con // Read weights #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { - #pragma unroll - for (uint32_t j = 0; j < kNumHeads / 4; ++ j) - weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) { + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + } } // Compute over KV blocks @@ -298,8 +299,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con full_kv_barriers[kv_stage_idx]->wait(kv_phase); // Read per-KV scales - auto scale_kv = make_float2(ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset), - ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset)); + float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_offset); // Wait UMMA arrival full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1); @@ -311,59 +311,49 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset; static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); - #pragma unroll + constexpr uint32_t kNumLDTMElems = kNumHeads * BLOCK_Q; + DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid kNumLDTMElems"); + uint32_t shifted_accum[kNumLDTMElems]; + auto tmem_load = [&](auto... Is) { + if constexpr (kNumLDTMElems == 32) { + cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 64) { + cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 128) { + cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...); + } + }; + [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + + empty_umma_barriers[warpgroup_idx]->arrive(); + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { - // Load from the tensor memory - constexpr uint32_t kNumLDTMElems = UMMA_M * kNumHeads / 128; - uint32_t shifted_accum[kNumLDTMElems]; - DG_STATIC_ASSERT(kNumLDTMElems == 16 or kNumLDTMElems == 32 or kNumLDTMElems == 64, "Invalid LDTM"); - auto tmem_load = [&](auto... Is) { - if constexpr (kNumLDTMElems == 16) { - cute::SM100_TMEM_LOAD_16dp256b4x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...); - } else if constexpr (kNumLDTMElems == 32) { - cute::SM100_TMEM_LOAD_16dp256b8x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...); - } else if constexpr (kNumLDTMElems == 64) { - cute::SM100_TMEM_LOAD_16dp256b16x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...); - } - }; - [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); - cutlass::arch::fence_view_async_tmem_load(); + float* accum = reinterpret_cast(shifted_accum + i * kNumHeads); - // Release UMMA empty - if (i == BLOCK_Q - 1) - empty_umma_barriers[warpgroup_idx]->arrive(); + auto sum = make_float2(0, 0); - // Transform - const auto& transform_2 = [&](const uint32_t& j, const uint32_t& k, const float2& sum) { - auto a = make_float2(fmaxf(*reinterpret_cast(&shifted_accum[j * 4 + k]), 0), - fmaxf(*reinterpret_cast(&shifted_accum[j * 4 + k + 2]), 0)); - auto b = make_float2(weights[i][j * 2 + k], weights[i][j * 2 + k]); - return __ffma2_rn(a, b, sum); - }; - - // Intra-thread reduction - auto sum_0 = make_float2(0, 0); - auto sum_1 = make_float2(0, 0); - #pragma unroll - for (uint32_t j = 0; j < kNumHeads / 8; ++ j) { - sum_0 = transform_2(j, 0, sum_0); - sum_1 = transform_2(j, 1, sum_1); + for (int j = 0; j < kNumWeightsInReg; j += 2) { + auto a = make_float2(fmaxf(accum[j], 0), + fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + sum = __ffma2_rn(a, b, sum); } - auto v = __fmul2_rn(__fadd2_rn(sum_0, sum_1), scale_kv); - // Inter-thread reduction - #pragma unroll - for (uint32_t j = 0; j < 2; ++ j) { - const auto& offset = 1u << j; - v.x += __shfl_xor_sync(0xffffffffu, v.x, offset); - v.y += __shfl_xor_sync(0xffffffffu, v.y, offset); + for (int j = kNumWeightsInReg; j < kNumHeads; j += 2) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j), + ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1)); + sum = __ffma2_rn(a, b, sum); } + float result = sum.x + sum.y; + result *= scale_kv; + // Store into the global memory // NOTES: we have redundant writes here, consider more carefully const uint32_t& q_idx = block_q_idx * BLOCK_Q + i; - logits[q_idx * stride_kv + kv_offset + v_0_offset] = v.x; - logits[q_idx * stride_kv + kv_offset + v_1_offset] = v.y; + logits[q_idx * stride_kv + kv_offset + v_offset] = result; } } num_total_kv_blocks += num_kv_blocks;