From 6be0eb31d9bf897b2396157a410fc06d8e5b0cdc Mon Sep 17 00:00:00 2001 From: yurekami Date: Thu, 1 Jan 2026 05:06:36 +0900 Subject: [PATCH] fix: use SM90ArchSpec instead of SM100ArchSpec in sm90_bf16_k_grouped_gemm MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The function sm90_bf16_k_grouped_gemm was incorrectly using SM100ArchSpec to calculate TMA descriptor block sizes. Since this file is the SM90 implementation, it should consistently use SM90ArchSpec like the other functions in this file (sm90_bf16_gemm, sm90_m_grouped_bf16_gemm_contiguous, etc.). This fixes a copy-paste error that could cause incorrect block size calculations on SM90 (Hopper) GPUs. Fixes #242 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- csrc/jit_kernels/impls/sm90_bf16_gemm.hpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp b/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp index 6499af2..9745019 100644 --- a/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp +++ b/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp @@ -258,18 +258,18 @@ static void sm90_bf16_k_grouped_gemm(const torch::Tensor& a, // Create tensor descriptors const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k, - SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), config.block_k, static_cast(a.stride(0)), 1, config.smem_config.swizzle_a_mode); const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k, - SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), config.block_k, static_cast(b.stride(0)), 1, config.smem_config.swizzle_b_mode); const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, - SM100ArchSpec::get_cd_store_block_m(config.block_m), - SM100ArchSpec::get_cd_store_block_n(config.block_n), + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), static_cast(d.stride(1)), num_groups, config.smem_config.swizzle_cd_mode);