diff --git a/dsv4/layers/moe.py b/dsv4/layers/moe.py index 3d1b33fb..1744e767 100644 --- a/dsv4/layers/moe.py +++ b/dsv4/layers/moe.py @@ -90,6 +90,7 @@ class Nvfp4MoE: self._padded_x_sf_buf_l2 = None self._l1_gsa_buf = None self._l2_gsa_buf = None + self._l1_out_buf = None # pre-allocated L1 GEMM output for graph capture self._output_buf = None self._row_indices_buf = None self._padded_hidden_buf = None @@ -164,6 +165,13 @@ class Nvfp4MoE: self._l1_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device) self._l2_gsa_buf = torch.zeros(self.num_experts, dtype=torch.float32, device=self.device) + # Pre-allocated L1 GEMM output — avoids torch.zeros() in run_fused_swiglu_grouped_gemm + # Shape: (max_tokens * top_k, intermediate_size) — max possible L1 output + self._l1_out_buf = torch.zeros( + self.max_num_tokens * self.top_k, self.intermediate_size, + dtype=torch.bfloat16, device=self.device + ) + # Pre-allocated tokens-per-expert buffer — replaces torch.bincount # (bincount produces data-dependent shapes, breaks CUDA graph capture) self._tokens_per_expert_buf = torch.zeros(self.num_experts, dtype=torch.int32, device=self.device) @@ -648,6 +656,7 @@ class Nvfp4MoE: expert_offsets=padded_expert_offsets[1:self.num_experts + 1], global_scale_a=l1_gsa, global_scale_b=self._l1_gsb, swiglu_limit=self._swiglu_limit if self._swiglu_limit is not None else 0.0, + out=self._l1_out_buf, ) l1_out_real = l1_out[padded_dst] # Fused deinterleave + amax + quantize: zero CPU syncs. diff --git a/dsv4/layers/shared_expert.py b/dsv4/layers/shared_expert.py index cb936673..4d32a548 100644 --- a/dsv4/layers/shared_expert.py +++ b/dsv4/layers/shared_expert.py @@ -91,6 +91,9 @@ class Nvfp4SharedExpert: self._l1_activation_global_scale = 1.0 / (6.0 * 448.0) self._l2_activation_global_scale = 1.0 / (6.0 * 448.0) + # Pre-allocated L1 GEMM output for graph capture + self._l1_out_buf = None + # Pre-allocated cudagraph buffers (set in _allocate_buffers) self._padded_x_fp4_buf_l1 = None self._padded_x_sf_buf_l1 = None @@ -179,6 +182,12 @@ class Nvfp4SharedExpert: # Global scale buffers self._l1_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device) self._l2_gsa_buf = torch.zeros(1, dtype=torch.float32, device=self.device) + + # Pre-allocated L1 output buffer for graph capture + self._l1_out_buf = torch.zeros( + max_rows, self.intermediate_size, + dtype=torch.bfloat16, device=self.device + ) # Expert offsets for num_groups=1: just [num_tokens_padded] # The GEMM expects expert_offsets as (num_experts,) cumulative offsets @@ -290,6 +299,7 @@ class Nvfp4SharedExpert: global_scale_a=gsa, global_scale_b=self._l1_gsb, swiglu_limit=self.swiglu_limit if self.swiglu_limit is not None else 0.0, + out=self._l1_out_buf, ) l1_out_real = l1_out[:num_tokens] # (num_tokens, 2*intermediate) BF16, interleaved [silu(gate), silu(gate)*up] # Deinterleave to separate gate and up, then take up half (SwiGLU result)