[CI] Fix mypy for vllm/v1/ops (#39219)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-04-08 23:06:34 -04:00
committed by GitHub
parent 2a49284c8a
commit aec18492d0
4 changed files with 28 additions and 25 deletions

View File

@@ -36,8 +36,6 @@ SEPARATE_GROUPS = [
EXCLUDE = [
"vllm/model_executor/models",
"vllm/model_executor/layers/fla/ops",
# Ignore triton kernels in ops.
"vllm/v1/attention/ops",
# TODO: Remove these entries after fixing mypy errors.
"vllm/benchmarks",
]

View File

@@ -4,6 +4,8 @@
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
from typing import Any
import torch
from vllm.platforms import current_platform
@@ -780,7 +782,7 @@ def context_attention_fwd(
return
max_seq_len = 0 if max_seq_len is None else max_seq_len
extra_kargs = {}
extra_kargs: dict[str, Any] = {}
if current_platform.is_rocm():
extra_kargs = {}

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import importlib
from importlib.util import find_spec
import torch
@@ -276,11 +277,9 @@ def fp8_paged_mqa_logits_torch(
@functools.lru_cache
def paged_mqa_logits_module():
paged_mqa_logits_module_path = None
if importlib.util.find_spec("aiter.ops.triton.pa_mqa_logits") is not None:
if find_spec("aiter.ops.triton.pa_mqa_logits") is not None:
paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits"
elif (
importlib.util.find_spec("aiter.ops.triton.attention.pa_mqa_logits") is not None
):
elif find_spec("aiter.ops.triton.attention.pa_mqa_logits") is not None:
paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits"
if paged_mqa_logits_module_path is not None:
@@ -380,9 +379,9 @@ def fp8_mqa_logits_torch(
Returns:
Logits tensor of shape [M, N], dtype `torch.float32`.
"""
kv, scale = kv
seq_len_kv = kv.shape[0]
k = kv.to(torch.bfloat16)
k_fp8, scale = kv
seq_len_kv = k_fp8.shape[0]
k = k_fp8.to(torch.bfloat16)
q = q.to(torch.bfloat16)
mask_lo = (
@@ -403,12 +402,9 @@ def fp8_mqa_logits_torch(
@functools.lru_cache
def mqa_logits_module():
mqa_logits_module_path = None
if importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None:
if find_spec("aiter.ops.triton.fp8_mqa_logits") is not None:
mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits"
elif (
importlib.util.find_spec("aiter.ops.triton.attention.fp8_mqa_logits")
is not None
):
elif find_spec("aiter.ops.triton.attention.fp8_mqa_logits") is not None:
mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits"
if mqa_logits_module_path is not None:
@@ -455,8 +451,8 @@ def rocm_fp8_mqa_logits(
if aiter_mqa_logits_module is not None:
fp8_mqa_logits = aiter_mqa_logits_module.fp8_mqa_logits
kv, scale = kv
return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
k_fp8, scale = kv
return fp8_mqa_logits(q, k_fp8, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
else:
return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
@@ -523,12 +519,14 @@ def rocm_aiter_sparse_attn_indexer(
total_seq_lens,
topk_indices_buffer,
)
attn_metadata = attn_metadata[k_cache_prefix]
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata.slot_mapping
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
layer_attn_metadata = attn_metadata[k_cache_prefix]
assert isinstance(layer_attn_metadata, DeepseekV32IndexerMetadata)
assert topk_indices_buffer is not None
assert scale_fmt is not None
slot_mapping = layer_attn_metadata.slot_mapping
has_decode = layer_attn_metadata.num_decodes > 0
has_prefill = layer_attn_metadata.num_prefills > 0
num_decode_tokens = layer_attn_metadata.num_decode_tokens
ops.indexer_k_quant_and_cache(
k,
@@ -540,7 +538,8 @@ def rocm_aiter_sparse_attn_indexer(
topk_indices_buffer[: hidden_states.shape[0]] = -1
if has_prefill:
prefill_metadata = attn_metadata.prefill
prefill_metadata = layer_attn_metadata.prefill
assert prefill_metadata is not None
for chunk in prefill_metadata.chunks:
k_fp8 = torch.empty(
[chunk.total_seq_lens, head_dim],
@@ -585,7 +584,8 @@ def rocm_aiter_sparse_attn_indexer(
)
if has_decode:
decode_metadata = attn_metadata.decode
decode_metadata = layer_attn_metadata.decode
assert decode_metadata is not None
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
# we only have [num_block, block_size, head_dim],
kv_cache = kv_cache.unsqueeze(-2)

View File

@@ -292,6 +292,9 @@ def flashinfer_wrapper(
# RoPE has already made q and k contiguous.
q, k = q.contiguous(), k.contiguous()
assert cu_seqlens is not None
assert max_seqlen is not None
assert sequence_lengths is not None
assert len(cu_seqlens) % 2 == 0, "cu_seqlens must be divisible by 2"
cu_seqlength = len(cu_seqlens) // 2
batch_offsets_qko = cu_seqlens[:cu_seqlength].view(-1, 1, 1, 1)