[Kernel] Refactor FlashInfer allreduce for mnnvl backend (#34109)
Signed-off-by: hjjq <50634613+hjjq@users.noreply.github.com> Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: Wei Zhao <51183510+wzhao18@users.noreply.github.com>
This commit is contained in:
14
vllm/envs.py
14
vllm/envs.py
@@ -168,6 +168,7 @@ if TYPE_CHECKING:
|
||||
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "masked_gemm"] = (
|
||||
"latency"
|
||||
)
|
||||
VLLM_FLASHINFER_ALLREDUCE_BACKEND: Literal["auto", "trtllm", "mnnvl"] = "auto"
|
||||
VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE: int = 394 * 1024 * 1024
|
||||
VLLM_XGRAMMAR_CACHE_MB: int = 0
|
||||
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
|
||||
@@ -206,6 +207,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False
|
||||
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
|
||||
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True
|
||||
VLLM_ALLREDUCE_USE_FLASHINFER: bool = False
|
||||
VLLM_TUNED_CONFIG_FOLDER: str | None = None
|
||||
VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS: set[str] = set()
|
||||
VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT: bool = False
|
||||
@@ -1290,6 +1292,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"latency",
|
||||
["throughput", "latency", "masked_gemm"],
|
||||
),
|
||||
# Flashinfer fused allreduce backend.
|
||||
# "auto" will default to "mnnvl", which performs mostly same/better than "trtllm".
|
||||
# But "mnnvl" backend does not support fuse with quantization.
|
||||
"VLLM_FLASHINFER_ALLREDUCE_BACKEND": env_with_choices(
|
||||
"VLLM_FLASHINFER_ALLREDUCE_BACKEND",
|
||||
"auto",
|
||||
["auto", "trtllm", "mnnvl"],
|
||||
),
|
||||
# Control the workspace buffer size for the FlashInfer backend.
|
||||
"VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE": lambda: int(
|
||||
os.getenv("VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE", str(394 * 1024 * 1024))
|
||||
@@ -1448,6 +1458,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_ALLREDUCE_USE_SYMM_MEM": lambda: bool(
|
||||
int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1"))
|
||||
),
|
||||
# Whether to use FlashInfer allreduce
|
||||
"VLLM_ALLREDUCE_USE_FLASHINFER": lambda: bool(
|
||||
int(os.getenv("VLLM_ALLREDUCE_USE_FLASHINFER", "0"))
|
||||
),
|
||||
# Experimental: use this to enable MCP tool calling for non harmony models
|
||||
"VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT": lambda: bool(
|
||||
int(os.getenv("VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", "0"))
|
||||
|
||||
Reference in New Issue
Block a user