Clean up
This commit is contained in:
@@ -280,13 +280,13 @@ static void k_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torc
|
||||
|
||||
// Shape checks
|
||||
const auto& [num_groups, m, n] = get_shape<3>(d);
|
||||
const auto& sum_mk = static_cast<uint64_t>(a.first.numel());
|
||||
const auto& sum_nk = static_cast<uint64_t>(b.first.numel());
|
||||
uint64_t sum_k = 0;
|
||||
const auto& sum_mk = a.first.numel();
|
||||
const auto& sum_nk = b.first.numel();
|
||||
int sum_k = 0;
|
||||
for (const auto& k: ks)
|
||||
sum_k += static_cast<uint64_t>(k);
|
||||
DG_HOST_ASSERT(sum_mk == static_cast<uint64_t>(m) * sum_k);
|
||||
DG_HOST_ASSERT(sum_nk == static_cast<uint64_t>(n) * sum_k);
|
||||
sum_k += k;
|
||||
DG_HOST_ASSERT(sum_mk == static_cast<int64_t>(m) * sum_k);
|
||||
DG_HOST_ASSERT(sum_nk == static_cast<int64_t>(n) * sum_k);
|
||||
|
||||
// Contiguity checks
|
||||
DG_HOST_ASSERT(a.first.is_contiguous());
|
||||
|
||||
Reference in New Issue
Block a user