Reapply [Attention] Refactor check_and_update_config (#35122)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-03-09 10:17:14 -04:00
committed by GitHub
parent 5578f2a4d3
commit 77a73458e3
32 changed files with 311 additions and 279 deletions

View File

@@ -4,7 +4,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, replace
from enum import Enum
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar
import numpy as np
import torch
@@ -144,15 +144,9 @@ class AttentionBackend(ABC):
@classmethod
def supports_block_size(cls, block_size: int | None) -> bool:
from vllm.config.cache import BlockSize
if block_size is None:
return True
valid_sizes = get_args(BlockSize)
if block_size not in valid_sizes:
return False
supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes()
if not supported_kernel_block_sizes:
return True
@@ -167,6 +161,17 @@ class AttentionBackend(ABC):
return True
return False
@classmethod
def get_preferred_block_size(cls, default_block_size: int) -> int:
supported_sizes = cls.get_supported_kernel_block_sizes()
if not supported_sizes:
return default_block_size
if cls.supports_block_size(default_block_size):
return default_block_size
return min(s.base if isinstance(s, MultipleOf) else s for s in supported_sizes)
@classmethod
def is_mla(cls) -> bool:
return False
@@ -210,7 +215,7 @@ class AttentionBackend(ABC):
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int,
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
@@ -224,7 +229,7 @@ class AttentionBackend(ABC):
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int,
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,

View File

@@ -75,7 +75,7 @@ class FlashAttnMLABackend(MLACommonBackend):
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,

View File

@@ -69,7 +69,7 @@ class FlashInferMLABackend(MLACommonBackend):
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,

View File

@@ -106,7 +106,7 @@ class FlashInferMLASparseBackend(AttentionBackend):
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,

View File

@@ -80,7 +80,7 @@ class FlashMLABackend(MLACommonBackend):
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,

View File

@@ -49,7 +49,6 @@ def get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int | None,
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
@@ -71,6 +70,12 @@ def get_attn_backend(
vllm_config = get_current_vllm_config()
cache_config = vllm_config.cache_config
if cache_config is not None and cache_config.user_specified_block_size:
block_size = cache_config.block_size
else:
block_size = None
attn_selector_config = AttentionSelectorConfig(
head_size=head_size,
dtype=dtype,