CUDA graph: Pre-allocate L1 GEMM output buffers in MoE and SharedExpert
Pass out= parameter to run_fused_swiglu_grouped_gemm to avoid per-step torch.zeros() allocation during CUDA graph capture.
This commit is contained in:
@@ -90,6 +90,7 @@ class Nvfp4MoE:
|
||||
self._padded_x_sf_buf_l2 = None
|
||||
self._l1_gsa_buf = None
|
||||
self._l2_gsa_buf = None
|
||||
self._l1_out_buf = None # pre-allocated L1 GEMM output for graph capture
|
||||
self._output_buf = None
|
||||
self._row_indices_buf = None
|
||||
self._padded_hidden_buf = None
|
||||
@@ -164,6 +165,13 @@ class Nvfp4MoE:
|
||||
self._l1_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
|
||||
self._l2_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device)
|
||||
|
||||
# Pre-allocated L1 GEMM output — avoids torch.zeros() in run_fused_swiglu_grouped_gemm
|
||||
# Shape: (max_tokens * top_k, intermediate_size) — max possible L1 output
|
||||
self._l1_out_buf = torch.zeros(
|
||||
self.max_num_tokens * self.top_k, self.intermediate_size,
|
||||
dtype=torch.bfloat16, device=self.device
|
||||
)
|
||||
|
||||
# Pre-allocated tokens-per-expert buffer — replaces torch.bincount
|
||||
# (bincount produces data-dependent shapes, breaks CUDA graph capture)
|
||||
self._tokens_per_expert_buf = torch.zeros(self.num_experts, dtype=torch.int32, device=self.device)
|
||||
@@ -648,6 +656,7 @@ class Nvfp4MoE:
|
||||
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
||||
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
|
||||
swiglu_limit=self._swiglu_limit if self._swiglu_limit is not None else 0.0,
|
||||
out=self._l1_out_buf,
|
||||
)
|
||||
l1_out_real = l1_out[padded_dst]
|
||||
# Fused deinterleave + amax + quantize: zero CPU syncs.
|
||||
|
||||
@@ -91,6 +91,9 @@ class Nvfp4SharedExpert:
|
||||
self._l1_activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
self._l2_activation_global_scale = 1.0 / (6.0 * 448.0)
|
||||
|
||||
# Pre-allocated L1 GEMM output for graph capture
|
||||
self._l1_out_buf = None
|
||||
|
||||
# Pre-allocated cudagraph buffers (set in _allocate_buffers)
|
||||
self._padded_x_fp4_buf_l1 = None
|
||||
self._padded_x_sf_buf_l1 = None
|
||||
@@ -179,6 +182,12 @@ class Nvfp4SharedExpert:
|
||||
# Global scale buffers
|
||||
self._l1_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device)
|
||||
self._l2_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device)
|
||||
|
||||
# Pre-allocated L1 output buffer for graph capture
|
||||
self._l1_out_buf = torch.zeros(
|
||||
max_rows, self.intermediate_size,
|
||||
dtype=torch.bfloat16, device=self.device
|
||||
)
|
||||
|
||||
# Expert offsets for num_groups=1: just [num_tokens_padded]
|
||||
# The GEMM expects expert_offsets as (num_experts,) cumulative offsets
|
||||
@@ -290,6 +299,7 @@ class Nvfp4SharedExpert:
|
||||
global_scale_a=gsa,
|
||||
global_scale_b=self._l1_gsb,
|
||||
swiglu_limit=self.swiglu_limit if self.swiglu_limit is not None else 0.0,
|
||||
out=self._l1_out_buf,
|
||||
)
|
||||
l1_out_real = l1_out[:num_tokens] # (num_tokens, 2*intermediate) BF16, interleaved [silu(gate), silu(gate)*up]
|
||||
# Deinterleave to separate gate and up, then take up half (SwiGLU result)
|
||||
|
||||
Reference in New Issue
Block a user