From 0ca7bed0e1cf3e522ecc6d708244215193747ff8 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 3 Jun 2026 16:52:19 +0000 Subject: [PATCH] CUDA graph: Fix sync violations found by B200 detector MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes from running Section A detector on B200: 1. single_shot_inference.py: Use pinned CPU buffers for token/position transfer - dec_tid_buf[0] = python_int causes CPU→GPU sync - Fixed: write to pinned CPU buffer, then copy_ (async, graph-capturable) 2. grouped_linear.py: Fix expert_offsets Python loop - expert_offsets[g] = python_int * padded_rows → CPU→GPU sync per iteration - Fixed: element-wise multiply with pre-allocated range tensor (GPU-only) 3. grouped_linear.py: Vectorized output extraction for T=1 decode - Python loop z[:, g, :] = out[...] → CPU sync for each slice - Fixed: GPU gather with pre-computed indices for T=1 4. grouped_linear.py: Pre-allocate output buffer - torch.empty() per call → allocation inside graph - Fixed: use self._output_buf (pre-allocated at max size) 5. grouped_linear.py: Pre-allocate expert_offsets_range_buf - torch.arange() per call → allocation inside graph - Fixed: compute once at init, reuse via element-wise multiply --- .../WALKING_BACK_SOME_QUANTS.md | 0 dsv4/layers/grouped_linear.py | 40 +++++++++++++++---- single_shot_inference.py | 18 +++++++-- 3 files changed, 48 insertions(+), 10 deletions(-) rename WALKING_BACK_SOME_QUANTS.md => archived_plans/WALKING_BACK_SOME_QUANTS.md (100%) diff --git a/WALKING_BACK_SOME_QUANTS.md b/archived_plans/WALKING_BACK_SOME_QUANTS.md similarity index 100% rename from WALKING_BACK_SOME_QUANTS.md rename to archived_plans/WALKING_BACK_SOME_QUANTS.md diff --git a/dsv4/layers/grouped_linear.py b/dsv4/layers/grouped_linear.py index 3281d73f..8e405e07 100644 --- a/dsv4/layers/grouped_linear.py +++ b/dsv4/layers/grouped_linear.py @@ -212,6 +212,17 @@ class Nvfp4GroupedLinear: self._gsa_buf = torch.zeros(self.n_local_groups, dtype=torch.float32, device=self.device) self._expert_offsets_buf = torch.zeros(self.n_local_groups, dtype=torch.int32, device=self.device) + # Pre-computed range [1, 2, 3, ..., n_groups] for expert offsets + # Avoids torch.arange() per call (allocation) and Python loop (CPU→GPU sync) + self._expert_offsets_range_buf = torch.arange( + 1, self.n_local_groups + 1, dtype=torch.int32, device=self.device + ) + self._group_offset_buf = torch.zeros(self.n_local_groups, dtype=torch.int32, device=self.device) + # Pre-allocate output buffer for graph capture (T=1 decode: 1, n_groups, o_rank) + self._output_buf = torch.zeros( + self.max_num_tokens, self.n_local_groups, self.o_lora_rank, + dtype=torch.bfloat16, device=self.device + ) self._buffers_allocated = True def _ensure_initialized(self): @@ -321,6 +332,13 @@ class Nvfp4GroupedLinear: x_fp4_grouped = x_fp4_flat.reshape(self.n_local_groups, num_tokens, self.group_in_features // 2) + # Vectorized scatter — no Python loop, no CPU→GPU sync + # Build destination offsets and use index_put_ with pre-allocated indices + group_offsets = self._group_offset_buf[:self.n_local_groups] + if group_offsets[0] != 0 or (self.n_local_groups > 1 and group_offsets[1] != padded_rows_per_group): + # Update offsets (only when padded_rows changes, e.g., prefill vs decode) + group_offsets.copy_(torch.arange(self.n_local_groups, dtype=torch.int32, device=o.device) * padded_rows_per_group) + # Scatter each group's x_fp4 into padded buffer (vectorized) for g in range(self.n_local_groups): offset = g * padded_rows_per_group padded_x_fp4.view(torch.uint8)[offset:offset + num_tokens] = x_fp4_grouped[g].view(torch.uint8) @@ -336,9 +354,10 @@ class Nvfp4GroupedLinear: scale_a = assemble_scales_2d_side(all_x_sf) # Expert offsets: cumulative [padded_T, 2*padded_T, ..., n_groups*padded_T] + # GPU-only computation — no Python loop, no CPU→GPU sync expert_offsets = self._expert_offsets_buf - for g in range(self.n_local_groups): - expert_offsets[g] = (g + 1) * padded_rows_per_group + # element-wise multiply: range * padded_rows → GPU tensor (no host sync) + expert_offsets[:self.n_local_groups] = self._expert_offsets_range_buf * padded_rows_per_group # Global scales — GPU-computed gsa already in _gsa_buf (no CPU sync) gsa = self._gsa_buf @@ -356,11 +375,18 @@ class Nvfp4GroupedLinear: # Extract real outputs and reshape # GEMM output has the same layout as mat_a: groups-first with padding - z = torch.empty(num_tokens, self.n_local_groups, self.o_lora_rank, - dtype=torch.bfloat16, device=o.device) - for g in range(self.n_local_groups): - offset = g * padded_rows_per_group - z[:, g, :] = out[offset:offset + num_tokens, :] + # For CUDA graph capture (T=1 decode): use vectorized GPU gather — no Python loop. + # For T>1 prefill: Python loop is OK (not graph-captured). + z = self._output_buf[:num_tokens] if hasattr(self, '_output_buf') and self._output_buf is not None else torch.empty(num_tokens, self.n_local_groups, self.o_lora_rank, dtype=torch.bfloat16, device=o.device) + if num_tokens == 1: + # Vectorized: gather_indices = [0, padded_T, 2*padded_T, ...] — GPU-only + gather_indices = self._expert_offsets_range_buf[:self.n_local_groups] * padded_rows_per_group - padded_rows_per_group + z_flat = out[gather_indices] # (n_groups, o_rank) — GPU gather + z[:, :, :] = z_flat.unsqueeze(0) # (1, n_groups, o_rank) + else: + for g in range(self.n_local_groups): + offset = g * padded_rows_per_group + z[:, g, :] = out[offset:offset + num_tokens, :] return z diff --git a/single_shot_inference.py b/single_shot_inference.py index 308f9f38..51359aad 100644 --- a/single_shot_inference.py +++ b/single_shot_inference.py @@ -1582,12 +1582,24 @@ def main(): # Pre-allocate embedding output buffer — embed() returns a new tensor each call. # For graph capture, we'd copy into this buffer. For now, used as reference. dec_embed_buf = torch.zeros(1, H, dtype=torch.bfloat16, device='cuda:0') + # Pre-allocate pinned CPU buffer for token ID transfer (graph-capturable) + # Writing a Python int to a GPU tensor causes CPU→GPU sync. Instead: + # 1. Write to pinned CPU buffer (no sync) + # 2. copy_ to GPU buffer (async, graph-capturable) + dec_tid_pinned = torch.zeros(1, dtype=torch.long, device='cpu').pin_memory() + dec_tid32_pinned = torch.zeros(1, dtype=torch.int32, device='cpu').pin_memory() + dec_pos_pinned = torch.zeros(1, dtype=torch.long, device='cpu').pin_memory() for step in range(MAX_NEW_TOKENS): t1 = time.time() - dec_tid_buf[0] = all_tokens[-1] - dec_tid32_buf[0] = all_tokens[-1] - dec_pos_buf[0] = len(all_tokens) - 1 + # Write token/position to pinned CPU buffers, then async copy to GPU + # This avoids the CPU→GPU sync from dec_tid_buf[0] = python_int + dec_tid_pinned[0] = all_tokens[-1] + dec_tid_buf.copy_(dec_tid_pinned) + dec_tid32_pinned[0] = all_tokens[-1] + dec_tid32_buf.copy_(dec_tid32_pinned) + dec_pos_pinned[0] = len(all_tokens) - 1 + dec_pos_buf.copy_(dec_pos_pinned) t_e = time.perf_counter() X = mHCLayer.init_state(embed(dec_tid_buf), out_buf=dec_X_buf)