[Bugfix] Gracefully disable AllReduceFusionPass on GPUs without multicast support (#35085)

Signed-off-by: haosdent <haosdent@gmail.com>
This commit is contained in:
haosdent
2026-02-25 23:31:45 +08:00
committed by GitHub
parent d72b0be33c
commit 0788ff0a15

View File

@@ -729,14 +729,26 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
scope="global",
)
self.workspace = flashinfer_comm.create_allreduce_fusion_workspace(
backend="trtllm",
world_size=self.tp_size,
rank=rank,
max_token_num=self.max_token_num,
hidden_dim=self.hidden_dim,
dtype=self.model_dtype,
)
try:
self.workspace = flashinfer_comm.create_allreduce_fusion_workspace(
backend="trtllm",
world_size=self.tp_size,
rank=rank,
max_token_num=self.max_token_num,
hidden_dim=self.hidden_dim,
dtype=self.model_dtype,
)
except RuntimeError as e:
if "multicast" not in str(e).lower():
raise
logger.warning_once(
"AllReduce fusion pass is disabled: flashinfer workspace "
"creation failed: %s. This is expected on GPUs without "
"NVSwitch (e.g., NVLink bridge-only or PCIe topologies). "
"Falling back to non-fused allreduce.",
str(e),
)
return
global _FI_WORKSPACE
_FI_WORKSPACE = self.workspace