Support TMA multicast on B with m_grouped_gemm_contiguous. (#88)

This commit is contained in:
yukuai26
2025-04-21 09:43:17 +08:00
committed by GitHub
parent 83aa960b9b
commit 891f35adf5
5 changed files with 74 additions and 31 deletions

View File

@@ -86,14 +86,14 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
// Prefetch TMA descriptors at very beginning
if (threadIdx.x == kNumMathThreads) {
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_a));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_b));
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_scales_a));
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_scales_a);
// `tensor_map_d` is only used in swizzling mode
// For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode
if constexpr (kSwizzleDMode > 0)
cute::prefetch_tma_descriptor(reinterpret_cast<cute::TmaDescriptor const*>(&tensor_map_d));
cute::prefetch_tma_descriptor(&tensor_map_d);
}
__syncwarp();
@@ -212,8 +212,16 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K));
// Issue TMA B
tma_copy<kNumTMAMulticastOnB>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
if (kNumTMAMulticastOnB > 1 and scheduler.is_tma_multicast_b_valid(m_block_idx)) {
// NOTES: in grouped contiguous GEMM, different `m_block_idx` values may correspond to blocks of different groups (B),
// requiring additional checks before multicast operations.
DG_STATIC_ASSERT(kNumTMAMulticastOnB <= 2, "Scheduler does not support > 2 TMA multicast");
tma_copy<kNumTMAMulticastOnB>(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
} else {
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx));
}
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
}

View File

@@ -43,6 +43,18 @@ struct Scheduler {
}
}
__device__ __forceinline__ bool is_tma_multicast_b_valid(const uint32_t& m_block_idx) {
if constexpr (kGemmType == GemmType::Normal) {
return true;
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M);
auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M);
return group_idx == peer_group_idx;
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
return false;
}
}
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t num_m_blocks, int block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
@@ -72,10 +84,10 @@ struct Scheduler {
const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
if constexpr (kGemmType == GemmType::Normal) {
return block_idx * block_size;
} else if (kGemmType == GemmType::GroupedContiguous) {
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : __ldg(grouped_layout + m_block_idx * BLOCK_M);
return offset * shape_dim + block_idx * block_size;
} else if (kGemmType == GemmType::GroupedMasked) {
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
return curr_group_idx * shape_dim + block_idx * block_size;
}
}