Fix SM90 MQA logits (#229)

This commit is contained in:
Ray Wang
2025-11-19 10:50:36 +08:00
committed by GitHub
parent 2f9d87877e
commit ec5e9ed0b8
2 changed files with 2 additions and 2 deletions

View File

@@ -87,7 +87,7 @@ static void smxx_fp8_mqa_logits(const torch::Tensor& q,
constexpr int block_qh = 128;
constexpr int block_kv = 256;
constexpr int num_specialized_threads = 128;
constexpr int num_math_threads = 256;
const int num_math_threads = (device_runtime->get_arch_major() == 10 ? 256 : 512);
constexpr int num_q_stages = 3, num_kv_stages = 3;
const int block_q = block_qh / num_heads;
DG_HOST_ASSERT(block_qh % num_heads == 0);

View File

@@ -333,7 +333,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con
auto sum = make_float2(0, 0);
for (int j = 0; j < kNumWeightsInReg; j += 2) {
for (int j = 0; j < kNumWeightsInReg; j += 2) {
auto a = make_float2(fmaxf(accum[j], 0),
fmaxf(accum[j + 1], 0));
auto b = make_float2(weights[i][j], weights[i][j + 1]);