[0/N][Attention] Fix miscellaneous pre-commit issues (#31924)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -140,7 +140,7 @@ class StaticSinkAttention(Attention, CustomOp):
|
||||
head_size, dtype, kv_cache_dtype, block_size
|
||||
)
|
||||
attn_backend = create_static_sink_attention_backend(
|
||||
underlying_attn_backend,
|
||||
underlying_attn_backend, # type: ignore[arg-type]
|
||||
sink_len=sink_len,
|
||||
)
|
||||
Attention.__init__(
|
||||
|
||||
@@ -55,7 +55,7 @@ def is_flashmla_dense_supported() -> tuple[bool, str | None]:
|
||||
is_availble, maybe_reason = _is_flashmla_available()
|
||||
if not is_availble:
|
||||
return False, maybe_reason
|
||||
if current_platform.get_device_capability()[0] != 9:
|
||||
if not current_platform.is_device_capability_family(90):
|
||||
return False, "FlashMLA Dense is only supported on Hopper devices."
|
||||
return True, None
|
||||
|
||||
@@ -67,7 +67,10 @@ def is_flashmla_sparse_supported() -> tuple[bool, str | None]:
|
||||
is_availble, maybe_reason = _is_flashmla_available()
|
||||
if not is_availble:
|
||||
return False, maybe_reason
|
||||
if current_platform.get_device_capability()[0] not in (9, 10):
|
||||
if not (
|
||||
current_platform.is_device_capability_family(90)
|
||||
or current_platform.is_device_capability_family(100)
|
||||
):
|
||||
return (
|
||||
False,
|
||||
"FlashMLA Sparse is only supported on Hopper and Blackwell devices.",
|
||||
|
||||
@@ -7,9 +7,13 @@ import torch
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import _custom_ops
|
||||
|
||||
ops = _custom_ops
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
|
||||
ops = ipex_ops
|
||||
|
||||
|
||||
class PagedAttention:
|
||||
|
||||
@@ -754,8 +754,8 @@ def context_attention_fwd(
|
||||
if current_platform.is_rocm():
|
||||
extra_kargs = {"kpack": 1, "waves_per_eu": 2}
|
||||
|
||||
grid = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"]))
|
||||
_fwd_kernel[grid](
|
||||
grid_fn = lambda META: (batch, head, triton.cdiv(max_input_len, META["BLOCK_M"]))
|
||||
_fwd_kernel[grid_fn](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
|
||||
@@ -37,9 +37,9 @@ def fp8_mqa_logits_torch(
|
||||
Returns:
|
||||
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||
"""
|
||||
kv, scale = kv
|
||||
seq_len_kv = kv.shape[0]
|
||||
k = kv.to(torch.bfloat16)
|
||||
k_fp8, scale = kv
|
||||
seq_len_kv = k_fp8.shape[0]
|
||||
k = k_fp8.to(torch.bfloat16)
|
||||
q = q.to(torch.bfloat16)
|
||||
|
||||
mask_lo = (
|
||||
|
||||
@@ -282,10 +282,7 @@ def _fwd_grouped_kernel_stage1(
|
||||
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
|
||||
split_kv_id = tl.program_id(2)
|
||||
|
||||
if kv_group_num > BLOCK_H:
|
||||
VALID_BLOCK_H: tl.constexpr = BLOCK_H
|
||||
else:
|
||||
VALID_BLOCK_H: tl.constexpr = kv_group_num
|
||||
VALID_BLOCK_H: tl.constexpr = BLOCK_H if kv_group_num > BLOCK_H else kv_group_num
|
||||
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
|
||||
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
|
||||
mask_h = mask_h & (cur_head < q_head_num)
|
||||
|
||||
@@ -202,9 +202,9 @@ def _fwd_kernel(
|
||||
def get_block_size(dtype: torch.dtype) -> int:
|
||||
if dtype == torch.float32:
|
||||
return 32
|
||||
elif (
|
||||
current_platform.is_cuda_alike()
|
||||
) and current_platform.get_device_capability().major > 8:
|
||||
elif current_platform.is_cuda_alike() and current_platform.has_device_capability(
|
||||
80
|
||||
):
|
||||
return 128
|
||||
else:
|
||||
return 64
|
||||
|
||||
@@ -7,16 +7,23 @@ from vllm.platforms import current_platform
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import _custom_ops
|
||||
|
||||
ops = _custom_ops
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
|
||||
flash_attn_varlen_func,
|
||||
get_scheduler_metadata,
|
||||
)
|
||||
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
|
||||
ops = ipex_ops
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
flash_attn_varlen_func = ops.flash_attn_varlen_func
|
||||
get_scheduler_metadata = ops.get_scheduler_metadata
|
||||
|
||||
elif current_platform.is_rocm():
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func # noqa: F401
|
||||
@@ -85,7 +92,7 @@ def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||
def flash_attn_supports_fp8() -> bool:
|
||||
return (
|
||||
get_flash_attn_version() == 3
|
||||
and current_platform.get_device_capability().major == 9
|
||||
and current_platform.is_device_capability_family(90)
|
||||
)
|
||||
|
||||
|
||||
@@ -105,10 +112,9 @@ def flash_attn_supports_mla():
|
||||
is_fa_version_supported,
|
||||
)
|
||||
|
||||
return (
|
||||
is_fa_version_supported(3)
|
||||
and current_platform.get_device_capability()[0] == 9
|
||||
)
|
||||
return is_fa_version_supported(
|
||||
3
|
||||
) and current_platform.is_device_capability_family(90)
|
||||
except (ImportError, AssertionError):
|
||||
pass
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user