diff --git a/dsv4/layers/linear.py b/dsv4/layers/linear.py index 82f12aaa..ba3edc9f 100644 --- a/dsv4/layers/linear.py +++ b/dsv4/layers/linear.py @@ -65,6 +65,7 @@ class Nvfp4Linear: self._padded_x_fp4_buf = None self._expert_offsets_buf = None self._gsa_buf = None + self._gemm_out_buf = None # pre-allocated GEMM output for graph capture self._buffers_allocated = False def finalize_weights(self): @@ -134,6 +135,11 @@ class Nvfp4Linear: self._scale_a_buf = torch.zeros( max_padded_rows, max_padded_cols, dtype=torch.float16, device=self.device ).to(torch.float8_e4m3fn) + + # Pre-allocated GEMM output buffer for graph capture + self._gemm_out_buf = torch.zeros( + max_padded_rows, self.out_features, dtype=torch.bfloat16, device=self.device + ) def _ensure_initialized(self): if self._mat_b is None: @@ -241,6 +247,7 @@ class Nvfp4Linear: expert_offsets=expert_offsets, global_scale_a=gsa, global_scale_b=self._gsb, + out=self._gemm_out_buf, ) return out[:num_tokens] @@ -298,6 +305,7 @@ class Nvfp4Linear: expert_offsets=expert_offsets, global_scale_a=self._gsa_buf, global_scale_b=self._gsb, + out=self._gemm_out_buf, ) return out[:num_tokens] diff --git a/dsv4/layers/moe.py b/dsv4/layers/moe.py index 8743f40c..59c48369 100644 --- a/dsv4/layers/moe.py +++ b/dsv4/layers/moe.py @@ -171,6 +171,12 @@ class Nvfp4MoE: self.max_num_tokens * self.top_k, 2 * self.intermediate_size, dtype=torch.bfloat16, device=self.device ) + # Pre-allocated L2 GEMM output — avoids torch.zeros() in run_nvfp4_grouped_gemm + # Shape: (max_tokens * top_k, hidden_size) — down projection + self._l2_out_buf = torch.zeros( + self.max_num_tokens * self.top_k, self.hidden_size, + dtype=torch.bfloat16, device=self.device + ) # Pre-allocated tokens-per-expert buffer — replaces torch.bincount # (bincount produces data-dependent shapes, breaks CUDA graph capture) @@ -678,6 +684,7 @@ class Nvfp4MoE: scale_a=l1_scale_a, scale_b=self._l1_scale_b, expert_offsets=padded_expert_offsets[1:self.num_experts + 1], global_scale_a=l1_gsa, global_scale_b=self._l1_gsb, + out=self._l1_out_buf, ) l1_out_real = l1_out[padded_dst] l1_deil = deinterleave_l1_weights(l1_out_real.unsqueeze(0).contiguous())[0] @@ -715,6 +722,7 @@ class Nvfp4MoE: scale_a=l2_scale_a, scale_b=self._l2_scale_b, expert_offsets=padded_expert_offsets[1:self.num_experts + 1], global_scale_a=l2_gsa, global_scale_b=self._l2_gsb, + out=self._l2_out_buf, ) l2_out_real = l2_out[padded_dst] diff --git a/dsv4/layers/shared_expert.py b/dsv4/layers/shared_expert.py index 7a02e58c..c3d36ed3 100644 --- a/dsv4/layers/shared_expert.py +++ b/dsv4/layers/shared_expert.py @@ -189,6 +189,12 @@ class Nvfp4SharedExpert: max_rows, 2 * self.intermediate_size, dtype=torch.bfloat16, device=self.device ) + # Pre-allocated L2 output buffer for graph capture + # L2 produces hidden_size BF16 columns (down projection) + self._l2_out_buf = torch.zeros( + max_rows, self.hidden_size, + dtype=torch.bfloat16, device=self.device + ) # Expert offsets for num_groups=1: just [num_tokens_padded] # The GEMM expects expert_offsets as (num_experts,) cumulative offsets @@ -347,6 +353,7 @@ class Nvfp4SharedExpert: expert_offsets=expert_offsets, global_scale_a=gsa, global_scale_b=self._l1_gsb, + out=self._l1_out_buf, ) # Extract real token outputs @@ -397,6 +404,7 @@ class Nvfp4SharedExpert: expert_offsets=expert_offsets, global_scale_a=gsa, global_scale_b=self._l2_gsb, + out=self._l2_out_buf, ) return out[:num_tokens]