fix: prevent int32 overflow in k-grouped GEMM size calculations (#226)

This commit is contained in:
Guoteng
2025-11-19 10:52:08 +08:00
committed by GitHub
parent ec5e9ed0b8
commit f63d7f24d6

View File

@@ -280,13 +280,13 @@ static void k_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torc
// Shape checks
const auto& [num_groups, m, n] = get_shape<3>(d);
const auto& sum_mk = a.first.numel();
const auto& sum_nk = b.first.numel();
int sum_k = 0;
const auto& sum_mk = static_cast<uint64_t>(a.first.numel());
const auto& sum_nk = static_cast<uint64_t>(b.first.numel());
uint64_t sum_k = 0;
for (const auto& k: ks)
sum_k += k;
DG_HOST_ASSERT(sum_mk == m * sum_k);
DG_HOST_ASSERT(sum_nk == n * sum_k);
sum_k += static_cast<uint64_t>(k);
DG_HOST_ASSERT(sum_mk == static_cast<uint64_t>(m) * sum_k);
DG_HOST_ASSERT(sum_nk == static_cast<uint64_t>(n) * sum_k);
// Contiguity checks
DG_HOST_ASSERT(a.first.is_contiguous());