From ec5e9ed0b87590d456cc017b41697d54cc6a2629 Mon Sep 17 00:00:00 2001 From: Ray Wang Date: Wed, 19 Nov 2025 10:50:36 +0800 Subject: [PATCH] Fix SM90 MQA logits (#229) --- csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp | 2 +- deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp b/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp index ffd6f43..3a43165 100644 --- a/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp +++ b/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp @@ -87,7 +87,7 @@ static void smxx_fp8_mqa_logits(const torch::Tensor& q, constexpr int block_qh = 128; constexpr int block_kv = 256; constexpr int num_specialized_threads = 128; - constexpr int num_math_threads = 256; + const int num_math_threads = (device_runtime->get_arch_major() == 10 ? 256 : 512); constexpr int num_q_stages = 3, num_kv_stages = 3; const int block_q = block_qh / num_heads; DG_HOST_ASSERT(block_qh % num_heads == 0); diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh index 5252ddb..7474295 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh @@ -333,7 +333,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con auto sum = make_float2(0, 0); - for (int j = 0; j < kNumWeightsInReg; j += 2) { + for (int j = 0; j < kNumWeightsInReg; j += 2) { auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); auto b = make_float2(weights[i][j], weights[i][j + 1]);