[CI] Fix mypy for vllm/v1/structured_output (#32722)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -51,7 +51,7 @@ class ReasoningParser:
|
||||
return self.model_tokenizer.get_vocab()
|
||||
|
||||
@abstractmethod
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
"""
|
||||
Check if the reasoning content ends in the input_ids.
|
||||
|
||||
@@ -68,7 +68,7 @@ class ReasoningParser:
|
||||
"""
|
||||
|
||||
def is_reasoning_end_streaming(
|
||||
self, input_ids: list[int], delta_ids: list[int]
|
||||
self, input_ids: Sequence[int], delta_ids: Sequence[int]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the reasoning content ends in the input_ids on a
|
||||
|
||||
@@ -65,7 +65,7 @@ class BaseThinkingReasoningParser(ReasoningParser):
|
||||
"think start/end tokens in the tokenizer!"
|
||||
)
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
start_token_id = self.start_token_id
|
||||
end_token_id = self.end_token_id
|
||||
|
||||
@@ -77,7 +77,7 @@ class BaseThinkingReasoningParser(ReasoningParser):
|
||||
return False
|
||||
|
||||
def is_reasoning_end_streaming(
|
||||
self, input_ids: list[int], delta_ids: list[int]
|
||||
self, input_ids: Sequence[int], delta_ids: Sequence[int]
|
||||
) -> bool:
|
||||
end_token_id = self.end_token_id
|
||||
return end_token_id in delta_ids
|
||||
|
||||
@@ -41,7 +41,7 @@ class DeepSeekV3ReasoningParser(ReasoningParser):
|
||||
return self._parser.is_reasoning_end(input_ids)
|
||||
|
||||
def is_reasoning_end_streaming(
|
||||
self, input_ids: list[int], delta_ids: list[int]
|
||||
self, input_ids: Sequence[int], delta_ids: Sequence[int]
|
||||
) -> bool:
|
||||
return self._parser.is_reasoning_end_streaming(input_ids, delta_ids)
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ class GptOssReasoningParser(ReasoningParser):
|
||||
self.reasoning_end_token_ids_suffix = self.model_tokenizer.encode("<|message|>")
|
||||
self.reasoning_max_num_between_tokens = 20
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
end_token_ids_prefix = self.reasoning_end_token_ids_prefix
|
||||
end_token_ids_suffix = self.reasoning_end_token_ids_suffix
|
||||
assert len(end_token_ids_prefix) > 0, "reasoning_end_token_ids_prefix is empty"
|
||||
|
||||
@@ -61,7 +61,7 @@ class Holo2ReasoningParser(ReasoningParser):
|
||||
return self._parser.is_reasoning_end(input_ids)
|
||||
|
||||
def is_reasoning_end_streaming(
|
||||
self, input_ids: list[int], delta_ids: list[int]
|
||||
self, input_ids: Sequence[int], delta_ids: Sequence[int]
|
||||
) -> bool:
|
||||
return self._parser.is_reasoning_end_streaming(input_ids, delta_ids)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ class HunyuanA13BReasoningParser(ReasoningParser):
|
||||
self.token_buffer = []
|
||||
self.text_buffer = ""
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
return self.current_state == "response"
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
|
||||
@@ -31,12 +31,12 @@ class IdentityReasoningParser(ReasoningParser):
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
# Always return True, since we never treat reasoning specially
|
||||
return True
|
||||
|
||||
def is_reasoning_end_streaming(
|
||||
self, input_ids: list[int], delta_ids: list[int]
|
||||
self, input_ids: Sequence[int], delta_ids: Sequence[int]
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ class MiniMaxM2AppendThinkReasoningParser(ReasoningParser):
|
||||
super().__init__(tokenizer, *args, **kwargs)
|
||||
self.end_token_id = self.vocab.get("</think>")
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
end_token_id = self.end_token_id
|
||||
return any(input_id == end_token_id for input_id in reversed(input_ids))
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from functools import cached_property
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
@@ -65,7 +66,7 @@ class MistralReasoningParser(BaseThinkingReasoningParser):
|
||||
|
||||
return SpecialTokens.end_think
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
has_eot_token = False
|
||||
|
||||
for id in input_ids[::-1]:
|
||||
|
||||
@@ -242,7 +242,7 @@ class Olmo3ReasoningParser(ReasoningParser):
|
||||
think_start=self.think_start, think_end=self.think_end
|
||||
)
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
text = self.model_tokenizer.decode(input_ids)
|
||||
return self.think_end in text
|
||||
|
||||
|
||||
@@ -100,11 +100,11 @@ class Step3ReasoningParser(ReasoningParser):
|
||||
|
||||
return reasoning, content
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
|
||||
return self.think_end_token_id in input_ids
|
||||
|
||||
def is_reasoning_end_streaming(
|
||||
self, input_ids: list[int], delta_ids: list[int]
|
||||
self, input_ids: Sequence[int], delta_ids: Sequence[int]
|
||||
) -> bool:
|
||||
end_token_id = self.think_end_token_id
|
||||
return end_token_id in delta_ids
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -17,15 +19,15 @@ elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
|
||||
reshape_and_cache_flash = ipex_ops.reshape_and_cache_flash
|
||||
flash_attn_varlen_func = ipex_ops.flash_attn_varlen_func
|
||||
get_scheduler_metadata = ipex_ops.get_scheduler_metadata
|
||||
flash_attn_varlen_func = ipex_ops.flash_attn_varlen_func # type: ignore[assignment]
|
||||
get_scheduler_metadata = ipex_ops.get_scheduler_metadata # type: ignore[assignment]
|
||||
|
||||
elif current_platform.is_rocm():
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func # noqa: F401
|
||||
from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
|
||||
except ImportError:
|
||||
|
||||
def flash_attn_varlen_func(*args, **kwargs):
|
||||
def flash_attn_varlen_func(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef,misc]
|
||||
raise ImportError(
|
||||
"ROCm platform requires upstream flash-attn "
|
||||
"to be installed. Please install flash-attn first."
|
||||
|
||||
@@ -49,7 +49,7 @@ class AiterTritonMLAImpl(AiterMLAImpl):
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||
):
|
||||
result = self.flash_attn_varlen_func(
|
||||
result = self.flash_attn_varlen_func( # type: ignore[call-arg]
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
|
||||
@@ -230,7 +230,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||
):
|
||||
output = self.flash_attn_varlen_func(
|
||||
output = self.flash_attn_varlen_func( # type: ignore[call-arg]
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
|
||||
@@ -294,7 +294,7 @@ class StructuredOutputManager:
|
||||
assert request.structured_output_request is not None
|
||||
if request.structured_output_request.reasoning_ended is None:
|
||||
request.structured_output_request.reasoning_ended = (
|
||||
self.reasoner.is_reasoning_end(request.prompt_token_ids)
|
||||
self.reasoner.is_reasoning_end(request.prompt_token_ids or [])
|
||||
)
|
||||
return request.structured_output_request.reasoning_ended
|
||||
return True
|
||||
@@ -323,8 +323,9 @@ class StructuredOutputManager:
|
||||
|
||||
# Check if reasoning ends in *this* step
|
||||
delta_from = request.num_computed_tokens - request.num_output_placeholders
|
||||
all_token_ids = request.all_token_ids
|
||||
if self.reasoner.is_reasoning_end_streaming(
|
||||
request.all_token_ids, request.all_token_ids[delta_from:]
|
||||
all_token_ids, all_token_ids[delta_from:]
|
||||
):
|
||||
# Reasoning just ended, so we shouldn't advance til
|
||||
# next pass
|
||||
|
||||
@@ -284,6 +284,9 @@ def serialize_guidance_grammar(
|
||||
def validate_guidance_grammar(
|
||||
sampling_params: SamplingParams, tokenizer: llguidance.LLTokenizer | None = None
|
||||
) -> None:
|
||||
# if structured output is not enabled, there is nothing to validate
|
||||
if sampling_params.structured_outputs is None:
|
||||
return
|
||||
tp, grm = get_structured_output_key(sampling_params.structured_outputs)
|
||||
guidance_grm = serialize_guidance_grammar(tp, grm)
|
||||
err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer)
|
||||
|
||||
@@ -69,7 +69,7 @@ class XgrammarBackend(StructuredOutputBackend):
|
||||
if idx < vocab_size:
|
||||
encoded_vocab[idx] = token
|
||||
stop_token_ids = [self.tokenizer.eos_token_id]
|
||||
backend_str = self.tokenizer.tokenizer.backend_tokenizer.to_str()
|
||||
backend_str = self.tokenizer.tokenizer.backend_tokenizer.to_str() # type: ignore[attr-defined]
|
||||
metadata = xgr.TokenizerInfo._detect_metadata_from_hf(backend_str)
|
||||
tokenizer_info = xgr.TokenizerInfo(
|
||||
encoded_vocab=encoded_vocab,
|
||||
|
||||
Reference in New Issue
Block a user