[Bugfix] Fix benchmark_fused_collective.py (#38082)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -25,6 +25,7 @@ import pandas as pd
|
||||
import torch # type: ignore
|
||||
import torch.distributed as dist # type: ignore
|
||||
|
||||
from vllm._custom_ops import create_fp4_output_tensors
|
||||
from vllm.config.vllm import CompilationConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
@@ -46,7 +47,7 @@ RMS_NORM_STATIC_FP8_QUANT_OP = torch.ops._C.rms_norm_static_fp8_quant
|
||||
FUSED_ADD_RMS_NORM_STATIC_FP8_QUANT_OP = (
|
||||
torch.ops._C.fused_add_rms_norm_static_fp8_quant
|
||||
)
|
||||
SCALED_FP4_QUANT_OP = torch.ops._C.scaled_fp4_quant
|
||||
SCALED_FP4_QUANT_OUT_OP = torch.ops._C.scaled_fp4_quant.out
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -334,13 +335,23 @@ class VllmFusedAllreduce:
|
||||
output_scale: torch.Tensor,
|
||||
):
|
||||
allreduce_out = tensor_model_parallel_all_reduce(input_tensor)
|
||||
rms_out = self.rms_norm(allreduce_out, residual)
|
||||
rms_output = self.rms_norm(allreduce_out, residual)
|
||||
if residual is None:
|
||||
rms_out = rms_output
|
||||
else:
|
||||
rms_out, residual_out = rms_output
|
||||
|
||||
SCALED_FP4_QUANT_OUT_OP(
|
||||
rms_out,
|
||||
input_global_scale,
|
||||
True,
|
||||
output=quant_out,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
|
||||
if residual is None:
|
||||
SCALED_FP4_QUANT_OP(quant_out, rms_out, output_scale, input_global_scale)
|
||||
return quant_out, output_scale
|
||||
else:
|
||||
rms_out, residual_out = rms_out
|
||||
SCALED_FP4_QUANT_OP(quant_out, rms_out, output_scale, input_global_scale)
|
||||
return quant_out, residual_out, output_scale
|
||||
|
||||
|
||||
@@ -362,8 +373,9 @@ def create_test_tensors(
|
||||
scale_fp4 = torch.tensor(1.0, dtype=torch.float32)
|
||||
quant_out_fp8 = torch.empty_like(input_tensor, dtype=FP8_DTYPE)
|
||||
# Pre-allocate FP4 output tensors (to avoid allocation overhead in benchmarks)
|
||||
fp4_quant_out = torch.empty((num_tokens, hidden_dim // 2), dtype=torch.uint8)
|
||||
fp4_output_scale = torch.empty((128, 4), dtype=torch.int32)
|
||||
fp4_quant_out, fp4_output_scale = create_fp4_output_tensors(
|
||||
num_tokens, hidden_dim, input_tensor.device, True
|
||||
)
|
||||
|
||||
return (
|
||||
input_tensor,
|
||||
|
||||
Reference in New Issue
Block a user