From f63d7f24d62462e5f7bedc38d94b3877b5762df9 Mon Sep 17 00:00:00 2001 From: Guoteng <32697156+SolenoidWGT@users.noreply.github.com> Date: Wed, 19 Nov 2025 10:52:08 +0800 Subject: [PATCH] fix: prevent int32 overflow in k-grouped GEMM size calculations (#226) --- csrc/apis/gemm.hpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp index 8d06292..076179c 100644 --- a/csrc/apis/gemm.hpp +++ b/csrc/apis/gemm.hpp @@ -280,13 +280,13 @@ static void k_grouped_fp8_gemm_nt_contiguous(const std::pair(d); - const auto& sum_mk = a.first.numel(); - const auto& sum_nk = b.first.numel(); - int sum_k = 0; + const auto& sum_mk = static_cast(a.first.numel()); + const auto& sum_nk = static_cast(b.first.numel()); + uint64_t sum_k = 0; for (const auto& k: ks) - sum_k += k; - DG_HOST_ASSERT(sum_mk == m * sum_k); - DG_HOST_ASSERT(sum_nk == n * sum_k); + sum_k += static_cast(k); + DG_HOST_ASSERT(sum_mk == static_cast(m) * sum_k); + DG_HOST_ASSERT(sum_nk == static_cast(n) * sum_k); // Contiguity checks DG_HOST_ASSERT(a.first.is_contiguous());