diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py index 1ba1f8156..41c05efd2 100755 --- a/tools/pre_commit/mypy.py +++ b/tools/pre_commit/mypy.py @@ -36,8 +36,6 @@ SEPARATE_GROUPS = [ EXCLUDE = [ "vllm/model_executor/models", "vllm/model_executor/layers/fla/ops", - # Ignore triton kernels in ops. - "vllm/v1/attention/ops", # TODO: Remove these entries after fixing mypy errors. "vllm/benchmarks", ] diff --git a/vllm/v1/attention/ops/prefix_prefill.py b/vllm/v1/attention/ops/prefix_prefill.py index 13c82f586..afa5f5178 100644 --- a/vllm/v1/attention/ops/prefix_prefill.py +++ b/vllm/v1/attention/ops/prefix_prefill.py @@ -4,6 +4,8 @@ # The kernels in this file are adapted from LightLLM's context_attention_fwd: # https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py +from typing import Any + import torch from vllm.platforms import current_platform @@ -780,7 +782,7 @@ def context_attention_fwd( return max_seq_len = 0 if max_seq_len is None else max_seq_len - extra_kargs = {} + extra_kargs: dict[str, Any] = {} if current_platform.is_rocm(): extra_kargs = {} diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 9d1da5b53..886b9410c 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools import importlib +from importlib.util import find_spec import torch @@ -276,11 +277,9 @@ def fp8_paged_mqa_logits_torch( @functools.lru_cache def paged_mqa_logits_module(): paged_mqa_logits_module_path = None - if importlib.util.find_spec("aiter.ops.triton.pa_mqa_logits") is not None: + if find_spec("aiter.ops.triton.pa_mqa_logits") is not None: paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits" - elif ( - importlib.util.find_spec("aiter.ops.triton.attention.pa_mqa_logits") is not None - ): + elif find_spec("aiter.ops.triton.attention.pa_mqa_logits") is not None: paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits" if paged_mqa_logits_module_path is not None: @@ -380,9 +379,9 @@ def fp8_mqa_logits_torch( Returns: Logits tensor of shape [M, N], dtype `torch.float32`. """ - kv, scale = kv - seq_len_kv = kv.shape[0] - k = kv.to(torch.bfloat16) + k_fp8, scale = kv + seq_len_kv = k_fp8.shape[0] + k = k_fp8.to(torch.bfloat16) q = q.to(torch.bfloat16) mask_lo = ( @@ -403,12 +402,9 @@ def fp8_mqa_logits_torch( @functools.lru_cache def mqa_logits_module(): mqa_logits_module_path = None - if importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None: + if find_spec("aiter.ops.triton.fp8_mqa_logits") is not None: mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits" - elif ( - importlib.util.find_spec("aiter.ops.triton.attention.fp8_mqa_logits") - is not None - ): + elif find_spec("aiter.ops.triton.attention.fp8_mqa_logits") is not None: mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits" if mqa_logits_module_path is not None: @@ -455,8 +451,8 @@ def rocm_fp8_mqa_logits( if aiter_mqa_logits_module is not None: fp8_mqa_logits = aiter_mqa_logits_module.fp8_mqa_logits - kv, scale = kv - return fp8_mqa_logits(q, kv, scale, weights, cu_seqlen_ks, cu_seqlen_ke) + k_fp8, scale = kv + return fp8_mqa_logits(q, k_fp8, scale, weights, cu_seqlen_ks, cu_seqlen_ke) else: return fp8_mqa_logits_torch(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) @@ -523,12 +519,14 @@ def rocm_aiter_sparse_attn_indexer( total_seq_lens, topk_indices_buffer, ) - attn_metadata = attn_metadata[k_cache_prefix] - assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) - slot_mapping = attn_metadata.slot_mapping - has_decode = attn_metadata.num_decodes > 0 - has_prefill = attn_metadata.num_prefills > 0 - num_decode_tokens = attn_metadata.num_decode_tokens + layer_attn_metadata = attn_metadata[k_cache_prefix] + assert isinstance(layer_attn_metadata, DeepseekV32IndexerMetadata) + assert topk_indices_buffer is not None + assert scale_fmt is not None + slot_mapping = layer_attn_metadata.slot_mapping + has_decode = layer_attn_metadata.num_decodes > 0 + has_prefill = layer_attn_metadata.num_prefills > 0 + num_decode_tokens = layer_attn_metadata.num_decode_tokens ops.indexer_k_quant_and_cache( k, @@ -540,7 +538,8 @@ def rocm_aiter_sparse_attn_indexer( topk_indices_buffer[: hidden_states.shape[0]] = -1 if has_prefill: - prefill_metadata = attn_metadata.prefill + prefill_metadata = layer_attn_metadata.prefill + assert prefill_metadata is not None for chunk in prefill_metadata.chunks: k_fp8 = torch.empty( [chunk.total_seq_lens, head_dim], @@ -585,7 +584,8 @@ def rocm_aiter_sparse_attn_indexer( ) if has_decode: - decode_metadata = attn_metadata.decode + decode_metadata = layer_attn_metadata.decode + assert decode_metadata is not None # kv_cache size requirement [num_block, block_size, n_head, head_dim], # we only have [num_block, block_size, head_dim], kv_cache = kv_cache.unsqueeze(-2) diff --git a/vllm/v1/attention/ops/vit_attn_wrappers.py b/vllm/v1/attention/ops/vit_attn_wrappers.py index 6ffe110ad..2d3505f51 100644 --- a/vllm/v1/attention/ops/vit_attn_wrappers.py +++ b/vllm/v1/attention/ops/vit_attn_wrappers.py @@ -292,6 +292,9 @@ def flashinfer_wrapper( # RoPE has already made q and k contiguous. q, k = q.contiguous(), k.contiguous() + assert cu_seqlens is not None + assert max_seqlen is not None + assert sequence_lengths is not None assert len(cu_seqlens) % 2 == 0, "cu_seqlens must be divisible by 2" cu_seqlength = len(cu_seqlens) // 2 batch_offsets_qko = cu_seqlens[:cu_seqlength].view(-1, 1, 1, 1)