diff --git a/vllm/attention/layers/static_sink_attention.py b/vllm/attention/layers/static_sink_attention.py index e5ed16ec1..13be65d8b 100644 --- a/vllm/attention/layers/static_sink_attention.py +++ b/vllm/attention/layers/static_sink_attention.py @@ -140,7 +140,7 @@ class StaticSinkAttention(Attention, CustomOp): head_size, dtype, kv_cache_dtype, block_size ) attn_backend = create_static_sink_attention_backend( - underlying_attn_backend, + underlying_attn_backend, # type: ignore[arg-type] sink_len=sink_len, ) Attention.__init__( diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index d8ab0b909..068b99937 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -55,7 +55,7 @@ def is_flashmla_dense_supported() -> tuple[bool, str | None]: is_availble, maybe_reason = _is_flashmla_available() if not is_availble: return False, maybe_reason - if current_platform.get_device_capability()[0] != 9: + if not current_platform.is_device_capability_family(90): return False, "FlashMLA Dense is only supported on Hopper devices." return True, None @@ -67,7 +67,10 @@ def is_flashmla_sparse_supported() -> tuple[bool, str | None]: is_availble, maybe_reason = _is_flashmla_available() if not is_availble: return False, maybe_reason - if current_platform.get_device_capability()[0] not in (9, 10): + if not ( + current_platform.is_device_capability_family(90) + or current_platform.is_device_capability_family(100) + ): return ( False, "FlashMLA Sparse is only supported on Hopper and Blackwell devices.", diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 4aa4bcf5b..280629548 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -7,9 +7,13 @@ import torch from vllm.platforms import current_platform if current_platform.is_cuda_alike(): - from vllm import _custom_ops as ops + from vllm import _custom_ops + + ops = _custom_ops elif current_platform.is_xpu(): - from vllm._ipex_ops import ipex_ops as ops + from vllm._ipex_ops import ipex_ops + + ops = ipex_ops class PagedAttention: diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index f101d5c4a..5a507a779 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -754,8 +754,8 @@ def context_attention_fwd( if current_platform.is_rocm(): extra_kargs = {"kpack": 1, "waves_per_eu": 2} - grid = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"])) - _fwd_kernel[grid]( + grid_fn = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"])) + _fwd_kernel[grid_fn]( q, k, v, diff --git a/vllm/attention/ops/rocm_aiter_mla_sparse.py b/vllm/attention/ops/rocm_aiter_mla_sparse.py index 080e92ecc..1e89d48db 100644 --- a/vllm/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/attention/ops/rocm_aiter_mla_sparse.py @@ -37,9 +37,9 @@ def fp8_mqa_logits_torch( Returns: Logits tensor of shape [M, N], dtype `torch.float32`. """ - kv, scale = kv - seq_len_kv = kv.shape[0] - k = kv.to(torch.bfloat16) + k_fp8, scale = kv + seq_len_kv = k_fp8.shape[0] + k = k_fp8.to(torch.bfloat16) q = q.to(torch.bfloat16) mask_lo = ( diff --git a/vllm/attention/ops/triton_decode_attention.py b/vllm/attention/ops/triton_decode_attention.py index aebc2e63c..1ed9698c5 100644 --- a/vllm/attention/ops/triton_decode_attention.py +++ b/vllm/attention/ops/triton_decode_attention.py @@ -282,10 +282,7 @@ def _fwd_grouped_kernel_stage1( cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) split_kv_id = tl.program_id(2) - if kv_group_num > BLOCK_H: - VALID_BLOCK_H: tl.constexpr = BLOCK_H - else: - VALID_BLOCK_H: tl.constexpr = kv_group_num + VALID_BLOCK_H: tl.constexpr = BLOCK_H if kv_group_num > BLOCK_H else kv_group_num cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H mask_h = mask_h & (cur_head < q_head_num) diff --git a/vllm/attention/ops/triton_prefill_attention.py b/vllm/attention/ops/triton_prefill_attention.py index ae7332830..c593698f1 100644 --- a/vllm/attention/ops/triton_prefill_attention.py +++ b/vllm/attention/ops/triton_prefill_attention.py @@ -202,9 +202,9 @@ def _fwd_kernel( def get_block_size(dtype: torch.dtype) -> int: if dtype == torch.float32: return 32 - elif ( - current_platform.is_cuda_alike() - ) and current_platform.get_device_capability().major > 8: + elif current_platform.is_cuda_alike() and current_platform.has_device_capability( + 80 + ): return 128 else: return 64 diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index e38c88f48..189bf3d4f 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -7,16 +7,23 @@ from vllm.platforms import current_platform logger = init_logger(__name__) if current_platform.is_cuda(): - from vllm import _custom_ops as ops + from vllm import _custom_ops + ops = _custom_ops reshape_and_cache_flash = ops.reshape_and_cache_flash - from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata -elif current_platform.is_xpu(): - from vllm._ipex_ops import ipex_ops as ops + from vllm.vllm_flash_attn import ( # type: ignore[attr-defined] + flash_attn_varlen_func, + get_scheduler_metadata, + ) +elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops + + ops = ipex_ops reshape_and_cache_flash = ops.reshape_and_cache_flash flash_attn_varlen_func = ops.flash_attn_varlen_func get_scheduler_metadata = ops.get_scheduler_metadata + elif current_platform.is_rocm(): try: from flash_attn import flash_attn_varlen_func # noqa: F401 @@ -85,7 +92,7 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None: def flash_attn_supports_fp8() -> bool: return ( get_flash_attn_version() == 3 - and current_platform.get_device_capability().major == 9 + and current_platform.is_device_capability_family(90) ) @@ -105,10 +112,9 @@ def flash_attn_supports_mla(): is_fa_version_supported, ) - return ( - is_fa_version_supported(3) - and current_platform.get_device_capability()[0] == 9 - ) + return is_fa_version_supported( + 3 + ) and current_platform.is_device_capability_family(90) except (ImportError, AssertionError): pass return False