From 2bfbdca23c60536f43850c9bec07fcc2b8f8e810 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 26 Mar 2026 14:51:00 +0800 Subject: [PATCH] [Bugfix] Fix benchmark_fused_collective.py (#38082) Signed-off-by: Jee Jee Li --- .../kernels/benchmark_fused_collective.py | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py index 05b842d7e..36cbd715f 100644 --- a/benchmarks/kernels/benchmark_fused_collective.py +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -25,6 +25,7 @@ import pandas as pd import torch # type: ignore import torch.distributed as dist # type: ignore +from vllm._custom_ops import create_fp4_output_tensors from vllm.config.vllm import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.distributed import ( tensor_model_parallel_all_reduce, @@ -46,7 +47,7 @@ RMS_NORM_STATIC_FP8_QUANT_OP = torch.ops._C.rms_norm_static_fp8_quant FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP = ( torch.ops._C.fused_add_rms_norm_static_fp8_quant ) -SCALED_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant +SCALED_FP4_QUANT_OUT_OP = torch.ops._C.scaled_fp4_quant.out logger = init_logger(__name__) @@ -334,13 +335,23 @@ class VllmFusedAllreduce: output_scale: torch.Tensor, ): allreduce_out = tensor_model_parallel_all_reduce(input_tensor) - rms_out = self.rms_norm(allreduce_out, residual) + rms_output = self.rms_norm(allreduce_out, residual) + if residual is None: + rms_out = rms_output + else: + rms_out, residual_out = rms_output + + SCALED_FP4_QUANT_OUT_OP( + rms_out, + input_global_scale, + True, + output=quant_out, + output_scale=output_scale, + ) + if residual is None: - SCALED_FP4_QUANT_OP(quant_out, rms_out, output_scale, input_global_scale) return quant_out, output_scale else: - rms_out, residual_out = rms_out - SCALED_FP4_QUANT_OP(quant_out, rms_out, output_scale, input_global_scale) return quant_out, residual_out, output_scale @@ -362,8 +373,9 @@ def create_test_tensors( scale_fp4 = torch.tensor(1.0, dtype=torch.float32) quant_out_fp8 = torch.empty_like(input_tensor, dtype=FP8_DTYPE) # Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks) - fp4_quant_out = torch.empty((num_tokens, hidden_dim // 2), dtype=torch.uint8) - fp4_output_scale = torch.empty((128, 4), dtype=torch.int32) + fp4_quant_out, fp4_output_scale = create_fp4_output_tensors( + num_tokens, hidden_dim, input_tensor.device, True + ) return ( input_tensor,