Pre-allocate ALL GEMM output buffers for CUDA graph capture

Every run_nvfp4_grouped_gemm call must pass out= with a pre-allocated
buffer. During CUDA graph capture, torch.zeros() allocations are
forbidden — they cause 'cudaErrorStreamCaptureUnsupported' errors.

Added:
- shared_expert: _l2_out_buf for L2 GEMM
- shared_expert: pass out= for both L1 and L2 GEMM calls
- moe: _l2_out_buf for L2 GEMM
- moe: pass out= for unfused L1 GEMM (fused L1 already had it)
- moe: pass out= for L2 GEMM
- linear: _gemm_out_buf for all GEMM calls
- linear: pass out= for both run() and run_from_quantized() paths

grouped_linear already had _output_buf_padded — no changes needed.
This commit is contained in:
2026-06-04 02:41:59 +00:00
parent 676a0448c0
commit e7766254b7
3 changed files with 24 additions and 0 deletions

View File

@@ -65,6 +65,7 @@ class Nvfp4Linear:
self._padded_x_fp4_buf = None
self._expert_offsets_buf = None
self._gsa_buf = None
self._gemm_out_buf = None # pre-allocated GEMM output for graph capture
self._buffers_allocated = False
def finalize_weights(self):
@@ -134,6 +135,11 @@ class Nvfp4Linear:
self._scale_a_buf = torch.zeros(
max_padded_rows, max_padded_cols, dtype=torch.float16, device=self.device
).to(torch.float8_e4m3fn)
# Pre-allocated GEMM output buffer for graph capture
self._gemm_out_buf = torch.zeros(
max_padded_rows, self.out_features, dtype=torch.bfloat16, device=self.device
)
def _ensure_initialized(self):
if self._mat_b is None:
@@ -241,6 +247,7 @@ class Nvfp4Linear:
expert_offsets=expert_offsets,
global_scale_a=gsa,
global_scale_b=self._gsb,
out=self._gemm_out_buf,
)
return out[:num_tokens]
@@ -298,6 +305,7 @@ class Nvfp4Linear:
expert_offsets=expert_offsets,
global_scale_a=self._gsa_buf,
global_scale_b=self._gsb,
out=self._gemm_out_buf,
)
return out[:num_tokens]

View File

@@ -171,6 +171,12 @@ class Nvfp4MoE:
self.max_num_tokens * self.top_k, 2 * self.intermediate_size,
dtype=torch.bfloat16, device=self.device
)
# Pre-allocated L2 GEMM output — avoids torch.zeros() in run_nvfp4_grouped_gemm
# Shape: (max_tokens * top_k, hidden_size) — down projection
self._l2_out_buf = torch.zeros(
self.max_num_tokens * self.top_k, self.hidden_size,
dtype=torch.bfloat16, device=self.device
)
# Pre-allocated tokens-per-expert buffer — replaces torch.bincount
# (bincount produces data-dependent shapes, breaks CUDA graph capture)
@@ -678,6 +684,7 @@ class Nvfp4MoE:
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
out=self._l1_out_buf,
)
l1_out_real = l1_out[padded_dst]
l1_deil = deinterleave_l1_weights(l1_out_real.unsqueeze(0).contiguous())[0]
@@ -715,6 +722,7 @@ class Nvfp4MoE:
scale_a=l2_scale_a, scale_b=self._l2_scale_b,
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
global_scale_a=l2_gsa, global_scale_b=self._l2_gsb,
out=self._l2_out_buf,
)
l2_out_real = l2_out[padded_dst]

View File

@@ -189,6 +189,12 @@ class Nvfp4SharedExpert:
max_rows, 2 * self.intermediate_size,
dtype=torch.bfloat16, device=self.device
)
# Pre-allocated L2 output buffer for graph capture
# L2 produces hidden_size BF16 columns (down projection)
self._l2_out_buf = torch.zeros(
max_rows, self.hidden_size,
dtype=torch.bfloat16, device=self.device
)
# Expert offsets for num_groups=1: just [num_tokens_padded]
# The GEMM expects expert_offsets as (num_experts,) cumulative offsets
@@ -347,6 +353,7 @@ class Nvfp4SharedExpert:
expert_offsets=expert_offsets,
global_scale_a=gsa,
global_scale_b=self._l1_gsb,
out=self._l1_out_buf,
)
# Extract real token outputs
@@ -397,6 +404,7 @@ class Nvfp4SharedExpert:
expert_offsets=expert_offsets,
global_scale_a=gsa,
global_scale_b=self._l2_gsb,
out=self._l2_out_buf,
)
return out[:num_tokens]