CUDA graph: Fix sync violations found by B200 detector
Fixes from running Section A detector on B200: 1. single_shot_inference.py: Use pinned CPU buffers for token/position transfer - dec_tid_buf[0] = python_int causes CPU→GPU sync - Fixed: write to pinned CPU buffer, then copy_ (async, graph-capturable) 2. grouped_linear.py: Fix expert_offsets Python loop - expert_offsets[g] = python_int * padded_rows → CPU→GPU sync per iteration - Fixed: element-wise multiply with pre-allocated range tensor (GPU-only) 3. grouped_linear.py: Vectorized output extraction for T=1 decode - Python loop z[:, g, :] = out[...] → CPU sync for each slice - Fixed: GPU gather with pre-computed indices for T=1 4. grouped_linear.py: Pre-allocate output buffer - torch.empty() per call → allocation inside graph - Fixed: use self._output_buf (pre-allocated at max size) 5. grouped_linear.py: Pre-allocate expert_offsets_range_buf - torch.arange() per call → allocation inside graph - Fixed: compute once at init, reuse via element-wise multiply
This commit is contained in:
@@ -212,6 +212,17 @@ class Nvfp4GroupedLinear:
|
||||
|
||||
self._gsa_buf = torch.zeros(self.n_local_groups, dtype=torch.float32, device=self.device)
|
||||
self._expert_offsets_buf = torch.zeros(self.n_local_groups, dtype=torch.int32, device=self.device)
|
||||
# Pre-computed range [1, 2, 3, ..., n_groups] for expert offsets
|
||||
# Avoids torch.arange() per call (allocation) and Python loop (CPU→GPU sync)
|
||||
self._expert_offsets_range_buf = torch.arange(
|
||||
1, self.n_local_groups + 1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self._group_offset_buf = torch.zeros(self.n_local_groups, dtype=torch.int32, device=self.device)
|
||||
# Pre-allocate output buffer for graph capture (T=1 decode: 1, n_groups, o_rank)
|
||||
self._output_buf = torch.zeros(
|
||||
self.max_num_tokens, self.n_local_groups, self.o_lora_rank,
|
||||
dtype=torch.bfloat16, device=self.device
|
||||
)
|
||||
self._buffers_allocated = True
|
||||
|
||||
def _ensure_initialized(self):
|
||||
@@ -321,6 +332,13 @@ class Nvfp4GroupedLinear:
|
||||
|
||||
x_fp4_grouped = x_fp4_flat.reshape(self.n_local_groups, num_tokens, self.group_in_features // 2)
|
||||
|
||||
# Vectorized scatter — no Python loop, no CPU→GPU sync
|
||||
# Build destination offsets and use index_put_ with pre-allocated indices
|
||||
group_offsets = self._group_offset_buf[:self.n_local_groups]
|
||||
if group_offsets[0] != 0 or (self.n_local_groups > 1 and group_offsets[1] != padded_rows_per_group):
|
||||
# Update offsets (only when padded_rows changes, e.g., prefill vs decode)
|
||||
group_offsets.copy_(torch.arange(self.n_local_groups, dtype=torch.int32, device=o.device) * padded_rows_per_group)
|
||||
# Scatter each group's x_fp4 into padded buffer (vectorized)
|
||||
for g in range(self.n_local_groups):
|
||||
offset = g * padded_rows_per_group
|
||||
padded_x_fp4.view(torch.uint8)[offset:offset + num_tokens] = x_fp4_grouped[g].view(torch.uint8)
|
||||
@@ -336,9 +354,10 @@ class Nvfp4GroupedLinear:
|
||||
scale_a = assemble_scales_2d_side(all_x_sf)
|
||||
|
||||
# Expert offsets: cumulative [padded_T, 2*padded_T, ..., n_groups*padded_T]
|
||||
# GPU-only computation — no Python loop, no CPU→GPU sync
|
||||
expert_offsets = self._expert_offsets_buf
|
||||
for g in range(self.n_local_groups):
|
||||
expert_offsets[g] = (g + 1) * padded_rows_per_group
|
||||
# element-wise multiply: range * padded_rows → GPU tensor (no host sync)
|
||||
expert_offsets[:self.n_local_groups] = self._expert_offsets_range_buf * padded_rows_per_group
|
||||
|
||||
# Global scales — GPU-computed gsa already in _gsa_buf (no CPU sync)
|
||||
gsa = self._gsa_buf
|
||||
@@ -356,11 +375,18 @@ class Nvfp4GroupedLinear:
|
||||
|
||||
# Extract real outputs and reshape
|
||||
# GEMM output has the same layout as mat_a: groups-first with padding
|
||||
z = torch.empty(num_tokens, self.n_local_groups, self.o_lora_rank,
|
||||
dtype=torch.bfloat16, device=o.device)
|
||||
for g in range(self.n_local_groups):
|
||||
offset = g * padded_rows_per_group
|
||||
z[:, g, :] = out[offset:offset + num_tokens, :]
|
||||
# For CUDA graph capture (T=1 decode): use vectorized GPU gather — no Python loop.
|
||||
# For T>1 prefill: Python loop is OK (not graph-captured).
|
||||
z = self._output_buf[:num_tokens] if hasattr(self, '_output_buf') and self._output_buf is not None else torch.empty(num_tokens, self.n_local_groups, self.o_lora_rank, dtype=torch.bfloat16, device=o.device)
|
||||
if num_tokens == 1:
|
||||
# Vectorized: gather_indices = [0, padded_T, 2*padded_T, ...] — GPU-only
|
||||
gather_indices = self._expert_offsets_range_buf[:self.n_local_groups] * padded_rows_per_group - padded_rows_per_group
|
||||
z_flat = out[gather_indices] # (n_groups, o_rank) — GPU gather
|
||||
z[:, :, :] = z_flat.unsqueeze(0) # (1, n_groups, o_rank)
|
||||
else:
|
||||
for g in range(self.n_local_groups):
|
||||
offset = g * padded_rows_per_group
|
||||
z[:, g, :] = out[offset:offset + num_tokens, :]
|
||||
|
||||
return z
|
||||
|
||||
|
||||
@@ -1582,12 +1582,24 @@ def main():
|
||||
# Pre-allocate embedding output buffer — embed() returns a new tensor each call.
|
||||
# For graph capture, we'd copy into this buffer. For now, used as reference.
|
||||
dec_embed_buf = torch.zeros(1, H, dtype=torch.bfloat16, device='cuda:0')
|
||||
# Pre-allocate pinned CPU buffer for token ID transfer (graph-capturable)
|
||||
# Writing a Python int to a GPU tensor causes CPU→GPU sync. Instead:
|
||||
# 1. Write to pinned CPU buffer (no sync)
|
||||
# 2. copy_ to GPU buffer (async, graph-capturable)
|
||||
dec_tid_pinned = torch.zeros(1, dtype=torch.long, device='cpu').pin_memory()
|
||||
dec_tid32_pinned = torch.zeros(1, dtype=torch.int32, device='cpu').pin_memory()
|
||||
dec_pos_pinned = torch.zeros(1, dtype=torch.long, device='cpu').pin_memory()
|
||||
|
||||
for step in range(MAX_NEW_TOKENS):
|
||||
t1 = time.time()
|
||||
dec_tid_buf[0] = all_tokens[-1]
|
||||
dec_tid32_buf[0] = all_tokens[-1]
|
||||
dec_pos_buf[0] = len(all_tokens) - 1
|
||||
# Write token/position to pinned CPU buffers, then async copy to GPU
|
||||
# This avoids the CPU→GPU sync from dec_tid_buf[0] = python_int
|
||||
dec_tid_pinned[0] = all_tokens[-1]
|
||||
dec_tid_buf.copy_(dec_tid_pinned)
|
||||
dec_tid32_pinned[0] = all_tokens[-1]
|
||||
dec_tid32_buf.copy_(dec_tid32_pinned)
|
||||
dec_pos_pinned[0] = len(all_tokens) - 1
|
||||
dec_pos_buf.copy_(dec_pos_pinned)
|
||||
|
||||
t_e = time.perf_counter()
|
||||
X = mHCLayer.init_state(embed(dec_tid_buf), out_buf=dec_X_buf)
|
||||
|
||||
Reference in New Issue
Block a user