Reapply [Attention] Refactor check_and_update_config (#35122)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-03-09 10:17:14 -04:00
committed by GitHub
parent 5578f2a4d3
commit 77a73458e3
32 changed files with 311 additions and 279 deletions

View File

@@ -323,11 +323,9 @@ class MLAAttention(nn.Module, AttentionLayerBase):
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
self.quant_config = quant_config
@@ -336,7 +334,6 @@ class MLAAttention(nn.Module, AttentionLayerBase):
self.head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla=True,
use_sparse=use_sparse,
num_heads=self.num_heads,
@@ -449,17 +446,24 @@ class MLAAttention(nn.Module, AttentionLayerBase):
)
# Attributes for forward_impl method
self.chunked_prefill_workspace_size = (
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
get_current_vllm_config()
)
)
self._vllm_config = get_current_vllm_config()
self._chunked_prefill_workspace_size: int | None = None
self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
static=True,
group_shape=GroupShape.PER_TENSOR,
compile_native=True,
)
@property
def chunked_prefill_workspace_size(self) -> int:
if self._chunked_prefill_workspace_size is None:
self._chunked_prefill_workspace_size = (
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
self._vllm_config
)
)
return self._chunked_prefill_workspace_size
def forward(
self,
q: torch.Tensor,