[CI] Fix mypy for vllm/v1/structured_output (#32722)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -17,15 +19,15 @@ elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
|
||||
reshape_and_cache_flash = ipex_ops.reshape_and_cache_flash
|
||||
flash_attn_varlen_func = ipex_ops.flash_attn_varlen_func
|
||||
get_scheduler_metadata = ipex_ops.get_scheduler_metadata
|
||||
flash_attn_varlen_func = ipex_ops.flash_attn_varlen_func # type: ignore[assignment]
|
||||
get_scheduler_metadata = ipex_ops.get_scheduler_metadata # type: ignore[assignment]
|
||||
|
||||
elif current_platform.is_rocm():
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func # noqa: F401
|
||||
from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
|
||||
except ImportError:
|
||||
|
||||
def flash_attn_varlen_func(*args, **kwargs):
|
||||
def flash_attn_varlen_func(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef,misc]
|
||||
raise ImportError(
|
||||
"ROCm platform requires upstream flash-attn "
|
||||
"to be installed. Please install flash-attn first."
|
||||
|
||||
@@ -49,7 +49,7 @@ class AiterTritonMLAImpl(AiterMLAImpl):
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||
):
|
||||
result = self.flash_attn_varlen_func(
|
||||
result = self.flash_attn_varlen_func( # type: ignore[call-arg]
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
|
||||
@@ -230,7 +230,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||
):
|
||||
output = self.flash_attn_varlen_func(
|
||||
output = self.flash_attn_varlen_func( # type: ignore[call-arg]
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
|
||||
Reference in New Issue
Block a user