[CI] Fix mypy for vllm/v1/structured_output (#32722)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-01-22 22:55:51 -05:00
committed by GitHub
parent 5e4e0e51f4
commit 7ef5873752
18 changed files with 32 additions and 25 deletions

View File

@@ -51,7 +51,7 @@ class ReasoningParser:
return self.model_tokenizer.get_vocab()
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
"""
Check if the reasoning content ends in the input_ids.
@@ -68,7 +68,7 @@ class ReasoningParser:
"""
def is_reasoning_end_streaming(
self, input_ids: list[int], delta_ids: list[int]
self, input_ids: Sequence[int], delta_ids: Sequence[int]
) -> bool:
"""
Check if the reasoning content ends in the input_ids on a

View File

@@ -65,7 +65,7 @@ class BaseThinkingReasoningParser(ReasoningParser):
"think start/end tokens in the tokenizer!"
)
def is_reasoning_end(self, input_ids: list[int]) -> bool:
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
start_token_id = self.start_token_id
end_token_id = self.end_token_id
@@ -77,7 +77,7 @@ class BaseThinkingReasoningParser(ReasoningParser):
return False
def is_reasoning_end_streaming(
self, input_ids: list[int], delta_ids: list[int]
self, input_ids: Sequence[int], delta_ids: Sequence[int]
) -> bool:
end_token_id = self.end_token_id
return end_token_id in delta_ids

View File

@@ -41,7 +41,7 @@ class DeepSeekV3ReasoningParser(ReasoningParser):
return self._parser.is_reasoning_end(input_ids)
def is_reasoning_end_streaming(
self, input_ids: list[int], delta_ids: list[int]
self, input_ids: Sequence[int], delta_ids: Sequence[int]
) -> bool:
return self._parser.is_reasoning_end_streaming(input_ids, delta_ids)

View File

@@ -78,7 +78,7 @@ class GptOssReasoningParser(ReasoningParser):
self.reasoning_end_token_ids_suffix = self.model_tokenizer.encode("<|message|>")
self.reasoning_max_num_between_tokens = 20
def is_reasoning_end(self, input_ids: list[int]) -> bool:
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
end_token_ids_prefix = self.reasoning_end_token_ids_prefix
end_token_ids_suffix = self.reasoning_end_token_ids_suffix
assert len(end_token_ids_prefix) > 0, "reasoning_end_token_ids_prefix is empty"

View File

@@ -61,7 +61,7 @@ class Holo2ReasoningParser(ReasoningParser):
return self._parser.is_reasoning_end(input_ids)
def is_reasoning_end_streaming(
self, input_ids: list[int], delta_ids: list[int]
self, input_ids: Sequence[int], delta_ids: Sequence[int]
) -> bool:
return self._parser.is_reasoning_end_streaming(input_ids, delta_ids)

View File

@@ -79,7 +79,7 @@ class HunyuanA13BReasoningParser(ReasoningParser):
self.token_buffer = []
self.text_buffer = ""
def is_reasoning_end(self, input_ids: list[int]) -> bool:
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
return self.current_state == "response"
def extract_content_ids(self, input_ids: list[int]) -> list[int]:

View File

@@ -31,12 +31,12 @@ class IdentityReasoningParser(ReasoningParser):
"constructor during construction."
)
def is_reasoning_end(self, input_ids: list[int]) -> bool:
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
# Always return True, since we never treat reasoning specially
return True
def is_reasoning_end_streaming(
self, input_ids: list[int], delta_ids: list[int]
self, input_ids: Sequence[int], delta_ids: Sequence[int]
) -> bool:
return True

View File

@@ -88,7 +88,7 @@ class MiniMaxM2AppendThinkReasoningParser(ReasoningParser):
super().__init__(tokenizer, *args, **kwargs)
self.end_token_id = self.vocab.get("</think>")
def is_reasoning_end(self, input_ids: list[int]) -> bool:
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
end_token_id = self.end_token_id
return any(input_id == end_token_id for input_id in reversed(input_ids))

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from functools import cached_property
from vllm.entrypoints.openai.chat_completion.protocol import (
@@ -65,7 +66,7 @@ class MistralReasoningParser(BaseThinkingReasoningParser):
return SpecialTokens.end_think
def is_reasoning_end(self, input_ids: list[int]) -> bool:
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
has_eot_token = False
for id in input_ids[::-1]:

View File

@@ -242,7 +242,7 @@ class Olmo3ReasoningParser(ReasoningParser):
think_start=self.think_start, think_end=self.think_end
)
def is_reasoning_end(self, input_ids: list[int]) -> bool:
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
text = self.model_tokenizer.decode(input_ids)
return self.think_end in text

View File

@@ -100,11 +100,11 @@ class Step3ReasoningParser(ReasoningParser):
return reasoning, content
def is_reasoning_end(self, input_ids: list[int]) -> bool:
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
return self.think_end_token_id in input_ids
def is_reasoning_end_streaming(
self, input_ids: list[int], delta_ids: list[int]
self, input_ids: Sequence[int], delta_ids: Sequence[int]
) -> bool:
end_token_id = self.think_end_token_id
return end_token_id in delta_ids

View File

@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from vllm.logger import init_logger
from vllm.platforms import current_platform
@@ -17,15 +19,15 @@ elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
reshape_and_cache_flash = ipex_ops.reshape_and_cache_flash
flash_attn_varlen_func = ipex_ops.flash_attn_varlen_func
get_scheduler_metadata = ipex_ops.get_scheduler_metadata
flash_attn_varlen_func = ipex_ops.flash_attn_varlen_func # type: ignore[assignment]
get_scheduler_metadata = ipex_ops.get_scheduler_metadata # type: ignore[assignment]
elif current_platform.is_rocm():
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
except ImportError:
def flash_attn_varlen_func(*args, **kwargs):
def flash_attn_varlen_func(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef,misc]
raise ImportError(
"ROCm platform requires upstream flash-attn "
"to be installed. Please install flash-attn first."

View File

@@ -49,7 +49,7 @@ class AiterTritonMLAImpl(AiterMLAImpl):
def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
):
result = self.flash_attn_varlen_func(
result = self.flash_attn_varlen_func( # type: ignore[call-arg]
q,
k,
v,

View File

@@ -230,7 +230,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
def _flash_attn_varlen_diff_headdims(
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
):
output = self.flash_attn_varlen_func(
output = self.flash_attn_varlen_func( # type: ignore[call-arg]
q=q,
k=k,
v=v,

View File

@@ -294,7 +294,7 @@ class StructuredOutputManager:
assert request.structured_output_request is not None
if request.structured_output_request.reasoning_ended is None:
request.structured_output_request.reasoning_ended = (
self.reasoner.is_reasoning_end(request.prompt_token_ids)
self.reasoner.is_reasoning_end(request.prompt_token_ids or [])
)
return request.structured_output_request.reasoning_ended
return True
@@ -323,8 +323,9 @@ class StructuredOutputManager:
# Check if reasoning ends in *this* step
delta_from = request.num_computed_tokens - request.num_output_placeholders
all_token_ids = request.all_token_ids
if self.reasoner.is_reasoning_end_streaming(
request.all_token_ids, request.all_token_ids[delta_from:]
all_token_ids, all_token_ids[delta_from:]
):
# Reasoning just ended, so we shouldn't advance til
# next pass

View File

@@ -284,6 +284,9 @@ def serialize_guidance_grammar(
def validate_guidance_grammar(
sampling_params: SamplingParams, tokenizer: llguidance.LLTokenizer | None = None
) -> None:
# if structured output is not enabled, there is nothing to validate
if sampling_params.structured_outputs is None:
return
tp, grm = get_structured_output_key(sampling_params.structured_outputs)
guidance_grm = serialize_guidance_grammar(tp, grm)
err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer)

View File

@@ -69,7 +69,7 @@ class XgrammarBackend(StructuredOutputBackend):
if idx < vocab_size:
encoded_vocab[idx] = token
stop_token_ids = [self.tokenizer.eos_token_id]
backend_str = self.tokenizer.tokenizer.backend_tokenizer.to_str()
backend_str = self.tokenizer.tokenizer.backend_tokenizer.to_str() # type: ignore[attr-defined]
metadata = xgr.TokenizerInfo._detect_metadata_from_hf(backend_str)
tokenizer_info = xgr.TokenizerInfo(
encoded_vocab=encoded_vocab,