CUDA graph: Pre-allocate L1 GEMM output buffers in MoE and SharedExpert

Pass out= parameter to run_fused_swiglu_grouped_gemm to avoid per-step
torch.zeros() allocation during CUDA graph capture.
This commit is contained in:
2026-06-03 23:17:43 +00:00
parent 56b816a54f
commit a468f72a0e
2 changed files with 19 additions and 0 deletions

View File

@@ -90,6 +90,7 @@ class Nvfp4MoE:
self._padded_x_sf_buf_l2 = None
self._l1_gsa_buf = None
self._l2_gsa_buf = None
self._l1_out_buf = None # pre-allocated L1 GEMM output for graph capture
self._output_buf = None
self._row_indices_buf = None
self._padded_hidden_buf = None
@@ -164,6 +165,13 @@ class Nvfp4MoE:
self._l1_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
self._l2_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
# Pre-allocated L1 GEMM output — avoids torch.zeros() in run_fused_swiglu_grouped_gemm
# Shape: (max_tokens * top_k, intermediate_size) — max possible L1 output
self._l1_out_buf = torch.zeros(
self.max_num_tokens * self.top_k, self.intermediate_size,
dtype=torch.bfloat16, device=self.device
)
# Pre-allocated tokens-per-expert buffer — replaces torch.bincount
# (bincount produces data-dependent shapes, breaks CUDA graph capture)
self._tokens_per_expert_buf = torch.zeros(self.num_experts, dtype=torch.int32, device=self.device)
@@ -648,6 +656,7 @@ class Nvfp4MoE:
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
swiglu_limit=self._swiglu_limit if self._swiglu_limit is not None else 0.0,
out=self._l1_out_buf,
)
l1_out_real = l1_out[padded_dst]
# Fused deinterleave + amax + quantize: zero CPU syncs.

View File

@@ -91,6 +91,9 @@ class Nvfp4SharedExpert:
self._l1_activation_global_scale = 1.0 / (6.0 * 448.0)
self._l2_activation_global_scale = 1.0 / (6.0 * 448.0)
# Pre-allocated L1 GEMM output for graph capture
self._l1_out_buf = None
# Pre-allocated cudagraph buffers (set in _allocate_buffers)
self._padded_x_fp4_buf_l1 = None
self._padded_x_sf_buf_l1 = None
@@ -179,6 +182,12 @@ class Nvfp4SharedExpert:
# Global scale buffers
self._l1_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device)
self._l2_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device)
# Pre-allocated L1 output buffer for graph capture
self._l1_out_buf = torch.zeros(
max_rows, self.intermediate_size,
dtype=torch.bfloat16, device=self.device
)
# Expert offsets for num_groups=1: just [num_tokens_padded]
# The GEMM expects expert_offsets as (num_experts,) cumulative offsets
@@ -290,6 +299,7 @@ class Nvfp4SharedExpert:
global_scale_a=gsa,
global_scale_b=self._l1_gsb,
swiglu_limit=self.swiglu_limit if self.swiglu_limit is not None else 0.0,
out=self._l1_out_buf,
)
l1_out_real = l1_out[:num_tokens] # (num_tokens, 2*intermediate) BF16, interleaved [silu(gate), silu(gate)*up]
# Deinterleave to separate gate and up, then take up half (SwiGLU result)