Merge pull request #270 from yurekami/fix/sm90-archspec-bug
fix: use SM90ArchSpec instead of SM100ArchSpec in sm90_bf16_k_grouped_gemm
This commit is contained in:
@@ -258,18 +258,18 @@ static void sm90_bf16_k_grouped_gemm(const torch::Tensor& a,
|
||||
|
||||
// Create tensor descriptors
|
||||
const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k,
|
||||
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(0)), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k,
|
||||
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(0)), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_cd = make_tma_cd_desc(d, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(1)), num_groups,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user