Use larger MMA shape (#227)

This commit is contained in:
Ray Wang
2025-11-14 11:38:15 +08:00
committed by GitHub
parent c9f8b34dcd
commit 2f9d87877e
2 changed files with 46 additions and 56 deletions

View File

@@ -87,7 +87,7 @@ static void smxx_fp8_mqa_logits(const torch::Tensor& q,
constexpr int block_qh = 128;
constexpr int block_kv = 256;
constexpr int num_specialized_threads = 128;
constexpr int num_math_threads = 512;
constexpr int num_math_threads = 256;
constexpr int num_q_stages = 3, num_kv_stages = 3;
const int block_q = block_qh / num_heads;
DG_HOST_ASSERT(block_qh % num_heads == 0);

View File

@@ -134,7 +134,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con
// Register reconfigurations
constexpr uint32_t kNumSpecializedRegisters = 32;
constexpr uint32_t kNumMathRegisters = 112;
constexpr uint32_t kNumMathRegisters = 232;
// Block scheduler
uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0;
@@ -168,7 +168,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con
// UMMA settings
// Construct instruction with layout F
constexpr uint32_t UMMA_M = 64;
constexpr uint32_t UMMA_M = 128;
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads;
@@ -242,7 +242,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
// Issue UMMA
DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size");
DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads, "Invalid block size");
DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim");
#pragma unroll
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
@@ -270,10 +270,11 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con
// Offsets
const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0);
float weights[BLOCK_Q][kNumHeads / 4];
const auto& warp_offset = warp_idx * 16;
const auto& v_0_offset = lane_idx / 4 + 0;
const auto& v_1_offset = lane_idx / 4 + 8;
// Number of registers is not sufficient for all weights, so we need to cache some of them
constexpr uint32_t kNumWeightsInReg = 48;
float weights[BLOCK_Q][kNumWeightsInReg];
const auto& warp_offset = warp_idx * 32;
const auto& v_offset = lane_idx;
while (block_q_idx < num_q_blocks) {
CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks);
@@ -284,9 +285,9 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con
// Read weights
#pragma unroll
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
#pragma unroll
for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) {
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
}
}
// Compute over KV blocks
@@ -298,8 +299,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
// Read per-KV scales
auto scale_kv = make_float2(ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset),
ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset));
float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_offset);
// Wait UMMA arrival
full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1);
@@ -311,59 +311,49 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con
const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset;
static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2;
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
#pragma unroll
constexpr uint32_t kNumLDTMElems = kNumHeads * BLOCK_Q;
DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid kNumLDTMElems");
uint32_t shifted_accum[kNumLDTMElems];
auto tmem_load = [&](auto... Is) {
if constexpr (kNumLDTMElems == 32) {
cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...);
} else if constexpr (kNumLDTMElems == 64) {
cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...);
} else if constexpr (kNumLDTMElems == 128) {
cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...);
}
};
[&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
cutlass::arch::fence_view_async_tmem_load();
empty_umma_barriers[warpgroup_idx]->arrive();
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
// Load from the tensor memory
constexpr uint32_t kNumLDTMElems = UMMA_M * kNumHeads / 128;
uint32_t shifted_accum[kNumLDTMElems];
DG_STATIC_ASSERT(kNumLDTMElems == 16 or kNumLDTMElems == 32 or kNumLDTMElems == 64, "Invalid LDTM");
auto tmem_load = [&](auto... Is) {
if constexpr (kNumLDTMElems == 16) {
cute::SM100_TMEM_LOAD_16dp256b4x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...);
} else if constexpr (kNumLDTMElems == 32) {
cute::SM100_TMEM_LOAD_16dp256b8x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...);
} else if constexpr (kNumLDTMElems == 64) {
cute::SM100_TMEM_LOAD_16dp256b16x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...);
}
};
[&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
cutlass::arch::fence_view_async_tmem_load();
float* accum = reinterpret_cast<float*>(shifted_accum + i * kNumHeads);
// Release UMMA empty
if (i == BLOCK_Q - 1)
empty_umma_barriers[warpgroup_idx]->arrive();
auto sum = make_float2(0, 0);
// Transform
const auto& transform_2 = [&](const uint32_t& j, const uint32_t& k, const float2& sum) {
auto a = make_float2(fmaxf(*reinterpret_cast<float*>(&shifted_accum[j * 4 + k]), 0),
fmaxf(*reinterpret_cast<float*>(&shifted_accum[j * 4 + k + 2]), 0));
auto b = make_float2(weights[i][j * 2 + k], weights[i][j * 2 + k]);
return __ffma2_rn(a, b, sum);
};
// Intra-thread reduction
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
#pragma unroll
for (uint32_t j = 0; j < kNumHeads / 8; ++ j) {
sum_0 = transform_2(j, 0, sum_0);
sum_1 = transform_2(j, 1, sum_1);
for (int j = 0; j < kNumWeightsInReg; j += 2) {
auto a = make_float2(fmaxf(accum[j], 0),
fmaxf(accum[j + 1], 0));
auto b = make_float2(weights[i][j], weights[i][j + 1]);
sum = __ffma2_rn(a, b, sum);
}
auto v = __fmul2_rn(__fadd2_rn(sum_0, sum_1), scale_kv);
// Inter-thread reduction
#pragma unroll
for (uint32_t j = 0; j < 2; ++ j) {
const auto& offset = 1u << j;
v.x += __shfl_xor_sync(0xffffffffu, v.x, offset);
v.y += __shfl_xor_sync(0xffffffffu, v.y, offset);
for (int j = kNumWeightsInReg; j < kNumHeads; j += 2) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j),
ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1));
sum = __ffma2_rn(a, b, sum);
}
float result = sum.x + sum.y;
result *= scale_kv;
// Store into the global memory
// NOTES: we have redundant writes here, consider more carefully
const uint32_t& q_idx = block_q_idx * BLOCK_Q + i;
logits[q_idx * stride_kv + kv_offset + v_0_offset] = v.x;
logits[q_idx * stride_kv + kv_offset + v_1_offset] = v.y;
logits[q_idx * stride_kv + kv_offset + v_offset] = result;
}
}
num_total_kv_blocks += num_kv_blocks;