fix: use SM90ArchSpec instead of SM100ArchSpec in sm90_bf16_k_grouped_gemm
The function sm90_bf16_k_grouped_gemm was incorrectly using SM100ArchSpec to calculate TMA descriptor block sizes. Since this file is the SM90 implementation, it should consistently use SM90ArchSpec like the other functions in this file (sm90_bf16_gemm, sm90_m_grouped_bf16_gemm_contiguous, etc.). This fixes a copy-paste error that could cause incorrect block size calculations on SM90 (Hopper) GPUs. Fixes #242 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -258,18 +258,18 @@ static void sm90_bf16_k_grouped_gemm(const torch::Tensor& a,
|
||||
|
||||
// Create tensor descriptors
|
||||
const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k,
|
||||
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(0)), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k,
|
||||
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(0)), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_cd = make_tma_cd_desc(d, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(1)), num_groups,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user