Reapply [Attention] Refactor check_and_update_config (#35122)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-03-09 10:17:14 -04:00
committed by GitHub
parent 5578f2a4d3
commit 77a73458e3
32 changed files with 311 additions and 279 deletions

View File

@@ -221,11 +221,9 @@ class Attention(nn.Module, AttentionLayerBase):
vllm_config = get_current_vllm_config()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
# llm-compressor mdls need to set cache_dtype to "fp8" manually.
@@ -275,7 +273,6 @@ class Attention(nn.Module, AttentionLayerBase):
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla=False,
has_sink=self.has_sink,
use_mm_prefix=self.use_mm_prefix,

View File

@@ -30,9 +30,8 @@ from vllm.v1.kv_cache_interface import (
def create_chunked_local_attention_backend(
underlying_attn_backend: AttentionBackend,
attention_chunk_size: int,
block_size: int,
) -> type[AttentionBackend]:
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_"
underlying_builder = underlying_attn_backend.get_builder_cls()
assert issubclass(underlying_builder, AttentionMetadataBuilder)
@@ -55,7 +54,9 @@ def create_chunked_local_attention_backend(
fast_build: bool = False,
):
cm, make_virtual_batches_block_table = make_local_attention_virtual_batches(
attention_chunk_size, common_attn_metadata, block_size
attention_chunk_size,
common_attn_metadata,
self.kv_cache_spec.block_size,
)
metadata = super().build(common_prefix_len, cm, fast_build)
metadata.make_virtual_batches_block_table = make_virtual_batches_block_table
@@ -94,16 +95,12 @@ class ChunkedLocalAttention(Attention):
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size
)
underlying_attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype)
attn_backend = create_chunked_local_attention_backend(
underlying_attn_backend, attention_chunk_size, block_size
underlying_attn_backend, attention_chunk_size
)
super().__init__(

View File

@@ -188,10 +188,8 @@ class CrossAttention(Attention):
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if attn_type is not None:
assert attn_type == AttentionType.ENCODER_DECODER, (
@@ -202,7 +200,6 @@ class CrossAttention(Attention):
head_size,
dtype,
kv_cache_dtype,
block_size,
attn_type=AttentionType.ENCODER_DECODER,
)
attn_backend = create_cross_attention_backend(underlying_attn_backend)

View File

@@ -66,16 +66,13 @@ class EncoderOnlyAttention(Attention):
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
underlying_attn_backend = get_attn_backend(
head_size,
dtype,
kv_cache_dtype,
block_size,
attn_type=AttentionType.ENCODER_ONLY,
)

View File

@@ -323,11 +323,9 @@ class MLAAttention(nn.Module, AttentionLayerBase):
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
self.quant_config = quant_config
@@ -336,7 +334,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
self.head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla=True,
use_sparse=use_sparse,
num_heads=self.num_heads,
@@ -449,17 +446,24 @@ class MLAAttention(nn.Module, AttentionLayerBase):
)
# Attributes for forward_impl method
self.chunked_prefill_workspace_size = (
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
get_current_vllm_config()
)
)
self._vllm_config = get_current_vllm_config()
self._chunked_prefill_workspace_size: int | None = None
self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
static=True,
group_shape=GroupShape.PER_TENSOR,
compile_native=True,
)
@property
def chunked_prefill_workspace_size(self) -> int:
if self._chunked_prefill_workspace_size is None:
self._chunked_prefill_workspace_size = (
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
self._vllm_config
)
)
return self._chunked_prefill_workspace_size
def forward(
self,
q: torch.Tensor,

View File

@@ -126,17 +126,13 @@ class StaticSinkAttention(Attention, CustomOp):
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if attn_backend is not None:
underlying_attn_backend = attn_backend
else:
underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size
)
underlying_attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype)
attn_backend = create_static_sink_attention_backend(
underlying_attn_backend, # type: ignore[arg-type]
sink_len=sink_len,
@@ -153,7 +149,6 @@ class StaticSinkAttention(Attention, CustomOp):
CustomOp.__init__(self)
self.sink_len = sink_len
self.block_size = block_size
self.sink_populated = False
self.sink_key = None
self.sink_value = None
@@ -212,12 +207,12 @@ class StaticSinkAttention(Attention, CustomOp):
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Block size may get updated after model loading, refresh it
block_size = vllm_config.cache_config.block_size
self.block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention.
assert self.attn_type == AttentionType.DECODER
return SinkFullAttentionSpec(
block_size=block_size,
block_size=self.block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
head_size_v=self.head_size_v,

View File

@@ -217,10 +217,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
)
# override attention block size if either (a) the
# user has not set it or (b) the user has set it
# too small.
if cache_config.block_size is None or cache_config.block_size < attn_block_size:
# override attention block size if it is too small,
# even if the user has explicitly set it
if cache_config.block_size < attn_block_size:
cache_config.block_size = attn_block_size
logger.info(
"Setting attention block size to %d tokens "

View File

@@ -290,16 +290,13 @@ class WhisperCausalAttentionWithBlockPooling(Attention):
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
underlying_attn_backend = get_attn_backend(
head_size,
dtype,
kv_cache_dtype,
block_size,
attn_type=attn_type,
)
attn_backend = create_whisper_attention_backend_with_block_pooling(