diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh index cdd28fc..2c24c5e 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh @@ -167,7 +167,8 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { const cute::TmaDescriptor* current_tensor_map_a = &tensor_map_a_base; const cute::TmaDescriptor* current_tensor_map_b = &tensor_map_b_base; - uint32_t last_group_idx = kNumGroups, sum_k = 0; + uint32_t last_group_idx = kNumGroups; + uint32_t prefetched_next_group_idx = kNumGroups; // Track which group was prefetched // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { @@ -187,16 +188,45 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, const uint32_t& next_stage_idx = stage_idx ^ 1; last_group_idx = scheduler.current_group_idx; - // Prepare next tensor map - sum_k += scheduler.current_shape_k; + // Check if the current group matches the prefetched group + // If not, we need to prepare the correct tensor map for the current group + if (scheduler.current_num_valid_groups > 0 && + scheduler.current_group_idx != prefetched_next_group_idx) { + // The prefetched tensor map doesn't match current group + // This happens when block count is small (< num_SMs) and scheduler skips groups + // Need to prepare the correct tensor map for current group + // Use scheduler.current_k_cumsum which correctly tracks k offset even when groups are skipped + const uint64_t current_k_offset = scheduler.current_k_cumsum; + tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[stage_idx], + gmem_a_ptr + current_k_offset * shape_m); + tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[stage_idx], + gmem_b_ptr + current_k_offset * shape_n); + tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[stage_idx], + scheduler.current_shape_k, scheduler.current_shape_k); + tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[stage_idx], + scheduler.current_shape_k, scheduler.current_shape_k); + *(gmem_tensor_map_a[stage_idx]) = *(smem_tensor_map_a[stage_idx]); + *(gmem_tensor_map_b[stage_idx]) = *(smem_tensor_map_b[stage_idx]); + // NOTE: Don't call tensor_map_release_cta() here! + // We're preparing the current tensor map, not the next one. + // It will be acquired immediately in the "Get current tensor map" section below. + } + + // Prepare next tensor map (prefetch for next group) if (scheduler.next_group_idx < kNumGroups) { - tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + static_cast(sum_k) * shape_m); - tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + static_cast(sum_k) * shape_n); + // Calculate next group's k offset using scheduler-provided information + // This ensures consistency even when groups are skipped + const uint64_t next_k_offset = static_cast(scheduler.current_k_cumsum) + scheduler.current_shape_k; + tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + next_k_offset * shape_m); + tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + next_k_offset * shape_n); tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k); tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k); *(gmem_tensor_map_a[next_stage_idx]) = *(smem_tensor_map_a[next_stage_idx]); *(gmem_tensor_map_b[next_stage_idx]) = *(smem_tensor_map_b[next_stage_idx]); tensor_map_release_cta(); + prefetched_next_group_idx = scheduler.next_group_idx; // Record which group was prefetched + } else { + prefetched_next_group_idx = kNumGroups; // No more groups to prefetch } // Get current tensor map