[Feature] Extend batch invariant torch.compile to B200 (#27856)
Signed-off-by: PaulZhang12 <paulzhan@fb.com>
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import functools
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from collections.abc import Callable
|
||||
@@ -11,6 +10,7 @@ import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
@@ -737,11 +737,28 @@ def enable_batch_invariant_mode():
|
||||
|
||||
_batch_invariant_MODE = True
|
||||
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
|
||||
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
|
||||
|
||||
# Batch invariant matmuls are no longer needed after cublas overrides
|
||||
if not is_torch_equal_or_newer("2.10.0.dev"):
|
||||
if current_platform.is_device_capability(100):
|
||||
# For PyTorch 2.9, B200 uses GEMV for bs=1
|
||||
# Requires https://github.com/pytorch/pytorch/pull/166735
|
||||
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "CUDA")
|
||||
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "CUDA")
|
||||
else:
|
||||
# Only source of batch invariance for Hopper is split-k, can disable through
|
||||
# cuBLAS workspace config
|
||||
_original_cublas_workspace_cfg = os.environ.get(
|
||||
"CUBLAS_WORKSPACE_CONFIG", None
|
||||
)
|
||||
_original_cublaslt_workspace_size = os.environ.get(
|
||||
"CUBLASLT_WORKSPACE_SIZE", None
|
||||
)
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||||
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
|
||||
|
||||
_batch_invariant_LIB.impl(
|
||||
"aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
|
||||
)
|
||||
@@ -750,6 +767,7 @@ def enable_batch_invariant_mode():
|
||||
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
|
||||
|
||||
# Also monkeypatch torch.bmm directly as a fallback
|
||||
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
|
||||
_original_torch_bmm = torch.bmm
|
||||
torch.bmm = bmm_batch_invariant
|
||||
|
||||
@@ -771,14 +789,6 @@ def enable_batch_invariant_mode():
|
||||
)
|
||||
torch.backends.cuda.preferred_blas_library(backend="cublaslt")
|
||||
|
||||
if not is_torch_equal_or_newer("2.10.0.dev"):
|
||||
_original_cublas_workspace_cfg = os.environ.get("CUBLAS_WORKSPACE_CONFIG", None)
|
||||
_original_cublaslt_workspace_size = os.environ.get(
|
||||
"CUBLASLT_WORKSPACE_SIZE", None
|
||||
)
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||||
os.environ["CUBLASLT_WORKSPACE_SIZE"] = "1"
|
||||
|
||||
|
||||
def disable_batch_invariant_mode():
|
||||
global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
|
||||
@@ -847,7 +857,6 @@ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
|
||||
return AttentionBlockSize(block_m=16, block_n=16)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def vllm_is_batch_invariant():
|
||||
env_key = "VLLM_BATCH_INVARIANT"
|
||||
is_overridden = False
|
||||
|
||||
Reference in New Issue
Block a user