fixed mypy warnings for files vllm/v1/attention with TEMPORARY workaround (#31465)

Signed-off-by: Zhuohao Yang <zy242@cornell.edu>
Co-authored-by: Zhuohao Yang <zy242@cornell.edu>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
Jack Yang
2026-01-06 23:08:47 -05:00
committed by GitHub
parent f09c5feb7c
commit 0a2c2dc3f1
18 changed files with 140 additions and 56 deletions

View File

@@ -167,7 +167,7 @@ class FlashAttentionBackend(AttentionBackend):
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int,
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
@@ -354,7 +354,11 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
aot_schedule = False
max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible
if self.use_full_cuda_graph and num_actual_tokens <= self.max_cudagraph_size:
if (
self.use_full_cuda_graph
and self.max_cudagraph_size is not None
and num_actual_tokens <= self.max_cudagraph_size
):
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
@@ -599,6 +603,9 @@ class FlashAttentionImpl(AttentionImpl):
We use torch's .expand() to avoid duplicating values
"""
assert output is not None, "Output tensor must be provided."
assert self.vllm_flash_attn_version is not None, (
"FlashAttention version not detected."
)
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
@@ -697,6 +704,11 @@ class FlashAttentionImpl(AttentionImpl):
)
return output
else:
sliding_window_size = (
list(self.sliding_window)
if self.sliding_window is not None
else None
)
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
@@ -709,7 +721,7 @@ class FlashAttentionImpl(AttentionImpl):
softmax_scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
window_size=sliding_window_size,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,
@@ -764,12 +776,19 @@ class FlashAttentionImpl(AttentionImpl):
k_descale: torch.Tensor | None = None,
v_descale: torch.Tensor | None = None,
) -> torch.Tensor:
assert self.vllm_flash_attn_version is not None, (
"FlashAttention version not detected."
)
cu_seqlens_q = attn_metadata.query_start_loc
max_seqlen_q = attn_metadata.max_query_len
block_table = attn_metadata.block_table
query = query.contiguous()
query_across_dcp = get_dcp_group().all_gather(query, dim=1)
sliding_window_size = (
list(self.sliding_window) if self.sliding_window is not None else None
)
context_attn_out, context_lse = flash_attn_varlen_func(
q=query_across_dcp,
k=key_cache,
@@ -782,7 +801,7 @@ class FlashAttentionImpl(AttentionImpl):
softmax_scale=self.scale,
causal=False,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
window_size=sliding_window_size,
block_table=block_table,
softcap=self.logits_soft_cap,
return_softmax_lse=True,
@@ -813,7 +832,7 @@ class FlashAttentionImpl(AttentionImpl):
softmax_scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
window_size=sliding_window_size,
softcap=self.logits_soft_cap,
return_softmax_lse=True,
fa_version=self.vllm_flash_attn_version,
@@ -850,6 +869,10 @@ class FlashAttentionImpl(AttentionImpl):
attn_metadata: Encoder attention metadata
layer: The attention layer
"""
assert self.vllm_flash_attn_version is not None, (
"FlashAttention version not detected."
)
# For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError(
@@ -868,6 +891,9 @@ class FlashAttentionImpl(AttentionImpl):
)
# Call flash attention directly on Q, K, V tensors
sliding_window_size = (
list(self.sliding_window) if self.sliding_window is not None else None
)
flash_attn_varlen_func(
q=query,
k=key,
@@ -880,7 +906,7 @@ class FlashAttentionImpl(AttentionImpl):
softmax_scale=self.scale,
causal=False, # Encoder attention is bidirectional
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
window_size=sliding_window_size,
softcap=self.logits_soft_cap,
fa_version=self.vllm_flash_attn_version,
q_descale=layer._q_scale.expand(descale_shape),
@@ -1020,7 +1046,7 @@ def cascade_attention(
max_seqlen_k=common_prefix_len,
softmax_scale=softmax_scale,
causal=False,
window_size=sliding_window,
window_size=list(sliding_window),
block_table=block_table[:1],
softcap=logits_soft_cap,
return_softmax_lse=True,
@@ -1048,7 +1074,7 @@ def cascade_attention(
max_seqlen_k=max_kv_len - common_prefix_len,
softmax_scale=softmax_scale,
causal=True,
window_size=sliding_window,
window_size=list(sliding_window),
block_table=block_table[:, num_common_kv_blocks:],
softcap=logits_soft_cap,
return_softmax_lse=True,

View File

@@ -113,6 +113,9 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
We use torch's .expand() to avoid duplicating values
"""
assert output is not None, "Output tensor must be provided."
assert self.vllm_flash_attn_version is not None, (
"FlashAttention version not detected."
)
if output_scale is not None or output_block_scale is not None:
raise NotImplementedError(
@@ -214,6 +217,11 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
)
return output
else:
sliding_window_size = (
list(self.sliding_window)
if self.sliding_window is not None
else None
)
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
@@ -226,7 +234,7 @@ class FlashAttentionDiffKVImpl(FlashAttentionImpl):
softmax_scale=self.scale,
causal=attn_metadata.causal,
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
window_size=sliding_window_size,
block_table=block_table,
softcap=self.logits_soft_cap,
scheduler_metadata=scheduler_metadata,

View File

@@ -530,11 +530,12 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
self._decode_wrappers_cudagraph: dict[
int, BatchDecodeWithPagedKVCacheWrapper
] = {}
self._decode_cudagraph_max_bs = min(
(1 + num_spec_tokens) * max_num_reqs,
self.compilation_config.max_cudagraph_capture_size,
)
self._decode_cudagraph_max_bs = (1 + num_spec_tokens) * max_num_reqs
if self.compilation_config.max_cudagraph_capture_size is not None:
self._decode_cudagraph_max_bs = min(
self._decode_cudagraph_max_bs,
self.compilation_config.max_cudagraph_capture_size,
)
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group

View File

@@ -215,7 +215,7 @@ def physical_to_logical_mapping(
)
# Only process valid blocks to avoid garbage values
num_blocks_per_seq = cdiv(seq_lens, block_size)
num_blocks_per_seq: torch.Tensor = cdiv(seq_lens, block_size)
mask = (
torch.arange(max_num_blocks, device=device)[None, :]
< num_blocks_per_seq[:, None]

View File

@@ -75,8 +75,10 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
self.compilation_config = vllm_config.compilation_config
self.speculative_config = vllm_config.speculative_config
self.kv_cache_spec = kv_cache_spec
if self.speculative_config:
self.num_spec = self.speculative_config.num_speculative_tokens
assert self.speculative_config.num_speculative_tokens is not None
self.num_spec: int = self.speculative_config.num_speculative_tokens
else:
self.num_spec = 0
self.use_spec_decode = self.num_spec > 0
@@ -85,10 +87,15 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
self.use_full_cuda_graph = (
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1),
self.compilation_config.max_cudagraph_capture_size,
self.decode_cudagraph_max_bs = (
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1)
)
if self.compilation_config.max_cudagraph_capture_size is not None:
self.decode_cudagraph_max_bs = min(
self.decode_cudagraph_max_bs,
self.compilation_config.max_cudagraph_capture_size,
)
self.spec_state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs, self.num_spec + 1),

View File

@@ -123,10 +123,11 @@ class Mamba2AttentionMetadataBuilder(
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
assert self.chunk_size is not None, (
chunk_size = vllm_config.model_config.get_mamba_chunk_size()
assert chunk_size is not None, (
"chunk_size needs to be set in the model config for Mamba2 models"
)
self.chunk_size: int = chunk_size
def _compute_chunk_metadata(
self,

View File

@@ -69,10 +69,12 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
assert isinstance(kv_cache_spec, MambaSpec)
self.compilation_config = vllm_config.compilation_config
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs,
self.compilation_config.max_cudagraph_capture_size,
)
self.decode_cudagraph_max_bs = self.vllm_config.scheduler_config.max_num_seqs
if self.compilation_config.max_cudagraph_capture_size is not None:
self.decode_cudagraph_max_bs = min(
self.decode_cudagraph_max_bs,
self.compilation_config.max_cudagraph_capture_size,
)
if self.vllm_config.cache_config.enable_prefix_caching:
self.state_indices_tensor = torch.empty(
@@ -150,9 +152,13 @@ class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1
)
# -1 in case it's non-computed and causes later issues with indexing
block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0)
block_idx_last_computed_token = torch.clamp(
block_idx_last_computed_token, min=0
)
# -1 in the case we have a padded request (0 seq-len)
block_idx_last_scheduled_token = block_idx_last_scheduled_token.clamp(min=0)
block_idx_last_scheduled_token = torch.clamp(
block_idx_last_scheduled_token, min=0
)
return (
block_idx_last_computed_token,

View File

@@ -62,7 +62,7 @@ class AiterTritonMLAImpl(AiterMLAImpl):
k,
v,
softmax_scale=softmax_scale,
return_lse=return_softmax_lse,
return_softmax_lse=return_softmax_lse,
**kwargs,
)
# Transpose the LSE if Triton MHA is used:

View File

@@ -202,6 +202,7 @@ from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionLayer,
AttentionMetadata,
MLAAttentionImpl,
)
from vllm.attention.backends.utils import get_mla_dims
@@ -251,13 +252,15 @@ class QueryLenSupport(Enum):
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
flash_attn_varlen_func,
)
is_vllm_fa = True
except ImportError:
# For rocm use upstream flash attention
if current_platform.is_rocm():
from flash_attn import flash_attn_varlen_func
from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
is_vllm_fa = False
try:
@@ -386,7 +389,7 @@ D = TypeVar("D", bound=MLACommonDecodeMetadata)
@dataclass
class MLACommonMetadata(Generic[D]):
class MLACommonMetadata(AttentionMetadata, Generic[D]):
"""Metadata for MLACommon.
NOTE: Please read the comment at the top of the file before trying to
@@ -434,7 +437,7 @@ class MLACommonMetadata(Generic[D]):
M = TypeVar("M", bound=MLACommonMetadata)
A = TypeVar("A")
A = TypeVar("A", bound=AttentionMetadata)
def use_flashinfer_prefill() -> bool:
@@ -617,7 +620,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self._fi_prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = []
self._global_hyperparameters = infer_global_hyperparameters(
get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl)
get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl) # type: ignore[type-abstract]
)
if self._use_trtllm_ragged_prefill:
@@ -874,7 +877,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
)
# Note(qcs): The max local context lengths
# padded to `dcp_local_block_size`.
padded_local_context_lens_cpu = (
padded_local_context_lens_cpu: torch.Tensor = (
cdiv(
context_lens_cpu,
self.dcp_virtual_block_size,
@@ -1171,7 +1174,9 @@ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
)
def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
if layer.quant_method is not None and not isinstance(
layer.quant_method, UnquantizedLinearMethod
):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(
layer.input_size_per_partition,
@@ -1327,12 +1332,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
device_capability = current_platform.get_device_capability()
self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9
and device_capability is not None
and device_capability[0] == 9
)
self.dcp_world_size: int | None = None
self.dcp_world_size: int = -1
self.chunked_prefill_workspace_size = (
MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
@@ -1583,7 +1590,9 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
)
def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
if layer.quant_method is not None and not isinstance(
layer.quant_method, UnquantizedLinearMethod
):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(
layer.input_size_per_partition,
@@ -1875,7 +1884,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
) -> None:
# TODO (zyongye): Prefill function here
assert attn_metadata.prefill is not None
assert self.dcp_world_size is not None
assert self.dcp_world_size != -1
has_context = attn_metadata.prefill.chunked_context is not None
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
@@ -1975,7 +1984,7 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
# same expert outputs.
return output.fill_(0)
if self.dcp_world_size is None:
if self.dcp_world_size == -1:
self.dcp_world_size = get_dcp_group().world_size
fp8_attention = self.kv_cache_dtype.startswith("fp8")

View File

@@ -33,7 +33,10 @@ from vllm.v1.attention.backends.mla.common import (
)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
flash_attn_varlen_func,
get_scheduler_metadata,
)
logger = init_logger(__name__)
@@ -181,7 +184,11 @@ class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]
# For Flash Attention MLA + full cudagraph
max_num_splits = 0
if self.use_full_cuda_graph and num_decode_tokens <= self.max_cudagraph_size:
if (
self.use_full_cuda_graph
and self.max_cudagraph_size is not None
and num_decode_tokens <= self.max_cudagraph_size
):
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,

