Reapply [Attention] Refactor check_and_update_config (#35122)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, replace
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -144,15 +144,9 @@ class AttentionBackend(ABC):
|
||||
|
||||
@classmethod
|
||||
def supports_block_size(cls, block_size: int | None) -> bool:
|
||||
from vllm.config.cache import BlockSize
|
||||
|
||||
if block_size is None:
|
||||
return True
|
||||
|
||||
valid_sizes = get_args(BlockSize)
|
||||
if block_size not in valid_sizes:
|
||||
return False
|
||||
|
||||
supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
|
||||
if not supported_kernel_block_sizes:
|
||||
return True
|
||||
@@ -167,6 +161,17 @@ class AttentionBackend(ABC):
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_preferred_block_size(cls, default_block_size: int) -> int:
|
||||
supported_sizes = cls.get_supported_kernel_block_sizes()
|
||||
if not supported_sizes:
|
||||
return default_block_size
|
||||
|
||||
if cls.supports_block_size(default_block_size):
|
||||
return default_block_size
|
||||
|
||||
return min(s.base if isinstance(s, MultipleOf) else s for s in supported_sizes)
|
||||
|
||||
@classmethod
|
||||
def is_mla(cls) -> bool:
|
||||
return False
|
||||
@@ -210,7 +215,7 @@ class AttentionBackend(ABC):
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: "CacheDType | None",
|
||||
block_size: int,
|
||||
block_size: int | None,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
@@ -224,7 +229,7 @@ class AttentionBackend(ABC):
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: "CacheDType | None",
|
||||
block_size: int,
|
||||
block_size: int | None,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
|
||||
@@ -75,7 +75,7 @@ class FlashAttnMLABackend(MLACommonBackend):
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int,
|
||||
block_size: int | None,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
|
||||
@@ -69,7 +69,7 @@ class FlashInferMLABackend(MLACommonBackend):
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int,
|
||||
block_size: int | None,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
|
||||
@@ -106,7 +106,7 @@ class FlashInferMLASparseBackend(AttentionBackend):
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int,
|
||||
block_size: int | None,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
|
||||
@@ -80,7 +80,7 @@ class FlashMLABackend(MLACommonBackend):
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int,
|
||||
block_size: int | None,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
|
||||
@@ -49,7 +49,6 @@ def get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str | None,
|
||||
block_size: int | None,
|
||||
use_mla: bool = False,
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
@@ -71,6 +70,12 @@ def get_attn_backend(
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config is not None and cache_config.user_specified_block_size:
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
block_size = None
|
||||
|
||||
attn_selector_config = AttentionSelectorConfig(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
|
||||
Reference in New Issue
Block a user