Reapply [Attention] Refactor check_and_update_config (#35122)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -221,11 +221,9 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
vllm_config = get_current_vllm_config()
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
calculate_kv_scales = cache_config.calculate_kv_scales
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
calculate_kv_scales = False
|
||||
|
||||
# llm-compressor mdls need to set cache_dtype to "fp8" manually.
|
||||
@@ -275,7 +273,6 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla=False,
|
||||
has_sink=self.has_sink,
|
||||
use_mm_prefix=self.use_mm_prefix,
|
||||
|
||||
@@ -30,9 +30,8 @@ from vllm.v1.kv_cache_interface import (
|
||||
def create_chunked_local_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend,
|
||||
attention_chunk_size: int,
|
||||
block_size: int,
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
|
||||
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_"
|
||||
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
assert issubclass(underlying_builder, AttentionMetadataBuilder)
|
||||
@@ -55,7 +54,9 @@ def create_chunked_local_attention_backend(
|
||||
fast_build: bool = False,
|
||||
):
|
||||
cm, make_virtual_batches_block_table = make_local_attention_virtual_batches(
|
||||
attention_chunk_size, common_attn_metadata, block_size
|
||||
attention_chunk_size,
|
||||
common_attn_metadata,
|
||||
self.kv_cache_spec.block_size,
|
||||
)
|
||||
metadata = super().build(common_prefix_len, cm, fast_build)
|
||||
metadata.make_virtual_batches_block_table = make_virtual_batches_block_table
|
||||
@@ -94,16 +95,12 @@ class ChunkedLocalAttention(Attention):
|
||||
dtype = torch.get_default_dtype()
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size, dtype, kv_cache_dtype, block_size
|
||||
)
|
||||
underlying_attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype)
|
||||
attn_backend = create_chunked_local_attention_backend(
|
||||
underlying_attn_backend, attention_chunk_size, block_size
|
||||
underlying_attn_backend, attention_chunk_size
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
|
||||
@@ -188,10 +188,8 @@ class CrossAttention(Attention):
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
if attn_type is not None:
|
||||
assert attn_type == AttentionType.ENCODER_DECODER, (
|
||||
@@ -202,7 +200,6 @@ class CrossAttention(Attention):
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
attn_type=AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
attn_backend = create_cross_attention_backend(underlying_attn_backend)
|
||||
|
||||
@@ -66,16 +66,13 @@ class EncoderOnlyAttention(Attention):
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
attn_type=AttentionType.ENCODER_ONLY,
|
||||
)
|
||||
|
||||
|
||||
@@ -323,11 +323,9 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
calculate_kv_scales = cache_config.calculate_kv_scales
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
calculate_kv_scales = False
|
||||
self.quant_config = quant_config
|
||||
|
||||
@@ -336,7 +334,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
self.head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla=True,
|
||||
use_sparse=use_sparse,
|
||||
num_heads=self.num_heads,
|
||||
@@ -449,17 +446,24 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
)
|
||||
|
||||
# Attributes for forward_impl method
|
||||
self.chunked_prefill_workspace_size = (
|
||||
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
|
||||
get_current_vllm_config()
|
||||
)
|
||||
)
|
||||
self._vllm_config = get_current_vllm_config()
|
||||
self._chunked_prefill_workspace_size: int | None = None
|
||||
self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
compile_native=True,
|
||||
)
|
||||
|
||||
@property
|
||||
def chunked_prefill_workspace_size(self) -> int:
|
||||
if self._chunked_prefill_workspace_size is None:
|
||||
self._chunked_prefill_workspace_size = (
|
||||
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
|
||||
self._vllm_config
|
||||
)
|
||||
)
|
||||
return self._chunked_prefill_workspace_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
|
||||
@@ -126,17 +126,13 @@ class StaticSinkAttention(Attention, CustomOp):
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
if attn_backend is not None:
|
||||
underlying_attn_backend = attn_backend
|
||||
else:
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size, dtype, kv_cache_dtype, block_size
|
||||
)
|
||||
underlying_attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype)
|
||||
attn_backend = create_static_sink_attention_backend(
|
||||
underlying_attn_backend, # type: ignore[arg-type]
|
||||
sink_len=sink_len,
|
||||
@@ -153,7 +149,6 @@ class StaticSinkAttention(Attention, CustomOp):
|
||||
CustomOp.__init__(self)
|
||||
|
||||
self.sink_len = sink_len
|
||||
self.block_size = block_size
|
||||
self.sink_populated = False
|
||||
self.sink_key = None
|
||||
self.sink_value = None
|
||||
@@ -212,12 +207,12 @@ class StaticSinkAttention(Attention, CustomOp):
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# Block size may get updated after model loading, refresh it
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
# Should not be called for enc-dec or encoder-only attention.
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
|
||||
return SinkFullAttentionSpec(
|
||||
block_size=block_size,
|
||||
block_size=self.block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
head_size_v=self.head_size_v,
|
||||
|
||||
@@ -217,10 +217,9 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig):
|
||||
mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token
|
||||
)
|
||||
|
||||
# override attention block size if either (a) the
|
||||
# user has not set it or (b) the user has set it
|
||||
# too small.
|
||||
if cache_config.block_size is None or cache_config.block_size < attn_block_size:
|
||||
# override attention block size if it is too small,
|
||||
# even if the user has explicitly set it
|
||||
if cache_config.block_size < attn_block_size:
|
||||
cache_config.block_size = attn_block_size
|
||||
logger.info(
|
||||
"Setting attention block size to %d tokens "
|
||||
|
||||
@@ -290,16 +290,13 @@ class WhisperCausalAttentionWithBlockPooling(Attention):
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
attn_type=attn_type,
|
||||
)
|
||||
attn_backend = create_whisper_attention_backend_with_block_pooling(
|
||||
|
||||
Reference in New Issue
Block a user