From 188ecae47fe470ac83a17355994cd8de014516a7 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 3 Jun 2026 21:30:24 +0000 Subject: [PATCH] CUDA graph: Eliminate per-step allocations in graph-captured code paths - gemm_runner.py: Add out= parameter to run_nvfp4_grouped_gemm and run_fused_swiglu_grouped_gemm to accept pre-allocated output buffers - quantize.py: Replace torch.zeros_like/torch.zeros with scalar 0.0 in torch.where() calls (graph-capturable, no memory allocation) - Both fixes prevent 'Disallowed operation during CUDA stream capture' errors during graph capture --- dsv4/ops/gemm_runner.py | 11 +++++++++-- dsv4/ops/quantize.py | 5 ++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/dsv4/ops/gemm_runner.py b/dsv4/ops/gemm_runner.py index e0711628..2b4f615c 100644 --- a/dsv4/ops/gemm_runner.py +++ b/dsv4/ops/gemm_runner.py @@ -170,6 +170,7 @@ def run_nvfp4_grouped_gemm( global_scale_b=None, # (experts,) float32 mma_tiler_mn=(128, 128), cluster_shape_mn=(1, 1), + out=None, # pre-allocated output buffer for CUDA graph capture ): """Run the CuTeDSL NVFP4 scaled grouped GEMM. @@ -184,7 +185,10 @@ def run_nvfp4_grouped_gemm( n_dim = mat_b.shape[2] tokens_sum = mat_a.shape[0] - out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device) + if out is None: + out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device) + else: + out.zero_() # NVFP4-3: use 2-CTA UMMA for M>=256 (1.7-1.9× throughput at prefill) use_2cta = tokens_sum >= 256 and cluster_shape_mn[0] % 2 == 0 @@ -424,7 +428,10 @@ def run_fused_swiglu_grouped_gemm( n_dim = mat_b.shape[2] tokens_sum = mat_a.shape[0] - out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device) + if out is None: + out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device) + else: + out.zero_() # NVFP4-3: use 2-CTA UMMA for M>=256 (1.7-1.9× throughput at prefill) # At decode (M<256), 1-CTA is correct (2-CTA wastes hardware) diff --git a/dsv4/ops/quantize.py b/dsv4/ops/quantize.py index c720b873..3a189555 100644 --- a/dsv4/ops/quantize.py +++ b/dsv4/ops/quantize.py @@ -143,11 +143,10 @@ def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE): block_amax = x_reshaped.abs().amax(dim=-1) # Detect zero blocks and underflow blocks (same threshold as quantize_to_nvfp4). zero_block = block_amax < (6.0 * 2.0 ** -9) - x_reshaped = torch.where(zero_block.unsqueeze(-1), - torch.zeros_like(x_reshaped), x_reshaped) + x_reshaped = torch.where(zero_block.unsqueeze(-1), 0.0, x_reshaped) block_amax = block_amax.clamp(min=1e-8, max=6.0 * 448.0) # E4M3 max = 448 block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn) - block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale) + block_scale = torch.where(zero_block, 0.0, block_scale) block_sf_expanded = block_scale.float().unsqueeze(-1) x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)