diff --git a/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp b/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp index ffd6f43..3a43165 100644 --- a/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp +++ b/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp @@ -87,7 +87,7 @@ static void smxx_fp8_mqa_logits(const torch::Tensor& q, constexpr int block_qh = 128; constexpr int block_kv = 256; constexpr int num_specialized_threads = 128; - constexpr int num_math_threads = 256; + const int num_math_threads = (device_runtime->get_arch_major() == 10 ? 256 : 512); constexpr int num_q_stages = 3, num_kv_stages = 3; const int block_q = block_qh / num_heads; DG_HOST_ASSERT(block_qh % num_heads == 0); diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh index 5252ddb..7474295 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh @@ -333,7 +333,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, con auto sum = make_float2(0, 0); - for (int j = 0; j < kNumWeightsInReg; j += 2) { + for (int j = 0; j < kNumWeightsInReg; j += 2) { auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); auto b = make_float2(weights[i][j], weights[i][j + 1]);