diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp index 076179c..bf1f758 100644 --- a/csrc/apis/gemm.hpp +++ b/csrc/apis/gemm.hpp @@ -280,13 +280,13 @@ static void k_grouped_fp8_gemm_nt_contiguous(const std::pair(d); - const auto& sum_mk = static_cast(a.first.numel()); - const auto& sum_nk = static_cast(b.first.numel()); - uint64_t sum_k = 0; + const auto& sum_mk = a.first.numel(); + const auto& sum_nk = b.first.numel(); + int sum_k = 0; for (const auto& k: ks) - sum_k += static_cast(k); - DG_HOST_ASSERT(sum_mk == static_cast(m) * sum_k); - DG_HOST_ASSERT(sum_nk == static_cast(n) * sum_k); + sum_k += k; + DG_HOST_ASSERT(sum_mk == static_cast(m) * sum_k); + DG_HOST_ASSERT(sum_nk == static_cast(n) * sum_k); // Contiguity checks DG_HOST_ASSERT(a.first.is_contiguous());