Merge pull request #270 from yurekami/fix/sm90-archspec-bug

fix: use SM90ArchSpec instead of SM100ArchSpec in sm90_bf16_k_grouped_gemm
This commit is contained in:
Zhean Xu
2026-01-06 09:56:33 +08:00
committed by GitHub

View File

@@ -258,18 +258,18 @@ static void sm90_bf16_k_grouped_gemm(const torch::Tensor& a,
// Create tensor descriptors
const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(0)), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(0)), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_cd = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
SM90ArchSpec::get_cd_store_block_m(config.block_m),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(1)), num_groups,
config.smem_config.swizzle_cd_mode);