fix: k_grouped_fp8_gemm_nt_contiguous crashes with n = 768 on H100 (#238)
This commit is contained in:
@@ -167,7 +167,8 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
|
||||
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
|
||||
const cute::TmaDescriptor* current_tensor_map_a = &tensor_map_a_base;
|
||||
const cute::TmaDescriptor* current_tensor_map_b = &tensor_map_b_base;
|
||||
uint32_t last_group_idx = kNumGroups, sum_k = 0;
|
||||
uint32_t last_group_idx = kNumGroups;
|
||||
uint32_t prefetched_next_group_idx = kNumGroups; // Track which group was prefetched
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
@@ -187,16 +188,45 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
|
||||
const uint32_t& next_stage_idx = stage_idx ^ 1;
|
||||
last_group_idx = scheduler.current_group_idx;
|
||||
|
||||
// Prepare next tensor map
|
||||
sum_k += scheduler.current_shape_k;
|
||||
// Check if the current group matches the prefetched group
|
||||
// If not, we need to prepare the correct tensor map for the current group
|
||||
if (scheduler.current_num_valid_groups > 0 &&
|
||||
scheduler.current_group_idx != prefetched_next_group_idx) {
|
||||
// The prefetched tensor map doesn't match current group
|
||||
// This happens when block count is small (< num_SMs) and scheduler skips groups
|
||||
// Need to prepare the correct tensor map for current group
|
||||
// Use scheduler.current_k_cumsum which correctly tracks k offset even when groups are skipped
|
||||
const uint64_t current_k_offset = scheduler.current_k_cumsum;
|
||||
tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[stage_idx],
|
||||
gmem_a_ptr + current_k_offset * shape_m);
|
||||
tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[stage_idx],
|
||||
gmem_b_ptr + current_k_offset * shape_n);
|
||||
tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[stage_idx],
|
||||
scheduler.current_shape_k, scheduler.current_shape_k);
|
||||
tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[stage_idx],
|
||||
scheduler.current_shape_k, scheduler.current_shape_k);
|
||||
*(gmem_tensor_map_a[stage_idx]) = *(smem_tensor_map_a[stage_idx]);
|
||||
*(gmem_tensor_map_b[stage_idx]) = *(smem_tensor_map_b[stage_idx]);
|
||||
// NOTE: Don't call tensor_map_release_cta() here!
|
||||
// We're preparing the current tensor map, not the next one.
|
||||
// It will be acquired immediately in the "Get current tensor map" section below.
|
||||
}
|
||||
|
||||
// Prepare next tensor map (prefetch for next group)
|
||||
if (scheduler.next_group_idx < kNumGroups) {
|
||||
tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + static_cast<uint64_t>(sum_k) * shape_m);
|
||||
tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + static_cast<uint64_t>(sum_k) * shape_n);
|
||||
// Calculate next group's k offset using scheduler-provided information
|
||||
// This ensures consistency even when groups are skipped
|
||||
const uint64_t next_k_offset = static_cast<uint64_t>(scheduler.current_k_cumsum) + scheduler.current_shape_k;
|
||||
tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + next_k_offset * shape_m);
|
||||
tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + next_k_offset * shape_n);
|
||||
tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k);
|
||||
tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k);
|
||||
*(gmem_tensor_map_a[next_stage_idx]) = *(smem_tensor_map_a[next_stage_idx]);
|
||||
*(gmem_tensor_map_b[next_stage_idx]) = *(smem_tensor_map_b[next_stage_idx]);
|
||||
tensor_map_release_cta();
|
||||
prefetched_next_group_idx = scheduler.next_group_idx; // Record which group was prefetched
|
||||
} else {
|
||||
prefetched_next_group_idx = kNumGroups; // No more groups to prefetch
|
||||
}
|
||||
|
||||
// Get current tensor map
|
||||
|
||||
Reference in New Issue
Block a user