diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py index 3f7e0a069..cb98a856c 100755 --- a/tools/pre_commit/mypy.py +++ b/tools/pre_commit/mypy.py @@ -41,6 +41,7 @@ FILES = [ "vllm/usage", "vllm/utils", "vllm/worker", + "vllm/v1/attention", "vllm/v1/core", "vllm/v1/engine", "vllm/v1/executor", @@ -60,7 +61,6 @@ SEPARATE_GROUPS = [ "vllm/lora", "vllm/model_executor", # v1 related - "vllm/v1/attention", "vllm/v1/kv_offload", "vllm/v1/spec_decode", "vllm/v1/structured_output", diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 025ede1eb..e51a48cef 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from enum import Enum from typing import TYPE_CHECKING, ClassVar, Generic, Protocol, TypeVar, get_args import torch @@ -14,7 +15,7 @@ if TYPE_CHECKING: from vllm.v1.attention.backends.utils import KVCacheLayoutType -class AttentionType: +class AttentionType(str, Enum): """ Attention type. Use string to be compatible with `torch.compile`. @@ -193,7 +194,7 @@ class AttentionBackend(ABC): head_size: int, dtype: torch.dtype, kv_cache_dtype: "CacheDType | None", - block_size: int | None, + block_size: int, use_mla: bool, has_sink: bool, use_sparse: bool, @@ -207,7 +208,7 @@ class AttentionBackend(ABC): head_size: int, dtype: torch.dtype, kv_cache_dtype: "CacheDType | None", - block_size: int | None, + block_size: int, use_mla: bool, has_sink: bool, use_sparse: bool, @@ -290,6 +291,11 @@ class AttentionLayer(Protocol): class AttentionImpl(ABC, Generic[T]): + # Required attributes that all impls should have + num_heads: int + head_size: int + scale: float + # Whether the attention impl can return the softmax lse for decode. # Some features like decode context parallelism require the softmax lse. can_return_lse_for_decode: bool = False diff --git a/vllm/model_executor/layers/attention_layer_base.py b/vllm/model_executor/layers/attention_layer_base.py index a60cf7871..24809ccb0 100644 --- a/vllm/model_executor/layers/attention_layer_base.py +++ b/vllm/model_executor/layers/attention_layer_base.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod -from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.abstract import AttentionBackend, AttentionImpl from vllm.config import VllmConfig from vllm.v1.kv_cache_interface import KVCacheSpec @@ -18,6 +18,8 @@ class AttentionLayerBase(ABC): from different layer types. """ + impl: "AttentionImpl" + @abstractmethod def get_attn_backend(self) -> type[AttentionBackend]: """Get the attention backend class for this layer.""" diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 3445e998d..24390605a 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -167,7 +167,7 @@ class FlashAttentionBackend(AttentionBackend): head_size: int, dtype: torch.dtype, kv_cache_dtype: CacheDType | None, - block_size: int, + block_size: int | None, use_mla: bool, has_sink: bool, use_sparse: bool, @@ -354,7 +354,11 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad aot_schedule = False max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible - if self.use_full_cuda_graph and num_actual_tokens <= self.max_cudagraph_size: + if ( + self.use_full_cuda_graph + and self.max_cudagraph_size is not None + and num_actual_tokens <= self.max_cudagraph_size + ): # NOTE(woosuk): Setting num_splits > 1 may increase the memory # usage, because the intermediate buffers of size [num_splits, # num_heads, num_tokens, head_size] are allocated. Therefore, @@ -599,6 +603,9 @@ class FlashAttentionImpl(AttentionImpl): We use torch's .expand() to avoid duplicating values """ assert output is not None, "Output tensor must be provided." + assert self.vllm_flash_attn_version is not None, ( + "FlashAttention version not detected." + ) if output_scale is not None or output_block_scale is not None: raise NotImplementedError( @@ -697,6 +704,11 @@ class FlashAttentionImpl(AttentionImpl): ) return output else: + sliding_window_size = ( + list(self.sliding_window) + if self.sliding_window is not None + else None + ) flash_attn_varlen_func( q=query[:num_actual_tokens], k=key_cache, @@ -709,7 +721,7 @@ class FlashAttentionImpl(AttentionImpl): softmax_scale=self.scale, causal=attn_metadata.causal, alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, + window_size=sliding_window_size, block_table=block_table, softcap=self.logits_soft_cap, scheduler_metadata=scheduler_metadata, @@ -764,12 +776,19 @@ class FlashAttentionImpl(AttentionImpl): k_descale: torch.Tensor | None = None, v_descale: torch.Tensor | None = None, ) -> torch.Tensor: + assert self.vllm_flash_attn_version is not None, ( + "FlashAttention version not detected." + ) + cu_seqlens_q = attn_metadata.query_start_loc max_seqlen_q = attn_metadata.max_query_len block_table = attn_metadata.block_table query = query.contiguous() query_across_dcp = get_dcp_group().all_gather(query, dim=1) + sliding_window_size = ( + list(self.sliding_window) if self.sliding_window is not None else None + ) context_attn_out, context_lse = flash_attn_varlen_func( q=query_across_dcp, k=key_cache, @@ -782,7 +801,7 @@ class FlashAttentionImpl(AttentionImpl): softmax_scale=self.scale, causal=False, alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, + window_size=sliding_window_size, block_table=block_table, softcap=self.logits_soft_cap, return_softmax_lse=True, @@ -813,7 +832,7 @@ class FlashAttentionImpl(AttentionImpl): softmax_scale=self.scale, causal=attn_metadata.causal, alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, + window_size=sliding_window_size, softcap=self.logits_soft_cap, return_softmax_lse=True, fa_version=self.vllm_flash_attn_version, @@ -850,6 +869,10 @@ class FlashAttentionImpl(AttentionImpl): attn_metadata: Encoder attention metadata layer: The attention layer """ + assert self.vllm_flash_attn_version is not None, ( + "FlashAttention version not detected." + ) + # For encoder attention, process FP8 quantization if needed if self.kv_cache_dtype.startswith("fp8"): raise NotImplementedError( @@ -868,6 +891,9 @@ class FlashAttentionImpl(AttentionImpl): ) # Call flash attention directly on Q, K, V tensors + sliding_window_size = ( + list(self.sliding_window) if self.sliding_window is not None else None + ) flash_attn_varlen_func( q=query, k=key, @@ -880,7 +906,7 @@ class FlashAttentionImpl(AttentionImpl): softmax_scale=self.scale, causal=False, # Encoder attention is bidirectional alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, + window_size=sliding_window_size, softcap=self.logits_soft_cap, fa_version=self.vllm_flash_attn_version, q_descale=layer._q_scale.expand(descale_shape), @@ -1020,7 +1046,7 @@ def cascade_attention( max_seqlen_k=common_prefix_len, softmax_scale=softmax_scale, causal=False, - window_size=sliding_window, + window_size=list(sliding_window), block_table=block_table[:1], softcap=logits_soft_cap, return_softmax_lse=True, @@ -1048,7 +1074,7 @@ def cascade_attention( max_seqlen_k=max_kv_len - common_prefix_len, softmax_scale=softmax_scale, causal=True, - window_size=sliding_window, + window_size=list(sliding_window), block_table=block_table[:, num_common_kv_blocks:], softcap=logits_soft_cap, return_softmax_lse=True, diff --git a/vllm/v1/attention/backends/flash_attn_diffkv.py b/vllm/v1/attention/backends/flash_attn_diffkv.py index 2e36740bd..ebbc4a02c 100644 --- a/vllm/v1/attention/backends/flash_attn_diffkv.py +++ b/vllm/v1/attention/backends/flash_attn_diffkv.py @@ -113,6 +113,9 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl): We use torch's .expand() to avoid duplicating values """ assert output is not None, "Output tensor must be provided." + assert self.vllm_flash_attn_version is not None, ( + "FlashAttention version not detected." + ) if output_scale is not None or output_block_scale is not None: raise NotImplementedError( @@ -214,6 +217,11 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl): ) return output else: + sliding_window_size = ( + list(self.sliding_window) + if self.sliding_window is not None + else None + ) flash_attn_varlen_func( q=query[:num_actual_tokens], k=key_cache, @@ -226,7 +234,7 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl): softmax_scale=self.scale, causal=attn_metadata.causal, alibi_slopes=self.alibi_slopes, - window_size=self.sliding_window, + window_size=sliding_window_size, block_table=block_table, softcap=self.logits_soft_cap, scheduler_metadata=scheduler_metadata, diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 7ef157384..0bdf396d8 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -530,11 +530,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): self._decode_wrappers_cudagraph: dict[ int, BatchDecodeWithPagedKVCacheWrapper ] = {} - self._decode_cudagraph_max_bs = min( - (1 + num_spec_tokens) * max_num_reqs, - self.compilation_config.max_cudagraph_capture_size, - ) - + self._decode_cudagraph_max_bs = (1 + num_spec_tokens) * max_num_reqs + if self.compilation_config.max_cudagraph_capture_size is not None: + self._decode_cudagraph_max_bs = min( + self._decode_cudagraph_max_bs, + self.compilation_config.max_cudagraph_capture_size, + ) try: self.dcp_world_size = get_dcp_group().world_size self.dcp_rank = get_dcp_group().rank_in_group diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index a151a437a..ad99a6dad 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -215,7 +215,7 @@ def physical_to_logical_mapping( ) # Only process valid blocks to avoid garbage values - num_blocks_per_seq = cdiv(seq_lens, block_size) + num_blocks_per_seq: torch.Tensor = cdiv(seq_lens, block_size) mask = ( torch.arange(max_num_blocks, device=device)[None, :] < num_blocks_per_seq[:, None] diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index b2bbbe1c5..79636ecab 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -75,8 +75,10 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] self.compilation_config = vllm_config.compilation_config self.speculative_config = vllm_config.speculative_config self.kv_cache_spec = kv_cache_spec + if self.speculative_config: - self.num_spec = self.speculative_config.num_speculative_tokens + assert self.speculative_config.num_speculative_tokens is not None + self.num_spec: int = self.speculative_config.num_speculative_tokens else: self.num_spec = 0 self.use_spec_decode = self.num_spec > 0 @@ -85,10 +87,15 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata] self.use_full_cuda_graph = ( self.compilation_config.cudagraph_mode.has_full_cudagraphs() ) - self.decode_cudagraph_max_bs = min( - self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1), - self.compilation_config.max_cudagraph_capture_size, + + self.decode_cudagraph_max_bs = ( + self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1) ) + if self.compilation_config.max_cudagraph_capture_size is not None: + self.decode_cudagraph_max_bs = min( + self.decode_cudagraph_max_bs, + self.compilation_config.max_cudagraph_capture_size, + ) self.spec_state_indices_tensor = torch.empty( (self.decode_cudagraph_max_bs, self.num_spec + 1), diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 460091161..74925a86e 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -123,10 +123,11 @@ class Mamba2AttentionMetadataBuilder( device: torch.device, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) - self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() - assert self.chunk_size is not None, ( + chunk_size = vllm_config.model_config.get_mamba_chunk_size() + assert chunk_size is not None, ( "chunk_size needs to be set in the model config for Mamba2 models" ) + self.chunk_size: int = chunk_size def _compute_chunk_metadata( self, diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index dd7b96e98..2d4335664 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -69,10 +69,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): assert isinstance(kv_cache_spec, MambaSpec) self.compilation_config = vllm_config.compilation_config - self.decode_cudagraph_max_bs = min( - self.vllm_config.scheduler_config.max_num_seqs, - self.compilation_config.max_cudagraph_capture_size, - ) + self.decode_cudagraph_max_bs = self.vllm_config.scheduler_config.max_num_seqs + if self.compilation_config.max_cudagraph_capture_size is not None: + self.decode_cudagraph_max_bs = min( + self.decode_cudagraph_max_bs, + self.compilation_config.max_cudagraph_capture_size, + ) if self.vllm_config.cache_config.enable_prefix_caching: self.state_indices_tensor = torch.empty( @@ -150,9 +152,13 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1 ) # -1 in case it's non-computed and causes later issues with indexing - block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0) + block_idx_last_computed_token = torch.clamp( + block_idx_last_computed_token, min=0 + ) # -1 in the case we have a padded request (0 seq-len) - block_idx_last_scheduled_token = block_idx_last_scheduled_token.clamp(min=0) + block_idx_last_scheduled_token = torch.clamp( + block_idx_last_scheduled_token, min=0 + ) return ( block_idx_last_computed_token, diff --git a/vllm/v1/attention/backends/mla/aiter_triton_mla.py b/vllm/v1/attention/backends/mla/aiter_triton_mla.py index 8a92152a0..96bc0480d 100644 --- a/vllm/v1/attention/backends/mla/aiter_triton_mla.py +++ b/vllm/v1/attention/backends/mla/aiter_triton_mla.py @@ -62,7 +62,7 @@ class AiterTritonMLAImpl(AiterMLAImpl): k, v, softmax_scale=softmax_scale, - return_lse=return_softmax_lse, + return_softmax_lse=return_softmax_lse, **kwargs, ) # Transpose the LSE if Triton MHA is used: diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index a47a2282f..4805bf2ee 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -202,6 +202,7 @@ from vllm._aiter_ops import rocm_aiter_ops from vllm.attention.backends.abstract import ( AttentionBackend, AttentionLayer, + AttentionMetadata, MLAAttentionImpl, ) from vllm.attention.backends.utils import get_mla_dims @@ -251,13 +252,15 @@ class QueryLenSupport(Enum): try: - from vllm.vllm_flash_attn import flash_attn_varlen_func + from vllm.vllm_flash_attn import ( # type: ignore[attr-defined] + flash_attn_varlen_func, + ) is_vllm_fa = True except ImportError: # For rocm use upstream flash attention if current_platform.is_rocm(): - from flash_attn import flash_attn_varlen_func + from flash_attn import flash_attn_varlen_func # type: ignore[no-redef] is_vllm_fa = False try: @@ -386,7 +389,7 @@ D = TypeVar("D", bound=MLACommonDecodeMetadata) @dataclass -class MLACommonMetadata(Generic[D]): +class MLACommonMetadata(AttentionMetadata, Generic[D]): """Metadata for MLACommon. NOTE: Please read the comment at the top of the file before trying to @@ -434,7 +437,7 @@ class MLACommonMetadata(Generic[D]): M = TypeVar("M", bound=MLACommonMetadata) -A = TypeVar("A") +A = TypeVar("A", bound=AttentionMetadata) def use_flashinfer_prefill() -> bool: @@ -617,7 +620,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): self._fi_prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = [] self._global_hyperparameters = infer_global_hyperparameters( - get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl) + get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl) # type: ignore[type-abstract] ) if self._use_trtllm_ragged_prefill: @@ -874,7 +877,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): ) # Note(qcs): The max local context lengths # padded to `dcp_local_block_size`. - padded_local_context_lens_cpu = ( + padded_local_context_lens_cpu: torch.Tensor = ( cdiv( context_lens_cpu, self.dcp_virtual_block_size, @@ -1171,7 +1174,9 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): ) def get_and_maybe_dequant_weights(layer: LinearBase): - if not isinstance(layer.quant_method, UnquantizedLinearMethod): + if layer.quant_method is not None and not isinstance( + layer.quant_method, UnquantizedLinearMethod + ): # NOTE: This should only be used offline, since it's O(N^3) eye = torch.eye( layer.input_size_per_partition, @@ -1327,12 +1332,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): # v with 0s to match the qk head dim for attention backends that do # not support different headdims # We don't need to pad V if we are on a hopper system with FA3 + device_capability = current_platform.get_device_capability() self._pad_v = self.vllm_flash_attn_version is None or not ( self.vllm_flash_attn_version == 3 - and current_platform.get_device_capability()[0] == 9 + and device_capability is not None + and device_capability[0] == 9 ) - self.dcp_world_size: int | None = None + self.dcp_world_size: int = -1 self.chunked_prefill_workspace_size = ( MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size( @@ -1583,7 +1590,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) def get_and_maybe_dequant_weights(layer: LinearBase): - if not isinstance(layer.quant_method, UnquantizedLinearMethod): + if layer.quant_method is not None and not isinstance( + layer.quant_method, UnquantizedLinearMethod + ): # NOTE: This should only be used offline, since it's O(N^3) eye = torch.eye( layer.input_size_per_partition, @@ -1875,7 +1884,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): ) -> None: # TODO (zyongye): Prefill function here assert attn_metadata.prefill is not None - assert self.dcp_world_size is not None + assert self.dcp_world_size != -1 has_context = attn_metadata.prefill.chunked_context is not None kv_nope = self.kv_b_proj(kv_c_normed)[0].view( @@ -1975,7 +1984,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): # same expert outputs. return output.fill_(0) - if self.dcp_world_size is None: + if self.dcp_world_size == -1: self.dcp_world_size = get_dcp_group().world_size fp8_attention = self.kv_cache_dtype.startswith("fp8") diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index b4a68f472..915b51c25 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -33,7 +33,10 @@ from vllm.v1.attention.backends.mla.common import ( ) from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.kv_cache_interface import AttentionSpec -from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata +from vllm.vllm_flash_attn import ( # type: ignore[attr-defined] + flash_attn_varlen_func, + get_scheduler_metadata, +) logger = init_logger(__name__) @@ -181,7 +184,11 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata] # For Flash Attention MLA + full cudagraph max_num_splits = 0 - if self.use_full_cuda_graph and num_decode_tokens <= self.max_cudagraph_size: + if ( + self.use_full_cuda_graph + and self.max_cudagraph_size is not None + and num_decode_tokens <= self.max_cudagraph_size + ): # NOTE(woosuk): Setting num_splits > 1 may increase the memory # usage, because the intermediate buffers of size [num_splits, # num_heads, num_tokens, head_size] are allocated. Therefore, diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 112253896..64cca2888 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -10,6 +10,7 @@ from vllm import _custom_ops as ops from vllm.attention.backends.abstract import ( AttentionBackend, AttentionLayer, + AttentionMetadata, MultipleOf, ) from vllm.attention.backends.utils import get_mla_dims @@ -124,7 +125,7 @@ class FlashMLASparseBackend(AttentionBackend): @dataclass -class FlashMLASparseMetadata: +class FlashMLASparseMetadata(AttentionMetadata): num_reqs: int max_query_len: int max_seq_len: int @@ -718,7 +719,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): ) self.softmax_scale = scale assert indexer is not None - self.topk_indices_buffer = indexer.topk_indices_buffer + self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer self.padding = 128 if current_platform.is_device_capability_family(100) else 64 if kv_cache_dtype == "fp8_ds_mla": @@ -980,6 +981,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): q = q[:num_actual_toks, ...] k_c_normed = k_c_normed[:num_actual_toks, ...] k_pe = k_pe[:num_actual_toks, ...] + assert self.topk_indices_buffer is not None topk_indices = self.topk_indices_buffer[:num_actual_toks] q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index c6e3f92dc..71be5b171 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -236,7 +236,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): k=k, v=v, softmax_scale=softmax_scale, - return_lse=return_softmax_lse, + return_softmax_lse=return_softmax_lse, **kwargs, ) @@ -251,6 +251,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]): ) -> tuple[torch.Tensor, torch.Tensor | None]: assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None + assert attn_metadata.decode.max_qo_len is not None if type(q) is tuple: q = torch.cat(q, dim=-1) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index c0e7f0e38..a461a2155 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -43,7 +43,7 @@ class ROCMAiterMLASparseBackend(AttentionBackend): return "ROCM_AITER_MLA_SPARSE" @staticmethod - def get_metadata_cls() -> type[AttentionMetadata]: + def get_metadata_cls() -> type["ROCMAiterMLASparseMetadata"]: return ROCMAiterMLASparseMetadata @staticmethod @@ -74,7 +74,7 @@ class ROCMAiterMLASparseBackend(AttentionBackend): @dataclass -class ROCMAiterMLASparseMetadata: +class ROCMAiterMLASparseMetadata(AttentionMetadata): num_reqs: int max_query_len: int max_seq_len: int @@ -223,7 +223,7 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]): ) self.softmax_scale = scale assert indexer is not None - self.topk_indices_buffer = indexer.topk_indices_buffer + self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer self.is_fp8bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled() def _forward_bf16_kv( @@ -294,6 +294,7 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]): # Convert from (N, B, L) to (B, N, L) ql_nope = ql_nope.transpose(0, 1) + assert self.topk_indices_buffer is not None topk_indices = self.topk_indices_buffer[:num_actual_toks] topk_indices_global = triton_convert_req_index_to_global_index( diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 523f759e0..5e3c436f8 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -155,7 +155,9 @@ class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadat self.block_size = kv_cache_spec.block_size spec_config = vllm_config.speculative_config - spec_token_tree = (spec := spec_config) and spec.speculative_token_tree + spec_token_tree: str | None = None + if spec := spec_config: + spec_token_tree = spec.speculative_token_tree tree_choices: list[tuple[int, ...]] = ( ast.literal_eval(spec_token_tree) if spec_token_tree is not None else [(0,)] ) diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 3cbdafe14..cc33b3319 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -469,6 +469,7 @@ def get_kv_cache_layout(): # Format specified by the code. global _KV_CACHE_LAYOUT_OVERRIDE + cache_layout: Literal["NHD", "HND"] | None = None if _KV_CACHE_LAYOUT_OVERRIDE is not None: cache_layout = _KV_CACHE_LAYOUT_OVERRIDE logger.info_once( @@ -524,7 +525,11 @@ def get_per_layer_parameters( to use during `plan`. """ - layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase, layer_names) + layers = get_layers_from_vllm_config( + vllm_config, + AttentionLayerBase, # type: ignore[type-abstract] + layer_names, + ) per_layer_params: dict[str, PerLayerParameters] = {} for key, layer in layers.items(): @@ -1125,7 +1130,7 @@ class KVSharingFastPrefillMetadata(Protocol): def create_fast_prefill_custom_backend( prefix: str, - underlying_attn_backend: AttentionBackend, + underlying_attn_backend: type[AttentionBackend], ) -> type[AttentionBackend]: underlying_builder = underlying_attn_backend.get_builder_cls()