diff --git a/dsv4/ops/gemm_runner.py b/dsv4/ops/gemm_runner.py index e0711628..2b4f615c 100644 --- a/dsv4/ops/gemm_runner.py +++ b/dsv4/ops/gemm_runner.py @@ -170,6 +170,7 @@ def run_nvfp4_grouped_gemm( global_scale_b=None, # (experts,) float32 mma_tiler_mn=(128, 128), cluster_shape_mn=(1, 1), + out=None, # pre-allocated output buffer for CUDA graph capture ): """Run the CuTeDSL NVFP4 scaled grouped GEMM. @@ -184,7 +185,10 @@ def run_nvfp4_grouped_gemm( n_dim = mat_b.shape[2] tokens_sum = mat_a.shape[0] - out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device) + if out is None: + out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device) + else: + out.zero_() # NVFP4-3: use 2-CTA UMMA for M>=256 (1.7-1.9× throughput at prefill) use_2cta = tokens_sum >= 256 and cluster_shape_mn[0] % 2 == 0 @@ -424,7 +428,10 @@ def run_fused_swiglu_grouped_gemm( n_dim = mat_b.shape[2] tokens_sum = mat_a.shape[0] - out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device) + if out is None: + out = torch.zeros(tokens_sum, n_dim, dtype=torch.bfloat16, device=mat_a.device) + else: + out.zero_() # NVFP4-3: use 2-CTA UMMA for M>=256 (1.7-1.9× throughput at prefill) # At decode (M<256), 1-CTA is correct (2-CTA wastes hardware) diff --git a/dsv4/ops/quantize.py b/dsv4/ops/quantize.py index c720b873..3a189555 100644 --- a/dsv4/ops/quantize.py +++ b/dsv4/ops/quantize.py @@ -143,11 +143,10 @@ def quantize_activation_nvfp4(x_bf16, global_scale, block_size=SF_VEC_SIZE): block_amax = x_reshaped.abs().amax(dim=-1) # Detect zero blocks and underflow blocks (same threshold as quantize_to_nvfp4). zero_block = block_amax < (6.0 * 2.0 ** -9) - x_reshaped = torch.where(zero_block.unsqueeze(-1), - torch.zeros_like(x_reshaped), x_reshaped) + x_reshaped = torch.where(zero_block.unsqueeze(-1), 0.0, x_reshaped) block_amax = block_amax.clamp(min=1e-8, max=6.0 * 448.0) # E4M3 max = 448 block_scale = (block_amax / 6.0).to(torch.float8_e4m3fn) - block_scale = torch.where(zero_block, torch.zeros_like(block_scale), block_scale) + block_scale = torch.where(zero_block, 0.0, block_scale) block_sf_expanded = block_scale.float().unsqueeze(-1) x_scaled = x_reshaped / block_sf_expanded.clamp(min=1e-8)