[Bugfix] Gracefully disable AllReduceFusionPass on GPUs without multicast support (#35085)
Signed-off-by: haosdent <haosdent@gmail.com>
This commit is contained in:
@@ -729,14 +729,26 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
scope="global",
|
||||
)
|
||||
|
||||
self.workspace = flashinfer_comm.create_allreduce_fusion_workspace(
|
||||
backend="trtllm",
|
||||
world_size=self.tp_size,
|
||||
rank=rank,
|
||||
max_token_num=self.max_token_num,
|
||||
hidden_dim=self.hidden_dim,
|
||||
dtype=self.model_dtype,
|
||||
)
|
||||
try:
|
||||
self.workspace = flashinfer_comm.create_allreduce_fusion_workspace(
|
||||
backend="trtllm",
|
||||
world_size=self.tp_size,
|
||||
rank=rank,
|
||||
max_token_num=self.max_token_num,
|
||||
hidden_dim=self.hidden_dim,
|
||||
dtype=self.model_dtype,
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "multicast" not in str(e).lower():
|
||||
raise
|
||||
logger.warning_once(
|
||||
"AllReduce fusion pass is disabled: flashinfer workspace "
|
||||
"creation failed: %s. This is expected on GPUs without "
|
||||
"NVSwitch (e.g., NVLink bridge-only or PCIe topologies). "
|
||||
"Falling back to non-fused allreduce.",
|
||||
str(e),
|
||||
)
|
||||
return
|
||||
|
||||
global _FI_WORKSPACE
|
||||
_FI_WORKSPACE = self.workspace
|
||||
|
||||
Reference in New Issue
Block a user