[1/N][Platform] Cleanup useless function (#26982)
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -9,6 +9,7 @@ Note: these tests will only pass on L4 GPU.
|
||||
import pytest
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import STR_BACKEND_ENV_VAR
|
||||
|
||||
@@ -69,8 +70,10 @@ def test_models(
|
||||
if kv_cache_dtype == "fp8_e5m2" and current_platform.is_rocm():
|
||||
pytest.skip(f"{kv_cache_dtype} is currently not supported on ROCm/HIP.")
|
||||
|
||||
if not current_platform.is_kv_cache_dtype_supported(kv_cache_dtype, None):
|
||||
pytest.skip(f"{kv_cache_dtype} is not supported on this platform.")
|
||||
if not flash_attn_supports_fp8():
|
||||
pytest.skip(
|
||||
f"{kv_cache_dtype} is not supported on this GPU type with {backend} attention."
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("TOKENIZERS_PARALLELISM", "true")
|
||||
|
||||
@@ -356,10 +356,6 @@ def test_compressed_tensors_fp8(vllm_runner):
|
||||
assert output
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_kv_cache_dtype_supported("fp8", None),
|
||||
reason="FP8 KV cache is not supported on this device.",
|
||||
)
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
|
||||
)
|
||||
|
||||
@@ -23,7 +23,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
_Backend = None
|
||||
|
||||
@@ -457,49 +457,6 @@ class CudaPlatformBase(Platform):
|
||||
def device_count(cls) -> int:
|
||||
return cuda_device_count_stateless()
|
||||
|
||||
@classmethod
|
||||
def is_kv_cache_dtype_supported(
|
||||
cls, kv_cache_dtype: str, model_config: "ModelConfig"
|
||||
) -> bool:
|
||||
fp8_attention = kv_cache_dtype.startswith("fp8")
|
||||
attention_backend = envs.VLLM_ATTENTION_BACKEND
|
||||
|
||||
supported = False
|
||||
if model_config is not None and model_config.use_mla:
|
||||
# Default to CutlassMLA for blackwell,
|
||||
# FlashMLA otherwise
|
||||
if attention_backend is None:
|
||||
if cls.is_device_capability(100):
|
||||
attention_backend = "CUTLASS_MLA"
|
||||
else:
|
||||
attention_backend = "FLASHMLA"
|
||||
|
||||
# Only FlashMLA and CUTLASS_MLA support fp8
|
||||
if attention_backend in ["FLASHMLA", "CUTLASS_MLA", "FLASHINFER_MLA"]:
|
||||
supported = True
|
||||
else:
|
||||
supported = not fp8_attention
|
||||
else:
|
||||
# Default to FlashAttention
|
||||
if attention_backend is None:
|
||||
attention_backend = "FLASH_ATTN"
|
||||
|
||||
# All Blackwell backends support fp8
|
||||
if cls.is_device_capability(100):
|
||||
supported = True
|
||||
elif attention_backend == "FLASH_ATTN":
|
||||
if fp8_attention:
|
||||
from vllm.attention.utils.fa_utils import flash_attn_supports_fp8
|
||||
|
||||
supported = flash_attn_supports_fp8()
|
||||
else:
|
||||
supported = True
|
||||
elif attention_backend == "FLASHINFER":
|
||||
supported = True
|
||||
elif attention_backend == "TRITON_ATTN":
|
||||
supported = cls.supports_fp8()
|
||||
return supported
|
||||
|
||||
@classmethod
|
||||
def check_if_supports_dtype(cls, dtype: torch.dtype):
|
||||
if dtype == torch.bfloat16: # noqa: SIM102
|
||||
|
||||
@@ -7,28 +7,23 @@ import platform
|
||||
import random
|
||||
import sys
|
||||
from datetime import timedelta
|
||||
from platform import uname
|
||||
from typing import TYPE_CHECKING, Any, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.distributed import PrefixStore, ProcessGroup
|
||||
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed import PrefixStore, ProcessGroup
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
else:
|
||||
_Backend = object
|
||||
ModelConfig = object
|
||||
VllmConfig = object
|
||||
PoolingParams = object
|
||||
SamplingParams = object
|
||||
FlexibleArgumentParser = object
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -36,7 +31,7 @@ logger = init_logger(__name__)
|
||||
|
||||
def in_wsl() -> bool:
|
||||
# Reference: https://github.com/microsoft/WSL/issues/4071
|
||||
return "microsoft" in " ".join(uname()).lower()
|
||||
return "microsoft" in " ".join(platform.uname()).lower()
|
||||
|
||||
|
||||
class PlatformEnum(enum.Enum):
|
||||
@@ -178,7 +173,8 @@ class Platform:
|
||||
import vllm._moe_C # noqa: F401
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
|
||||
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
|
||||
# Import _Backend here to avoid circular import.
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
|
||||
return _Backend.TORCH_SDPA
|
||||
@@ -186,7 +182,7 @@ class Platform:
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: _Backend,
|
||||
selected_backend: "_Backend",
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str | None,
|
||||
@@ -317,7 +313,7 @@ class Platform:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
Check and update the configuration for the current platform.
|
||||
|
||||
@@ -498,9 +494,9 @@ class Platform:
|
||||
@classmethod
|
||||
def validate_request(
|
||||
cls,
|
||||
prompt: PromptType,
|
||||
params: SamplingParams | PoolingParams,
|
||||
processed_inputs: ProcessorInputs,
|
||||
prompt: "PromptType",
|
||||
params: "SamplingParams | PoolingParams",
|
||||
processed_inputs: "ProcessorInputs",
|
||||
) -> None:
|
||||
"""Raises if this request is unsupported on this platform"""
|
||||
|
||||
@@ -543,25 +539,16 @@ class Platform:
|
||||
def stateless_init_device_torch_dist_pg(
|
||||
cls,
|
||||
backend: str,
|
||||
prefix_store: PrefixStore,
|
||||
prefix_store: "PrefixStore",
|
||||
group_rank: int,
|
||||
group_size: int,
|
||||
timeout: timedelta,
|
||||
) -> ProcessGroup:
|
||||
) -> "ProcessGroup":
|
||||
"""
|
||||
Init platform-specific torch distributed process group.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def is_kv_cache_dtype_supported(
|
||||
cls, kv_cache_dtype: str, model_config: ModelConfig
|
||||
) -> bool:
|
||||
"""
|
||||
Returns if the kv_cache_dtype is supported by the current platform.
|
||||
"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def check_if_supports_dtype(cls, dtype: torch.dtype):
|
||||
"""
|
||||
|
||||
@@ -15,7 +15,7 @@ from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
_Backend = None
|
||||
|
||||
@@ -474,12 +474,6 @@ class RocmPlatform(Platform):
|
||||
def device_count(cls) -> int:
|
||||
return cuda_device_count_stateless()
|
||||
|
||||
@classmethod
|
||||
def is_kv_cache_dtype_supported(
|
||||
cls, kv_cache_dtype: str, model_config: "ModelConfig"
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def check_if_supports_dtype(cls, dtype: torch.dtype):
|
||||
if dtype == torch.bfloat16: # noqa: SIM102
|
||||
|
||||
@@ -222,12 +222,6 @@ class TpuPlatform(Platform):
|
||||
):
|
||||
raise ValueError("Torch XLA does not support per-request seed.")
|
||||
|
||||
@classmethod
|
||||
def is_kv_cache_dtype_supported(
|
||||
cls, kv_cache_dtype: str, model_config: "ModelConfig"
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
@torch.compile(backend="openxla")
|
||||
def insert_blocks_to_device(
|
||||
|
||||
@@ -86,22 +86,6 @@ class XPUPlatform(Platform):
|
||||
logger.info("Using Flash Attention backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
|
||||
@classmethod
|
||||
def is_kv_cache_dtype_supported(
|
||||
cls, kv_cache_dtype: str, model_config: "ModelConfig"
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the kv_cache_dtype is supported.
|
||||
XPU only support fp8 kv cache with triton backend.
|
||||
"""
|
||||
if (
|
||||
envs.is_set("VLLM_ATTENTION_BACKEND")
|
||||
and envs.VLLM_ATTENTION_BACKEND == "TRITON_ATTN"
|
||||
):
|
||||
return kv_cache_dtype in ["fp8_e4m3", "fp8_e5m2", "fp8"]
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user