[PERF] Allreduce fusion. Support torch native matching. Tuning of the thresholds (#24248)

Signed-off-by: Luka Govedič <lgovedic@redhat.com>
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Signed-off-by: ilmarkov <markovilya197@gmail.com>
Co-authored-by: Luka Govedič <lgovedic@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Ilya Markov
2025-11-11 00:33:11 +01:00
committed by GitHub
parent 021143561f
commit d17ecc6b19
6 changed files with 1284 additions and 83 deletions

View File

@@ -9,7 +9,6 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.pattern_matcher import PatternMatcherPass
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
@@ -450,34 +449,41 @@ class AsyncTPPass(VllmPatternMatcherPass):
logger.debug("Replaced %s patterns", self.matched_count)
# Max size of the input tensor per world size per device capability
# to use flashinfer fused allreduce
FI_ALLREDUCE_FUSION_MAX_SIZE_MB: dict[int, dict[int, float]] = {
90: {
2: 64, # 64MB
4: 2, # 2MB
8: 0.5, # 0.5MB
},
100: {
2: 64, # 64MB
4: 32, # 32MB
8: 1, # 1MB
},
}
# Max size of the input tensor per world size per device capability
# to use flashinfer one shot fused allreduce
# OneShot max size is at most 64MB / world size (FlashInfer restriction)
_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB: dict[int, dict[int, float]] = {
90: {
2: 32, # 32MB
4: 2, # 2MB
8: 0.5, # 0.5MB
},
100: {
2: 32, # 32MB
4: 4, # 4MB
8: 1, # 1MB
},
}
if flashinfer_comm is not None:
_FI_WORKSPACE_TENSOR = None
MiB = 1024 * 1024
# Max size of the input tensor per world size
# to use flashinfer fused allreduce
_FI_MAX_SIZES = {
2: 64 * MiB, # 64MB
4: MiB, # 1MB
6: MiB // 2, # 512KB
8: MiB // 2, # 512KB
}
try:
_FI_MAX_SIZES.update(
{
int(k): int(float(v) * MiB)
for k, v in envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items()
}
)
except Exception as e:
raise ValueError(
"Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: " + str(e)
) from e
# opt for a more conservative default value
# when world size is not in _FI_MAX_SIZES
_DEFAULT_FI_MAX_SIZE = MiB // 2
def call_trtllm_fused_allreduce_norm(
allreduce_in: torch.Tensor,
@@ -491,7 +497,6 @@ if flashinfer_comm is not None:
fp32_acc: bool,
max_token_num: int,
pattern_code: int,
fuse_rms_quant: bool,
norm_out: torch.Tensor | None = None,
quant_out: torch.Tensor | None = None,
scale_out: torch.Tensor | None = None,
@@ -500,12 +505,20 @@ if flashinfer_comm is not None:
num_tokens, hidden_size = allreduce_in.shape
element_size = allreduce_in.element_size()
current_tensor_size = num_tokens * hidden_size * element_size
max_fusion_size = max_token_num * hidden_size * element_size
use_flashinfer = current_tensor_size <= min(
_FI_MAX_SIZES.get(world_size, _DEFAULT_FI_MAX_SIZE),
max_fusion_size,
)
if use_flashinfer:
if num_tokens <= max_token_num:
device_capability = current_platform.get_device_capability().to_int()
# Get one shot input size limit for the current world size
# for the current device capability
max_one_shot_size_mb = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
device_capability, {}
).get(world_size, None)
# Use one shot if no max size for one shot is specified
use_oneshot = (
max_one_shot_size_mb is None
or current_tensor_size <= max_one_shot_size_mb * MiB
)
assert _FI_WORKSPACE_TENSOR is not None, (
"Flashinfer must be enabled when using flashinfer"
)
@@ -532,7 +545,7 @@ if flashinfer_comm is not None:
hidden_dim=allreduce_in.shape[-1],
workspace_ptrs=_FI_WORKSPACE_TENSOR,
launch_with_pdl=launch_with_pdl,
use_oneshot=True,
use_oneshot=use_oneshot,
trigger_completion_at_end=trigger_completion_at_end,
fp32_acc=fp32_acc,
pattern_code=pattern_code,
@@ -545,7 +558,7 @@ if flashinfer_comm is not None:
)
else:
allreduce_out = tensor_model_parallel_all_reduce(allreduce_in)
if scale_factor is not None and scale_out is None and fuse_rms_quant:
if scale_factor is not None and scale_out is None:
# Do fused rms norm static fp8 quant fused op
if norm_out is None:
torch.ops._C.fused_add_rms_norm_static_fp8_quant(
@@ -568,15 +581,10 @@ if flashinfer_comm is not None:
norm_out = allreduce_out
else:
torch.ops._C.rms_norm(norm_out, allreduce_out, rms_gamma, rms_eps)
if scale_factor is not None:
if scale_out is not None:
torch.ops._C.scaled_fp4_quant(
quant_out, norm_out, scale_out, scale_factor
)
else:
torch.ops._C.static_scaled_fp8_quant(
quant_out, norm_out, scale_factor
)
if scale_factor is not None and scale_out is not None:
torch.ops._C.scaled_fp4_quant(
quant_out, norm_out, scale_out, scale_factor
)
if scale_factor is None or norm_out is not None:
# we need to return allreduce output
# in cases of non quant fused AR + RMS norm
@@ -595,7 +603,6 @@ if flashinfer_comm is not None:
fp32_acc: bool,
max_token_num: int,
pattern_code: int,
fuse_rms_quant: bool,
norm_out: torch.Tensor | None = None,
quant_out: torch.Tensor | None = None,
scale_out: torch.Tensor | None = None,
@@ -629,7 +636,6 @@ class FlashInferFusedAllReduceParams:
world_size: int,
use_fp32_lamport: bool = False,
max_token_num: int = 1024,
fuse_rms_quant: bool = False,
):
self.rank = rank
self.world_size = world_size
@@ -637,9 +643,7 @@ class FlashInferFusedAllReduceParams:
self.trigger_completion_at_end = True
self.launch_with_pdl = True
self.fp32_acc = True
self.use_oneshot = False
self.max_token_num = max_token_num
self.fuse_rms_quant = fuse_rms_quant
def get_trtllm_fused_allreduce_kwargs(self):
return {
@@ -649,7 +653,6 @@ class FlashInferFusedAllReduceParams:
"trigger_completion_at_end": self.trigger_completion_at_end,
"fp32_acc": self.fp32_acc,
"max_token_num": self.max_token_num,
"fuse_rms_quant": self.fuse_rms_quant,
}
@@ -1119,23 +1122,35 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
"skipping allreduce fusion pass"
)
return
# Check if the world size is supported
if self.tp_size not in _FI_MAX_SIZES:
max_size = config.compilation_config.pass_config.flashinfer_max_size(
self.tp_size
)
if max_size is None:
# Flashinfer doesn't support current world size
logger.warning(
"Flashinfer allreduce fusion is not supported for world size %s",
self.tp_size,
)
return
max_num_token = min(
_FI_MAX_SIZES.get(self.tp_size, _DEFAULT_FI_MAX_SIZE)
// (self.hidden_dim * self.tp_size * (4 if use_fp32_lamport else 2)),
config.compilation_config.pass_config.fi_allreduce_fusion_max_token_num,
element_size = 4 if use_fp32_lamport else 2
self.max_token_num = max_size // (self.hidden_dim * element_size)
# take the min to save workspace size and we'll never use more
# than max_num_batched_tokens anyways
self.max_token_num = min(
self.max_token_num, config.scheduler_config.max_num_batched_tokens
)
logger.debug_once(
f"Flashinfer max size: {max_size // (1024 * 1024)} MB,"
"Maximal number of tokens used by "
f"Flashinfer Allreduce Fusion: {self.max_token_num}",
scope="global",
)
self.ipc_handles, workspace_tensor = (
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
tp_rank=rank,
tp_size=self.tp_size,
max_token_num=max_num_token,
max_token_num=self.max_token_num,
hidden_dim=self.hidden_dim,
group=self.group,
use_fp32_lamport=use_fp32_lamport,
@@ -1148,10 +1163,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
rank=rank,
world_size=self.tp_size,
use_fp32_lamport=use_fp32_lamport,
max_token_num=max_num_token,
# fuse rms norm static fp8 quant fused op
# in fallback path, when we don't use flashinfer
fuse_rms_quant=config.compilation_config.pass_config.enable_fusion,
max_token_num=self.max_token_num,
)
self.register_patterns()