[0/N][Attention] Fix miscellaneous pre-commit issues (#31924)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-01-07 20:15:17 -05:00
committed by GitHub
parent 5dcd7ef1f2
commit 0d7667419f
8 changed files with 36 additions and 26 deletions

View File

@@ -140,7 +140,7 @@ class StaticSinkAttention(Attention, CustomOp):
head_size, dtype, kv_cache_dtype, block_size
)
attn_backend = create_static_sink_attention_backend(
underlying_attn_backend,
underlying_attn_backend, # type: ignore[arg-type]
sink_len=sink_len,
)
Attention.__init__(

View File

@@ -55,7 +55,7 @@ def is_flashmla_dense_supported() -> tuple[bool, str | None]:
is_availble, maybe_reason = _is_flashmla_available()
if not is_availble:
return False, maybe_reason
if current_platform.get_device_capability()[0] != 9:
if not current_platform.is_device_capability_family(90):
return False, "FlashMLA Dense is only supported on Hopper devices."
return True, None
@@ -67,7 +67,10 @@ def is_flashmla_sparse_supported() -> tuple[bool, str | None]:
is_availble, maybe_reason = _is_flashmla_available()
if not is_availble:
return False, maybe_reason
if current_platform.get_device_capability()[0] not in (9, 10):
if not (
current_platform.is_device_capability_family(90)
or current_platform.is_device_capability_family(100)
):
return (
False,
"FlashMLA Sparse is only supported on Hopper and Blackwell devices.",

View File

@@ -7,9 +7,13 @@ import torch
from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
from vllm import _custom_ops
ops = _custom_ops
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
from vllm._ipex_ops import ipex_ops
ops = ipex_ops
class PagedAttention:

View File

@@ -754,8 +754,8 @@ def context_attention_fwd(
if current_platform.is_rocm():
extra_kargs = {"kpack": 1, "waves_per_eu": 2}
grid = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"]))
_fwd_kernel[grid](
grid_fn = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"]))
_fwd_kernel[grid_fn](
q,
k,
v,

View File

@@ -37,9 +37,9 @@ def fp8_mqa_logits_torch(
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
kv, scale = kv
seq_len_kv = kv.shape[0]
k = kv.to(torch.bfloat16)
k_fp8, scale = kv
seq_len_kv = k_fp8.shape[0]
k = k_fp8.to(torch.bfloat16)
q = q.to(torch.bfloat16)
mask_lo = (

View File

@@ -282,10 +282,7 @@ def _fwd_grouped_kernel_stage1(
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
split_kv_id = tl.program_id(2)
if kv_group_num > BLOCK_H:
VALID_BLOCK_H: tl.constexpr = BLOCK_H
else:
VALID_BLOCK_H: tl.constexpr = kv_group_num
VALID_BLOCK_H: tl.constexpr = BLOCK_H if kv_group_num > BLOCK_H else kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num)

View File

@@ -202,9 +202,9 @@ def _fwd_kernel(
def get_block_size(dtype: torch.dtype) -> int:
if dtype == torch.float32:
return 32
elif (
current_platform.is_cuda_alike()
) and current_platform.get_device_capability().major > 8:
elif current_platform.is_cuda_alike() and current_platform.has_device_capability(
80
):
return 128
else:
return 64

View File

@@ -7,16 +7,23 @@ from vllm.platforms import current_platform
logger = init_logger(__name__)
if current_platform.is_cuda():
from vllm import _custom_ops as ops
from vllm import _custom_ops
ops = _custom_ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
flash_attn_varlen_func,
get_scheduler_metadata,
)
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
ops = ipex_ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
flash_attn_varlen_func = ops.flash_attn_varlen_func
get_scheduler_metadata = ops.get_scheduler_metadata
elif current_platform.is_rocm():
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
@@ -85,7 +92,7 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
def flash_attn_supports_fp8() -> bool:
return (
get_flash_attn_version() == 3
and current_platform.get_device_capability().major == 9
and current_platform.is_device_capability_family(90)
)
@@ -105,10 +112,9 @@ def flash_attn_supports_mla():
is_fa_version_supported,
)
return (
is_fa_version_supported(3)
and current_platform.get_device_capability()[0] == 9
)
return is_fa_version_supported(
3
) and current_platform.is_device_capability_family(90)
except (ImportError, AssertionError):
pass
return False