CUDA graph: Eliminate per-step allocations in graph-captured code paths

- gemm_runner.py: Add out= parameter to run_nvfp4_grouped_gemm and
  run_fused_swiglu_grouped_gemm to accept pre-allocated output buffers
- quantize.py: Replace torch.zeros_like/torch.zeros with scalar 0.0 in
  torch.where() calls (graph-capturable, no memory allocation)
- Both fixes prevent 'Disallowed operation during CUDA stream capture'
  errors during graph capture
This commit is contained in:
2026-06-03 21:30:24 +00:00
parent 91c370360a
commit 188ecae47f
2 changed files with 11 additions and 5 deletions

View File

@@ -170,6 +170,7 @@ def run_nvfp4_grouped_gemm(
global_scale_b=None, # (experts,) float32
mma_tiler_mn=(128, 128),
cluster_shape_mn=(1, 1),
out=None, # pre-allocated output buffer for CUDA graph capture
):
"""Run the CuTeDSL NVFP4 scaled grouped GEMM.
@@ -184,7 +185,10 @@ def run_nvfp4_grouped_gemm(
n_dim = mat_b.shape[2]
tokens_sum = mat_a.shape[0]
out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device)
if out is None:
out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device)
else:
out.zero_()
# NVFP4-3: use 2-CTA UMMA for M>=256 (1.7-1.9× throughput at prefill)
use_2cta = tokens_sum >= 256 and cluster_shape_mn[0] % 2 == 0
@@ -424,7 +428,10 @@ def run_fused_swiglu_grouped_gemm(
n_dim = mat_b.shape[2]
tokens_sum = mat_a.shape[0]
out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device)
if out is None:
out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device)
else:
out.zero_()
# NVFP4-3: use 2-CTA UMMA for M>=256 (1.7-1.9× throughput at prefill)
# At decode (M<256), 1-CTA is correct (2-CTA wastes hardware)

View File

@@ -143,11 +143,10 @@ def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE):
block_amax = x_reshaped.abs().amax(dim=-1)
# Detect zero blocks and underflow blocks (same threshold as quantize_to_nvfp4).
zero_block = block_amax < (6.0 * 2.0 ** -9)
x_reshaped = torch.where(zero_block.unsqueeze(-1),
torch.zeros_like(x_reshaped), x_reshaped)
x_reshaped = torch.where(zero_block.unsqueeze(-1), 0.0, x_reshaped)
block_amax = block_amax.clamp(min=1e-8, max=6.0 * 448.0) # E4M3 max = 448
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale)
block_scale = torch.where(zero_block, 0.0, block_scale)
block_sf_expanded = block_scale.float().unsqueeze(-1)
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)