Fix SM90 MQA logits (#229)
This commit is contained in:
@@ -87,7 +87,7 @@ static void smxx_fp8_mqa_logits(const torch::Tensor& q,
|
||||
constexpr int block_qh = 128;
|
||||
constexpr int block_kv = 256;
|
||||
constexpr int num_specialized_threads = 128;
|
||||
constexpr int num_math_threads = 256;
|
||||
const int num_math_threads = (device_runtime->get_arch_major() == 10 ? 256 : 512);
|
||||
constexpr int num_q_stages = 3, num_kv_stages = 3;
|
||||
const int block_q = block_qh / num_heads;
|
||||
DG_HOST_ASSERT(block_qh % num_heads == 0);
|
||||
|
||||
@@ -333,7 +333,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con
|
||||
|
||||
auto sum = make_float2(0, 0);
|
||||
|
||||
for (int j = 0; j < kNumWeightsInReg; j += 2) {
|
||||
for (int j = 0; j < kNumWeightsInReg; j += 2) {
|
||||
auto a = make_float2(fmaxf(accum[j], 0),
|
||||
fmaxf(accum[j + 1], 0));
|
||||
auto b = make_float2(weights[i][j], weights[i][j + 1]);
|
||||
|
||||
Reference in New Issue
Block a user