diff --git a/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp b/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp index 6499af2..9745019 100644 --- a/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp +++ b/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp @@ -258,18 +258,18 @@ static void sm90_bf16_k_grouped_gemm(const torch::Tensor& a, // Create tensor descriptors const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k, - SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), config.block_k, static_cast(a.stride(0)), 1, config.smem_config.swizzle_a_mode); const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k, - SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), config.block_k, static_cast(b.stride(0)), 1, config.smem_config.swizzle_b_mode); const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, - SM100ArchSpec::get_cd_store_block_m(config.block_m), - SM100ArchSpec::get_cd_store_block_n(config.block_n), + SM90ArchSpec::get_cd_store_block_m(config.block_m), + SM90ArchSpec::get_cd_store_block_n(config.block_n), static_cast(d.stride(1)), num_groups, config.smem_config.swizzle_cd_mode);