View File

@@ -10,6 +10,7 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionLayer,
AttentionMetadata,
MultipleOf,
)
from vllm.attention.backends.utils import get_mla_dims
@@ -124,7 +125,7 @@ class FlashMLASparseBackend(AttentionBackend):
@dataclass
class FlashMLASparseMetadata:
class FlashMLASparseMetadata(AttentionMetadata):
num_reqs: int
max_query_len: int
max_seq_len: int
@@ -718,7 +719,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
)
self.softmax_scale = scale
assert indexer is not None
self.topk_indices_buffer = indexer.topk_indices_buffer
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
self.padding = 128 if current_platform.is_device_capability_family(100) else 64
if kv_cache_dtype == "fp8_ds_mla":
@@ -980,6 +981,7 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
q = q[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[:num_actual_toks]
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

View File

@@ -236,7 +236,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
k=k,
v=v,
softmax_scale=softmax_scale,
return_lse=return_softmax_lse,
return_softmax_lse=return_softmax_lse,
**kwargs,
)
@@ -251,6 +251,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
) -> tuple[torch.Tensor, torch.Tensor | None]:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
assert attn_metadata.decode.max_qo_len is not None
if type(q) is tuple:
q = torch.cat(q, dim=-1)

View File

@@ -43,7 +43,7 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
return "ROCM_AITER_MLA_SPARSE"
@staticmethod
def get_metadata_cls() -> type[AttentionMetadata]:
def get_metadata_cls() -> type["ROCMAiterMLASparseMetadata"]:
return ROCMAiterMLASparseMetadata
@staticmethod
@@ -74,7 +74,7 @@ class ROCMAiterMLASparseBackend(AttentionBackend):
@dataclass
class ROCMAiterMLASparseMetadata:
class ROCMAiterMLASparseMetadata(AttentionMetadata):
num_reqs: int
max_query_len: int
max_seq_len: int
@@ -223,7 +223,7 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
)
self.softmax_scale = scale
assert indexer is not None
self.topk_indices_buffer = indexer.topk_indices_buffer
self.topk_indices_buffer: torch.Tensor | None = indexer.topk_indices_buffer
self.is_fp8bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
def _forward_bf16_kv(
@@ -294,6 +294,7 @@ class ROCMAiterMLASparseImpl(MLACommonBaseImpl[ROCMAiterMLASparseMetadata]):
# Convert from (N, B, L) to (B, N, L)
ql_nope = ql_nope.transpose(0, 1)
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[:num_actual_toks]
topk_indices_global = triton_convert_req_index_to_global_index(

View File

@@ -155,7 +155,9 @@ class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadat
self.block_size = kv_cache_spec.block_size
spec_config = vllm_config.speculative_config
spec_token_tree = (spec := spec_config) and spec.speculative_token_tree
spec_token_tree: str | None = None
if spec := spec_config:
spec_token_tree = spec.speculative_token_tree
tree_choices: list[tuple[int, ...]] = (
ast.literal_eval(spec_token_tree) if spec_token_tree is not None else [(0,)]
)

View File

@@ -469,6 +469,7 @@ def get_kv_cache_layout():
# Format specified by the code.
global _KV_CACHE_LAYOUT_OVERRIDE
cache_layout: Literal["NHD", "HND"] | None = None
if _KV_CACHE_LAYOUT_OVERRIDE is not None:
cache_layout = _KV_CACHE_LAYOUT_OVERRIDE
logger.info_once(
@@ -524,7 +525,11 @@ def get_per_layer_parameters(
to use during `plan`.
"""
layers = get_layers_from_vllm_config(vllm_config, AttentionLayerBase, layer_names)
layers = get_layers_from_vllm_config(
vllm_config,
AttentionLayerBase, # type: ignore[type-abstract]
layer_names,
)
per_layer_params: dict[str, PerLayerParameters] = {}
for key, layer in layers.items():
@@ -1125,7 +1130,7 @@ class KVSharingFastPrefillMetadata(Protocol):
def create_fast_prefill_custom_backend(
prefix: str,
underlying_attn_backend: AttentionBackend,
underlying_attn_backend: type[AttentionBackend],
) -> type[AttentionBackend]:
underlying_builder = underlying_attn_backend.get_builder_cls()