CUDA graph: Eliminate per-step allocations in graph-captured code paths
- gemm_runner.py: Add out= parameter to run_nvfp4_grouped_gemm and run_fused_swiglu_grouped_gemm to accept pre-allocated output buffers - quantize.py: Replace torch.zeros_like/torch.zeros with scalar 0.0 in torch.where() calls (graph-capturable, no memory allocation) - Both fixes prevent 'Disallowed operation during CUDA stream capture' errors during graph capture
This commit is contained in:
@@ -170,6 +170,7 @@ def run_nvfp4_grouped_gemm(
|
||||
global_scale_b=None, # (experts,) float32
|
||||
mma_tiler_mn=(128, 128),
|
||||
cluster_shape_mn=(1, 1),
|
||||
out=None, # pre-allocated output buffer for CUDA graph capture
|
||||
):
|
||||
"""Run the CuTeDSL NVFP4 scaled grouped GEMM.
|
||||
|
||||
@@ -184,7 +185,10 @@ def run_nvfp4_grouped_gemm(
|
||||
n_dim = mat_b.shape[2]
|
||||
tokens_sum = mat_a.shape[0]
|
||||
|
||||
out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device)
|
||||
if out is None:
|
||||
out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device)
|
||||
else:
|
||||
out.zero_()
|
||||
|
||||
# NVFP4-3: use 2-CTA UMMA for M>=256 (1.7-1.9× throughput at prefill)
|
||||
use_2cta = tokens_sum >= 256 and cluster_shape_mn[0] % 2 == 0
|
||||
@@ -424,7 +428,10 @@ def run_fused_swiglu_grouped_gemm(
|
||||
n_dim = mat_b.shape[2]
|
||||
tokens_sum = mat_a.shape[0]
|
||||
|
||||
out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device)
|
||||
if out is None:
|
||||
out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device)
|
||||
else:
|
||||
out.zero_()
|
||||
|
||||
# NVFP4-3: use 2-CTA UMMA for M>=256 (1.7-1.9× throughput at prefill)
|
||||
# At decode (M<256), 1-CTA is correct (2-CTA wastes hardware)
|
||||
|
||||
@@ -143,11 +143,10 @@ def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE):
|
||||
block_amax = x_reshaped.abs().amax(dim=-1)
|
||||
# Detect zero blocks and underflow blocks (same threshold as quantize_to_nvfp4).
|
||||
zero_block = block_amax < (6.0 * 2.0 ** -9)
|
||||
x_reshaped = torch.where(zero_block.unsqueeze(-1),
|
||||
torch.zeros_like(x_reshaped), x_reshaped)
|
||||
x_reshaped = torch.where(zero_block.unsqueeze(-1), 0.0, x_reshaped)
|
||||
block_amax = block_amax.clamp(min=1e-8, max=6.0 * 448.0) # E4M3 max = 448
|
||||
block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn)
|
||||
block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale)
|
||||
block_scale = torch.where(zero_block, 0.0, block_scale)
|
||||
|
||||
block_sf_expanded = block_scale.float().unsqueeze(-1)
|
||||
x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)
|
||||
|
||||
Reference in New Issue
Block a user