[Misc] Add override for allreduce fusion thresholds (#23639)
Signed-off-by: Julien Lin <jullin@nvidia.com>
This commit is contained in:
@@ -10,6 +10,7 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
|
|||||||
from torch._inductor.pattern_matcher import PatternMatcherPass
|
from torch._inductor.pattern_matcher import PatternMatcherPass
|
||||||
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
|
||||||
from vllm.distributed.parallel_state import (
|
from vllm.distributed.parallel_state import (
|
||||||
@@ -401,6 +402,18 @@ if flashinfer_comm is not None:
|
|||||||
6: MiB // 2, # 512KB
|
6: MiB // 2, # 512KB
|
||||||
8: MiB // 2, # 512KB
|
8: MiB // 2, # 512KB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
_FI_MAX_SIZES.update({
|
||||||
|
int(k): int(float(v) * MiB)
|
||||||
|
for k, v in
|
||||||
|
envs.VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB.items()
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
"Failed to parse VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB: "
|
||||||
|
+ str(e)) from e
|
||||||
|
|
||||||
# opt for a more conservative default value
|
# opt for a more conservative default value
|
||||||
# when world size is not in _FI_MAX_SIZES
|
# when world size is not in _FI_MAX_SIZES
|
||||||
_DEFAULT_FI_MAX_SIZE = MiB // 2
|
_DEFAULT_FI_MAX_SIZE = MiB // 2
|
||||||
|
|||||||
11
vllm/envs.py
11
vllm/envs.py
@@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -1046,6 +1047,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE":
|
"VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE":
|
||||||
lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")),
|
lambda: int(os.getenv("VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE", "163840")),
|
||||||
|
|
||||||
|
# Specifies the thresholds of the communicated tensor sizes under which
|
||||||
|
# vllm should use flashinfer fused allreduce. The variable should be a
|
||||||
|
# JSON with the following format:
|
||||||
|
# { <world size>: <max size in mb> }
|
||||||
|
# Unspecified world sizes will fallback to
|
||||||
|
# { 2: 64, 4: 1, <everything else>: 0.5 }
|
||||||
|
"VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB":
|
||||||
|
lambda: json.loads(os.getenv(
|
||||||
|
"VLLM_FLASHINFER_ALLREDUCE_FUSION_THRESHOLDS_MB", "{}")),
|
||||||
|
|
||||||
# MoE routing strategy selector.
|
# MoE routing strategy selector.
|
||||||
# See `RoutingSimulator.get_available_strategies()` # for available
|
# See `RoutingSimulator.get_available_strategies()` # for available
|
||||||
# strategies.
|
# strategies.
|
||||||
|
|||||||
Reference in New Issue
Block a user