fix: use SM90ArchSpec instead of SM100ArchSpec in sm90_bf16_k_grouped_gemm

The function sm90_bf16_k_grouped_gemm was incorrectly using SM100ArchSpec
to calculate TMA descriptor block sizes. Since this file is the SM90
implementation, it should consistently use SM90ArchSpec like the other
functions in this file (sm90_bf16_gemm, sm90_m_grouped_bf16_gemm_contiguous,
etc.).

This fixes a copy-paste error that could cause incorrect block size
calculations on SM90 (Hopper) GPUs.

Fixes #242

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
yurekami
2026-01-01 05:06:36 +09:00
parent 9b680f4284
commit 6be0eb31d9

View File

@@ -258,18 +258,18 @@ static void sm90_bf16_k_grouped_gemm(const torch::Tensor& a,
// Create tensor descriptors
const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(0)), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(0)), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_cd = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
SM90ArchSpec::get_cd_store_block_m(config.block_m),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(1)), num_groups,
config.smem_config.swizzle_cd_mode);