[CI] Fix mypy for vllm/v1/ops (#39219)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -36,8 +36,6 @@ SEPARATE_GROUPS = [
|
|||||||
EXCLUDE = [
|
EXCLUDE = [
|
||||||
"vllm/model_executor/models",
|
"vllm/model_executor/models",
|
||||||
"vllm/model_executor/layers/fla/ops",
|
"vllm/model_executor/layers/fla/ops",
|
||||||
# Ignore triton kernels in ops.
|
|
||||||
"vllm/v1/attention/ops",
|
|
||||||
# TODO: Remove these entries after fixing mypy errors.
|
# TODO: Remove these entries after fixing mypy errors.
|
||||||
"vllm/benchmarks",
|
"vllm/benchmarks",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -4,6 +4,8 @@
|
|||||||
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
|
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
|
||||||
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
|
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@@ -780,7 +782,7 @@ def context_attention_fwd(
|
|||||||
return
|
return
|
||||||
|
|
||||||
max_seq_len = 0 if max_seq_len is None else max_seq_len
|
max_seq_len = 0 if max_seq_len is None else max_seq_len
|
||||||
extra_kargs = {}
|
extra_kargs: dict[str, Any] = {}
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
extra_kargs = {}
|
extra_kargs = {}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import functools
|
import functools
|
||||||
import importlib
|
import importlib
|
||||||
|
from importlib.util import find_spec
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -276,11 +277,9 @@ def fp8_paged_mqa_logits_torch(
|
|||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def paged_mqa_logits_module():
|
def paged_mqa_logits_module():
|
||||||
paged_mqa_logits_module_path = None
|
paged_mqa_logits_module_path = None
|
||||||
if importlib.util.find_spec("aiter.ops.triton.pa_mqa_logits") is not None:
|
if find_spec("aiter.ops.triton.pa_mqa_logits") is not None:
|
||||||
paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits"
|
paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits"
|
||||||
elif (
|
elif find_spec("aiter.ops.triton.attention.pa_mqa_logits") is not None:
|
||||||
importlib.util.find_spec("aiter.ops.triton.attention.pa_mqa_logits") is not None
|
|
||||||
):
|
|
||||||
paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits"
|
paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits"
|
||||||
|
|
||||||
if paged_mqa_logits_module_path is not None:
|
if paged_mqa_logits_module_path is not None:
|
||||||
@@ -380,9 +379,9 @@ def fp8_mqa_logits_torch(
|
|||||||
Returns:
|
Returns:
|
||||||
Logits tensor of shape [M, N], dtype `torch.float32`.
|
Logits tensor of shape [M, N], dtype `torch.float32`.
|
||||||
"""
|
"""
|
||||||
kv, scale = kv
|
k_fp8, scale = kv
|
||||||
seq_len_kv = kv.shape[0]
|
seq_len_kv = k_fp8.shape[0]
|
||||||
k = kv.to(torch.bfloat16)
|
k = k_fp8.to(torch.bfloat16)
|
||||||
q = q.to(torch.bfloat16)
|
q = q.to(torch.bfloat16)
|
||||||
|
|
||||||
mask_lo = (
|
mask_lo = (
|
||||||
@@ -403,12 +402,9 @@ def fp8_mqa_logits_torch(
|
|||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def mqa_logits_module():
|
def mqa_logits_module():
|
||||||
mqa_logits_module_path = None
|
mqa_logits_module_path = None
|
||||||
if importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None:
|
if find_spec("aiter.ops.triton.fp8_mqa_logits") is not None:
|
||||||
mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits"
|
mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits"
|
||||||
elif (
|
elif find_spec("aiter.ops.triton.attention.fp8_mqa_logits") is not None:
|
||||||
importlib.util.find_spec("aiter.ops.triton.attention.fp8_mqa_logits")
|
|
||||||
is not None
|
|
||||||
):
|
|
||||||
mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits"
|
mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits"
|
||||||
|
|
||||||
if mqa_logits_module_path is not None:
|
if mqa_logits_module_path is not None:
|
||||||
@@ -455,8 +451,8 @@ def rocm_fp8_mqa_logits(
|
|||||||
|
|
||||||
if aiter_mqa_logits_module is not None:
|
if aiter_mqa_logits_module is not None:
|
||||||
fp8_mqa_logits = aiter_mqa_logits_module.fp8_mqa_logits
|
fp8_mqa_logits = aiter_mqa_logits_module.fp8_mqa_logits
|
||||||
kv, scale = kv
|
k_fp8, scale = kv
|
||||||
return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
|
return fp8_mqa_logits(q, k_fp8, scale, weights, cu_seqlen_ks, cu_seqlen_ke)
|
||||||
else:
|
else:
|
||||||
return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
|
return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke)
|
||||||
|
|
||||||
@@ -523,12 +519,14 @@ def rocm_aiter_sparse_attn_indexer(
|
|||||||
total_seq_lens,
|
total_seq_lens,
|
||||||
topk_indices_buffer,
|
topk_indices_buffer,
|
||||||
)
|
)
|
||||||
attn_metadata = attn_metadata[k_cache_prefix]
|
layer_attn_metadata = attn_metadata[k_cache_prefix]
|
||||||
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
|
assert isinstance(layer_attn_metadata, DeepseekV32IndexerMetadata)
|
||||||
slot_mapping = attn_metadata.slot_mapping
|
assert topk_indices_buffer is not None
|
||||||
has_decode = attn_metadata.num_decodes > 0
|
assert scale_fmt is not None
|
||||||
has_prefill = attn_metadata.num_prefills > 0
|
slot_mapping = layer_attn_metadata.slot_mapping
|
||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
has_decode = layer_attn_metadata.num_decodes > 0
|
||||||
|
has_prefill = layer_attn_metadata.num_prefills > 0
|
||||||
|
num_decode_tokens = layer_attn_metadata.num_decode_tokens
|
||||||
|
|
||||||
ops.indexer_k_quant_and_cache(
|
ops.indexer_k_quant_and_cache(
|
||||||
k,
|
k,
|
||||||
@@ -540,7 +538,8 @@ def rocm_aiter_sparse_attn_indexer(
|
|||||||
|
|
||||||
topk_indices_buffer[: hidden_states.shape[0]] = -1
|
topk_indices_buffer[: hidden_states.shape[0]] = -1
|
||||||
if has_prefill:
|
if has_prefill:
|
||||||
prefill_metadata = attn_metadata.prefill
|
prefill_metadata = layer_attn_metadata.prefill
|
||||||
|
assert prefill_metadata is not None
|
||||||
for chunk in prefill_metadata.chunks:
|
for chunk in prefill_metadata.chunks:
|
||||||
k_fp8 = torch.empty(
|
k_fp8 = torch.empty(
|
||||||
[chunk.total_seq_lens, head_dim],
|
[chunk.total_seq_lens, head_dim],
|
||||||
@@ -585,7 +584,8 @@ def rocm_aiter_sparse_attn_indexer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if has_decode:
|
if has_decode:
|
||||||
decode_metadata = attn_metadata.decode
|
decode_metadata = layer_attn_metadata.decode
|
||||||
|
assert decode_metadata is not None
|
||||||
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
|
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
|
||||||
# we only have [num_block, block_size, head_dim],
|
# we only have [num_block, block_size, head_dim],
|
||||||
kv_cache = kv_cache.unsqueeze(-2)
|
kv_cache = kv_cache.unsqueeze(-2)
|
||||||
|
|||||||
@@ -292,6 +292,9 @@ def flashinfer_wrapper(
|
|||||||
# RoPE has already made q and k contiguous.
|
# RoPE has already made q and k contiguous.
|
||||||
q, k = q.contiguous(), k.contiguous()
|
q, k = q.contiguous(), k.contiguous()
|
||||||
|
|
||||||
|
assert cu_seqlens is not None
|
||||||
|
assert max_seqlen is not None
|
||||||
|
assert sequence_lengths is not None
|
||||||
assert len(cu_seqlens) % 2 == 0, "cu_seqlens must be divisible by 2"
|
assert len(cu_seqlens) % 2 == 0, "cu_seqlens must be divisible by 2"
|
||||||
cu_seqlength = len(cu_seqlens) // 2
|
cu_seqlength = len(cu_seqlens) // 2
|
||||||
batch_offsets_qko = cu_seqlens[:cu_seqlength].view(-1, 1, 1, 1)
|
batch_offsets_qko = cu_seqlens[:cu_seqlength].view(-1, 1, 1, 1)
|
||||||
|
|||||||
Reference in New Issue
Block a user