[Bugfix] Fix benchmark_fused_collective crash on CustomOp init (#34665)
Signed-off-by: Mayank Ketkar <mketkar@zoox.com> Signed-off-by: Mayank Ketkar <mayket04@gmail.com> Co-authored-by: Mayank Ketkar <mketkar@zoox.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -408,18 +408,18 @@ def run_benchmarks(
|
|||||||
|
|
||||||
rms_eps = 1e-6
|
rms_eps = 1e-6
|
||||||
results = {}
|
results = {}
|
||||||
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
|
|
||||||
use_oneshot_options = [False] if no_oneshot else [True, False]
|
use_oneshot_options = [False] if no_oneshot else [True, False]
|
||||||
|
|
||||||
# Create RMSNorm and QuantFP8 layers once for native benchmarks
|
|
||||||
|
|
||||||
if "none" in quant_modes:
|
if "none" in quant_modes:
|
||||||
# Standard AllReduce + RMSNorm
|
# Standard AllReduce + RMSNorm
|
||||||
|
# Re-create VllmFusedAllreduce per config so CustomOp binds the
|
||||||
|
# correct forward method (native vs custom kernel).
|
||||||
for custom_op in ["-rms_norm", "+rms_norm"]:
|
for custom_op in ["-rms_norm", "+rms_norm"]:
|
||||||
with set_current_vllm_config(
|
with set_current_vllm_config(
|
||||||
VllmConfig(compilation_config=CompilationConfig(custom_ops=[custom_op]))
|
VllmConfig(compilation_config=CompilationConfig(custom_ops=[custom_op]))
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
|
||||||
suffix = (
|
suffix = (
|
||||||
"_custom_rms_norm" if "+" in custom_op else "_native_rms_norm"
|
"_custom_rms_norm" if "+" in custom_op else "_native_rms_norm"
|
||||||
)
|
)
|
||||||
@@ -438,6 +438,7 @@ def run_benchmarks(
|
|||||||
VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"]))
|
VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"]))
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
|
||||||
standard_allreduce_rmsnorm_native_compiled = torch.compile(
|
standard_allreduce_rmsnorm_native_compiled = torch.compile(
|
||||||
vllm_fused_allreduce.allreduce_rmsnorm,
|
vllm_fused_allreduce.allreduce_rmsnorm,
|
||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
@@ -482,7 +483,7 @@ def run_benchmarks(
|
|||||||
"_custom_rms_norm" if "+" in rms_norm_custom_op else "_native_rms_norm"
|
"_custom_rms_norm" if "+" in rms_norm_custom_op else "_native_rms_norm"
|
||||||
)
|
)
|
||||||
for quant_fp8_custom_op in ["-quant_fp8", "+quant_fp8"]:
|
for quant_fp8_custom_op in ["-quant_fp8", "+quant_fp8"]:
|
||||||
suffix += (
|
op_suffix = suffix + (
|
||||||
"_custom_quant_fp8"
|
"_custom_quant_fp8"
|
||||||
if "+" in quant_fp8_custom_op
|
if "+" in quant_fp8_custom_op
|
||||||
else "_native_quant_fp8"
|
else "_native_quant_fp8"
|
||||||
@@ -495,16 +496,17 @@ def run_benchmarks(
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
|
||||||
time_ms = benchmark_operation(
|
time_ms = benchmark_operation(
|
||||||
vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant,
|
vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant,
|
||||||
input_tensor,
|
input_tensor,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
scale_factor=scale_fp8,
|
scale_factor=scale_fp8,
|
||||||
)
|
)
|
||||||
results[f"standard_allreduce{suffix}"] = time_ms
|
results[f"standard_allreduce{op_suffix}"] = time_ms
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e)
|
logger.error("Standard AllReduce+RMSNorm+FP8 failed: %s", e)
|
||||||
results[f"standard_allreduce{suffix}"] = float("inf")
|
results[f"standard_allreduce{op_suffix}"] = float("inf")
|
||||||
|
|
||||||
# Standard AllReduce + RMSNorm + FP8 Quant Native Compiled
|
# Standard AllReduce + RMSNorm + FP8 Quant Native Compiled
|
||||||
with set_current_vllm_config(
|
with set_current_vllm_config(
|
||||||
@@ -515,6 +517,7 @@ def run_benchmarks(
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
|
||||||
standard_allreduce_rmsnorm_fp8_quant_native_compiled = torch.compile(
|
standard_allreduce_rmsnorm_fp8_quant_native_compiled = torch.compile(
|
||||||
vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant,
|
vllm_fused_allreduce.allreduce_rmsnorm_fp8_quant,
|
||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
@@ -580,6 +583,7 @@ def run_benchmarks(
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
|
||||||
time_ms = benchmark_operation(
|
time_ms = benchmark_operation(
|
||||||
vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant,
|
vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant,
|
||||||
input_tensor,
|
input_tensor,
|
||||||
@@ -598,6 +602,7 @@ def run_benchmarks(
|
|||||||
VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"]))
|
VllmConfig(compilation_config=CompilationConfig(custom_ops=["-rms_norm"]))
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
vllm_fused_allreduce = VllmFusedAllreduce(hidden_dim, dtype)
|
||||||
standard_allreduce_rmsnorm_fp4_quant_native_compiled = torch.compile(
|
standard_allreduce_rmsnorm_fp4_quant_native_compiled = torch.compile(
|
||||||
vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant,
|
vllm_fused_allreduce.allreduce_rmsnorm_fp4_quant,
|
||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
|
|||||||
Reference in New Issue
Block a user