Pre-allocate ALL GEMM output buffers for CUDA graph capture
Every run_nvfp4_grouped_gemm call must pass out= with a pre-allocated buffer. During CUDA graph capture, torch.zeros() allocations are forbidden — they cause 'cudaErrorStreamCaptureUnsupported' errors. Added: - shared_expert: _l2_out_buf for L2 GEMM - shared_expert: pass out= for both L1 and L2 GEMM calls - moe: _l2_out_buf for L2 GEMM - moe: pass out= for unfused L1 GEMM (fused L1 already had it) - moe: pass out= for L2 GEMM - linear: _gemm_out_buf for all GEMM calls - linear: pass out= for both run() and run_from_quantized() paths grouped_linear already had _output_buf_padded — no changes needed.
This commit is contained in:
@@ -65,6 +65,7 @@ class Nvfp4Linear:
|
||||
self._padded_x_fp4_buf = None
|
||||
self._expert_offsets_buf = None
|
||||
self._gsa_buf = None
|
||||
self._gemm_out_buf = None # pre-allocated GEMM output for graph capture
|
||||
self._buffers_allocated = False
|
||||
|
||||
def finalize_weights(self):
|
||||
@@ -134,6 +135,11 @@ class Nvfp4Linear:
|
||||
self._scale_a_buf = torch.zeros(
|
||||
max_padded_rows, max_padded_cols, dtype=torch.float16, device=self.device
|
||||
).to(torch.float8_e4m3fn)
|
||||
|
||||
# Pre-allocated GEMM output buffer for graph capture
|
||||
self._gemm_out_buf = torch.zeros(
|
||||
max_padded_rows, self.out_features, dtype=torch.bfloat16, device=self.device
|
||||
)
|
||||
|
||||
def _ensure_initialized(self):
|
||||
if self._mat_b is None:
|
||||
@@ -241,6 +247,7 @@ class Nvfp4Linear:
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=gsa,
|
||||
global_scale_b=self._gsb,
|
||||
out=self._gemm_out_buf,
|
||||
)
|
||||
|
||||
return out[:num_tokens]
|
||||
@@ -298,6 +305,7 @@ class Nvfp4Linear:
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=self._gsa_buf,
|
||||
global_scale_b=self._gsb,
|
||||
out=self._gemm_out_buf,
|
||||
)
|
||||
|
||||
return out[:num_tokens]
|
||||
|
||||
@@ -171,6 +171,12 @@ class Nvfp4MoE:
|
||||
self.max_num_tokens * self.top_k, 2 * self.intermediate_size,
|
||||
dtype=torch.bfloat16, device=self.device
|
||||
)
|
||||
# Pre-allocated L2 GEMM output — avoids torch.zeros() in run_nvfp4_grouped_gemm
|
||||
# Shape: (max_tokens * top_k, hidden_size) — down projection
|
||||
self._l2_out_buf = torch.zeros(
|
||||
self.max_num_tokens * self.top_k, self.hidden_size,
|
||||
dtype=torch.bfloat16, device=self.device
|
||||
)
|
||||
|
||||
# Pre-allocated tokens-per-expert buffer — replaces torch.bincount
|
||||
# (bincount produces data-dependent shapes, breaks CUDA graph capture)
|
||||
@@ -678,6 +684,7 @@ class Nvfp4MoE:
|
||||
scale_a=l1_scale_a, scale_b=self._l1_scale_b,
|
||||
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
||||
global_scale_a=l1_gsa, global_scale_b=self._l1_gsb,
|
||||
out=self._l1_out_buf,
|
||||
)
|
||||
l1_out_real = l1_out[padded_dst]
|
||||
l1_deil = deinterleave_l1_weights(l1_out_real.unsqueeze(0).contiguous())[0]
|
||||
@@ -715,6 +722,7 @@ class Nvfp4MoE:
|
||||
scale_a=l2_scale_a, scale_b=self._l2_scale_b,
|
||||
expert_offsets=padded_expert_offsets[1:self.num_experts + 1],
|
||||
global_scale_a=l2_gsa, global_scale_b=self._l2_gsb,
|
||||
out=self._l2_out_buf,
|
||||
)
|
||||
|
||||
l2_out_real = l2_out[padded_dst]
|
||||
|
||||
@@ -189,6 +189,12 @@ class Nvfp4SharedExpert:
|
||||
max_rows, 2 * self.intermediate_size,
|
||||
dtype=torch.bfloat16, device=self.device
|
||||
)
|
||||
# Pre-allocated L2 output buffer for graph capture
|
||||
# L2 produces hidden_size BF16 columns (down projection)
|
||||
self._l2_out_buf = torch.zeros(
|
||||
max_rows, self.hidden_size,
|
||||
dtype=torch.bfloat16, device=self.device
|
||||
)
|
||||
|
||||
# Expert offsets for num_groups=1: just [num_tokens_padded]
|
||||
# The GEMM expects expert_offsets as (num_experts,) cumulative offsets
|
||||
@@ -347,6 +353,7 @@ class Nvfp4SharedExpert:
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=gsa,
|
||||
global_scale_b=self._l1_gsb,
|
||||
out=self._l1_out_buf,
|
||||
)
|
||||
|
||||
# Extract real token outputs
|
||||
@@ -397,6 +404,7 @@ class Nvfp4SharedExpert:
|
||||
expert_offsets=expert_offsets,
|
||||
global_scale_a=gsa,
|
||||
global_scale_b=self._l2_gsb,
|
||||
out=self._l2_out_buf,
|
||||
)
|
||||
|
||||
return out[:num_tokens]
|
||||
|
||||
Reference in New Issue
Block a user