[CI] Fix mypy for vllm/v1/ops (#39219)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-04-08 23:06:34 -04:00
committed by GitHub
parent 2a49284c8a
commit aec18492d0
4 changed files with 28 additions and 25 deletions

View File

@@ -36,8 +36,6 @@ SEPARATE_GROUPS = [
EXCLUDE = [ EXCLUDE = [
"vllm/model_executor/models", "vllm/model_executor/models",
"vllm/model_executor/layers/fla/ops", "vllm/model_executor/layers/fla/ops",
# Ignore triton kernels in ops.
"vllm/v1/attention/ops",
# TODO: Remove these entries after fixing mypy errors. # TODO: Remove these entries after fixing mypy errors.
"vllm/benchmarks", "vllm/benchmarks",
] ]

View File

@@ -4,6 +4,8 @@
# The kernels in this file are adapted from LightLLM's context_attention_fwd: # The kernels in this file are adapted from LightLLM's context_attention_fwd:
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py # https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
from typing import Any
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
@@ -780,7 +782,7 @@ def context_attention_fwd(
return return
max_seq_len = 0 if max_seq_len is None else max_seq_len max_seq_len = 0 if max_seq_len is None else max_seq_len
extra_kargs = {} extra_kargs: dict[str, Any] = {}
if current_platform.is_rocm(): if current_platform.is_rocm():
extra_kargs = {} extra_kargs = {}

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools import functools
import importlib import importlib
from importlib.util import find_spec
import torch import torch
@@ -276,11 +277,9 @@ def fp8_paged_mqa_logits_torch(
@functools.lru_cache @functools.lru_cache
def paged_mqa_logits_module(): def paged_mqa_logits_module():
paged_mqa_logits_module_path = None paged_mqa_logits_module_path = None
if importlib.util.find_spec("aiter.ops.triton.pa_mqa_logits") is not None: if find_spec("aiter.ops.triton.pa_mqa_logits") is not None:
paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits" paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits"
elif ( elif find_spec("aiter.ops.triton.attention.pa_mqa_logits") is not None:
importlib.util.find_spec("aiter.ops.triton.attention.pa_mqa_logits") is not None
):
paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits" paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits"
if paged_mqa_logits_module_path is not None: if paged_mqa_logits_module_path is not None:
@@ -380,9 +379,9 @@ def fp8_mqa_logits_torch(
Returns: Returns:
Logits tensor of shape [M, N], dtype `torch.float32`. Logits tensor of shape [M, N], dtype `torch.float32`.
""" """
kv, scale = kv k_fp8, scale = kv
seq_len_kv = kv.shape[0] seq_len_kv = k_fp8.shape[0]
k = kv.to(torch.bfloat16) k = k_fp8.to(torch.bfloat16)
q = q.to(torch.bfloat16) q = q.to(torch.bfloat16)
mask_lo = ( mask_lo = (
@@ -403,12 +402,9 @@ def fp8_mqa_logits_torch(
@functools.lru_cache @functools.lru_cache
def mqa_logits_module(): def mqa_logits_module():
mqa_logits_module_path = None mqa_logits_module_path = None
if importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None: if find_spec("aiter.ops.triton.fp8_mqa_logits") is not None:
mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits" mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits"
elif ( elif find_spec("aiter.ops.triton.attention.fp8_mqa_logits") is not None:
importlib.util.find_spec("aiter.ops.triton.attention.fp8_mqa_logits")
is not None
):
mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits" mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits"
if mqa_logits_module_path is not None: if mqa_logits_module_path is not None:
@@ -455,8 +451,8 @@ def rocm_fp8_mqa_logits(
if aiter_mqa_logits_module is not None: if aiter_mqa_logits_module is not None:
fp8_mqa_logits = aiter_mqa_logits_module.fp8_mqa_logits fp8_mqa_logits = aiter_mqa_logits_module.fp8_mqa_logits
kv, scale = kv k_fp8, scale = kv
return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke) return fp8_mqa_logits(q, k_fp8, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
else: else:
return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
@@ -523,12 +519,14 @@ def rocm_aiter_sparse_attn_indexer(
total_seq_lens, total_seq_lens,
topk_indices_buffer, topk_indices_buffer,
) )
attn_metadata = attn_metadata[k_cache_prefix] layer_attn_metadata = attn_metadata[k_cache_prefix]
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) assert isinstance(layer_attn_metadata, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata.slot_mapping assert topk_indices_buffer is not None
has_decode = attn_metadata.num_decodes > 0 assert scale_fmt is not None
has_prefill = attn_metadata.num_prefills > 0 slot_mapping = layer_attn_metadata.slot_mapping
num_decode_tokens = attn_metadata.num_decode_tokens has_decode = layer_attn_metadata.num_decodes > 0
has_prefill = layer_attn_metadata.num_prefills > 0
num_decode_tokens = layer_attn_metadata.num_decode_tokens
ops.indexer_k_quant_and_cache( ops.indexer_k_quant_and_cache(
k, k,
@@ -540,7 +538,8 @@ def rocm_aiter_sparse_attn_indexer(
topk_indices_buffer[: hidden_states.shape[0]] = -1 topk_indices_buffer[: hidden_states.shape[0]] = -1
if has_prefill: if has_prefill:
prefill_metadata = attn_metadata.prefill prefill_metadata = layer_attn_metadata.prefill
assert prefill_metadata is not None
for chunk in prefill_metadata.chunks: for chunk in prefill_metadata.chunks:
k_fp8 = torch.empty( k_fp8 = torch.empty(
[chunk.total_seq_lens, head_dim], [chunk.total_seq_lens, head_dim],
@@ -585,7 +584,8 @@ def rocm_aiter_sparse_attn_indexer(
) )
if has_decode: if has_decode:
decode_metadata = attn_metadata.decode decode_metadata = layer_attn_metadata.decode
assert decode_metadata is not None
# kv_cache size requirement [num_block, block_size, n_head, head_dim], # kv_cache size requirement [num_block, block_size, n_head, head_dim],
# we only have [num_block, block_size, head_dim], # we only have [num_block, block_size, head_dim],
kv_cache = kv_cache.unsqueeze(-2) kv_cache = kv_cache.unsqueeze(-2)

View File

@@ -292,6 +292,9 @@ def flashinfer_wrapper(
# RoPE has already made q and k contiguous. # RoPE has already made q and k contiguous.
q, k = q.contiguous(), k.contiguous() q, k = q.contiguous(), k.contiguous()
assert cu_seqlens is not None
assert max_seqlen is not None
assert sequence_lengths is not None
assert len(cu_seqlens) % 2 == 0, "cu_seqlens must be divisible by 2" assert len(cu_seqlens) % 2 == 0, "cu_seqlens must be divisible by 2"
cu_seqlength = len(cu_seqlens) // 2 cu_seqlength = len(cu_seqlens) // 2
batch_offsets_qko = cu_seqlens[:cu_seqlength].view(-1, 1, 1, 1) batch_offsets_qko = cu_seqlens[:cu_seqlength].view(-1, 1, 1, 1)