CUDA graph: Fix sync violations found by B200 detector

Fixes from running Section A detector on B200:

1. single_shot_inference.py: Use pinned CPU buffers for token/position transfer
   - dec_tid_buf[0] = python_int causes CPU→GPU sync
   - Fixed: write to pinned CPU buffer, then copy_ (async, graph-capturable)

2. grouped_linear.py: Fix expert_offsets Python loop
   - expert_offsets[g] = python_int * padded_rows → CPU→GPU sync per iteration
   - Fixed: element-wise multiply with pre-allocated range tensor (GPU-only)

3. grouped_linear.py: Vectorized output extraction for T=1 decode
   - Python loop z[:, g, :] = out[...] → CPU sync for each slice
   - Fixed: GPU gather with pre-computed indices for T=1

4. grouped_linear.py: Pre-allocate output buffer
   - torch.empty() per call → allocation inside graph
   - Fixed: use self._output_buf (pre-allocated at max size)

5. grouped_linear.py: Pre-allocate expert_offsets_range_buf
   - torch.arange() per call → allocation inside graph
   - Fixed: compute once at init, reuse via element-wise multiply
This commit is contained in:
2026-06-03 16:52:19 +00:00
parent 46a3a51832
commit 0ca7bed0e1
3 changed files with 48 additions and 10 deletions

View File

@@ -212,6 +212,17 @@ class Nvfp4GroupedLinear:
self._gsa_buf = torch.zeros(self.n_local_groups, dtype=torch.float32, device=self.device)
self._expert_offsets_buf = torch.zeros(self.n_local_groups, dtype=torch.int32, device=self.device)
# Pre-computed range [1, 2, 3, ..., n_groups] for expert offsets
# Avoids torch.arange() per call (allocation) and Python loop (CPU→GPU sync)
self._expert_offsets_range_buf = torch.arange(
1, self.n_local_groups + 1, dtype=torch.int32, device=self.device
)
self._group_offset_buf = torch.zeros(self.n_local_groups, dtype=torch.int32, device=self.device)
# Pre-allocate output buffer for graph capture (T=1 decode: 1, n_groups, o_rank)
self._output_buf = torch.zeros(
self.max_num_tokens, self.n_local_groups, self.o_lora_rank,
dtype=torch.bfloat16, device=self.device
)
self._buffers_allocated = True
def _ensure_initialized(self):
@@ -321,6 +332,13 @@ class Nvfp4GroupedLinear:
x_fp4_grouped = x_fp4_flat.reshape(self.n_local_groups, num_tokens, self.group_in_features // 2)
# Vectorized scatter — no Python loop, no CPU→GPU sync
# Build destination offsets and use index_put_ with pre-allocated indices
group_offsets = self._group_offset_buf[:self.n_local_groups]
if group_offsets[0] != 0 or (self.n_local_groups > 1 and group_offsets[1] != padded_rows_per_group):
# Update offsets (only when padded_rows changes, e.g., prefill vs decode)
group_offsets.copy_(torch.arange(self.n_local_groups, dtype=torch.int32, device=o.device) * padded_rows_per_group)
# Scatter each group's x_fp4 into padded buffer (vectorized)
for g in range(self.n_local_groups):
offset = g * padded_rows_per_group
padded_x_fp4.view(torch.uint8)[offset:offset + num_tokens] = x_fp4_grouped[g].view(torch.uint8)
@@ -336,9 +354,10 @@ class Nvfp4GroupedLinear:
scale_a = assemble_scales_2d_side(all_x_sf)
# Expert offsets: cumulative [padded_T, 2*padded_T, ..., n_groups*padded_T]
# GPU-only computation — no Python loop, no CPU→GPU sync
expert_offsets = self._expert_offsets_buf
for g in range(self.n_local_groups):
expert_offsets[g] = (g + 1) * padded_rows_per_group
# element-wise multiply: range * padded_rows → GPU tensor (no host sync)
expert_offsets[:self.n_local_groups] = self._expert_offsets_range_buf * padded_rows_per_group
# Global scales — GPU-computed gsa already in _gsa_buf (no CPU sync)
gsa = self._gsa_buf
@@ -356,11 +375,18 @@ class Nvfp4GroupedLinear:
# Extract real outputs and reshape
# GEMM output has the same layout as mat_a: groups-first with padding
z = torch.empty(num_tokens, self.n_local_groups, self.o_lora_rank,
dtype=torch.bfloat16, device=o.device)
for g in range(self.n_local_groups):
offset = g * padded_rows_per_group
z[:, g, :] = out[offset:offset + num_tokens, :]
# For CUDA graph capture (T=1 decode): use vectorized GPU gather — no Python loop.
# For T>1 prefill: Python loop is OK (not graph-captured).
z = self._output_buf[:num_tokens] if hasattr(self, '_output_buf') and self._output_buf is not None else torch.empty(num_tokens, self.n_local_groups, self.o_lora_rank, dtype=torch.bfloat16, device=o.device)
if num_tokens == 1:
# Vectorized: gather_indices = [0, padded_T, 2*padded_T, ...] — GPU-only
gather_indices = self._expert_offsets_range_buf[:self.n_local_groups] * padded_rows_per_group - padded_rows_per_group
z_flat = out[gather_indices] # (n_groups, o_rank) — GPU gather
z[:, :, :] = z_flat.unsqueeze(0) # (1, n_groups, o_rank)
else:
for g in range(self.n_local_groups):
offset = g * padded_rows_per_group
z[:, g, :] = out[offset:offset + num_tokens, :]
return z

View File

@@ -1582,12 +1582,24 @@ def main():
# Pre-allocate embedding output buffer — embed() returns a new tensor each call.
# For graph capture, we'd copy into this buffer. For now, used as reference.
dec_embed_buf = torch.zeros(1, H, dtype=torch.bfloat16, device='cuda:0')
# Pre-allocate pinned CPU buffer for token ID transfer (graph-capturable)
# Writing a Python int to a GPU tensor causes CPU→GPU sync. Instead:
# 1. Write to pinned CPU buffer (no sync)
# 2. copy_ to GPU buffer (async, graph-capturable)
dec_tid_pinned = torch.zeros(1, dtype=torch.long, device='cpu').pin_memory()
dec_tid32_pinned = torch.zeros(1, dtype=torch.int32, device='cpu').pin_memory()
dec_pos_pinned = torch.zeros(1, dtype=torch.long, device='cpu').pin_memory()
for step in range(MAX_NEW_TOKENS):
t1 = time.time()
dec_tid_buf[0] = all_tokens[-1]
dec_tid32_buf[0] = all_tokens[-1]
dec_pos_buf[0] = len(all_tokens) - 1
# Write token/position to pinned CPU buffers, then async copy to GPU
# This avoids the CPU→GPU sync from dec_tid_buf[0] = python_int
dec_tid_pinned[0] = all_tokens[-1]
dec_tid_buf.copy_(dec_tid_pinned)
dec_tid32_pinned[0] = all_tokens[-1]
dec_tid32_buf.copy_(dec_tid32_pinned)
dec_pos_pinned[0] = len(all_tokens) - 1
dec_pos_buf.copy_(dec_pos_pinned)
t_e = time.perf_counter()
X = mHCLayer.init_state(embed(dec_tid_buf), out_buf=dec_X_buf)