[Bugfix] Fix Qwen-VL tokenizer implementation (#36140)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
(cherry picked from commit 7196348157)
This commit is contained in:
Cyrus Leung
2026-03-06 00:07:19 +08:00
committed by khluu
parent 9a474ce7a4
commit fa78ec8a72
9 changed files with 118 additions and 66 deletions

View File

@@ -29,7 +29,8 @@ def test_tokenizer_like_protocol():
_assert_tokenizer_like(tokenizer)
tokenizer = get_tokenizer(
"mistralai/Mistral-7B-Instruct-v0.3", tokenizer_mode="mistral"
"mistralai/Mistral-7B-Instruct-v0.3",
tokenizer_mode="mistral",
)
assert isinstance(tokenizer, MistralTokenizer)
_assert_tokenizer_like(tokenizer)
@@ -40,11 +41,20 @@ def test_tokenizer_like_protocol():
tokenizer = get_tokenizer("deepseek-ai/DeepSeek-V3", tokenizer_mode="deepseek_v32")
assert isinstance(tokenizer, HfTokenizer)
# Verify it's a fast tokenizer (required for FastIncrementalDetokenizer)
assert isinstance(tokenizer, PreTrainedTokenizerFast)
assert "DSV32" in tokenizer.__class__.__name__
_assert_tokenizer_like(tokenizer)
tokenizer = get_tokenizer(
"Qwen/Qwen-VL",
tokenizer_mode="qwen_vl",
trust_remote_code=True,
)
assert isinstance(tokenizer, HfTokenizer)
assert "WithoutImagePad" in tokenizer.__class__.__name__
@pytest.mark.parametrize("tokenizer_name", ["facebook/opt-125m", "gpt2"])
def test_tokenizer_revision(tokenizer_name: str):

View File

@@ -1321,6 +1321,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
- "slow" will always use the slow tokenizer.\n
- "mistral" will always use the tokenizer from `mistral_common`.\n
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
- "qwen_vl" will always use the tokenizer from `qwen_vl`.\n
- Other custom values can be supported via plugins.""",
)
parser.add_argument("--use-beam-search", action="store_true")

View File

@@ -126,6 +126,7 @@ class ModelConfig:
- "slow" will always use the slow tokenizer.\n
- "mistral" will always use the tokenizer from `mistral_common`.\n
- "deepseek_v32" will always use the tokenizer from `deepseek_v32`.\n
- "qwen_vl" will always use the tokenizer from `qwen_vl`.\n
- Other custom values can be supported via plugins."""
trust_remote_code: bool = False
"""Trust remote code (e.g., from HuggingFace) when downloading the model

View File

@@ -6,11 +6,9 @@
# Copyright (c) Alibaba Cloud.
"""Inference-only Qwen-VL model compatible with HuggingFace weights."""
import copy
import math
import unicodedata
from collections.abc import Callable, Collection, Mapping, Sequence, Set
from functools import lru_cache, partial
from collections.abc import Callable, Mapping, Sequence
from functools import partial
from typing import Annotated, Literal, TypeAlias
import regex as re
@@ -436,60 +434,6 @@ class QwenVLModel(QWenModel):
)
@lru_cache(maxsize=1)
def _get_tokenizer_without_image_pad(
tokenizer: PreTrainedTokenizer,
) -> PreTrainedTokenizer:
"""
The logic of adding image pad tokens should only be applied in
[`QwenVLProcessor`][vllm.model_executor.models.qwen_vl.QwenVLProcessor],
so they are patched out here.
The definition of the wrapped tokenizer can be found here:
https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py
"""
new_tokenizer = copy.deepcopy(tokenizer)
class TokenizerWithoutImagePad(tokenizer.__class__): # type: ignore
def tokenize(
self,
text: str,
allowed_special: Set[str] | str = "all",
disallowed_special: Collection[str] | str = (),
**kwargs,
) -> list[bytes | str]:
text = unicodedata.normalize("NFC", text)
return [
self.decoder[t]
for t in self.tokenizer.encode(
text,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)
]
def _decode(
self,
token_ids: int | list[int],
skip_special_tokens: bool = False,
errors: str | None = None,
**kwargs,
) -> str:
if isinstance(token_ids, int):
token_ids = [token_ids]
return self.tokenizer.decode(
token_ids,
errors=errors or self.errors,
)
TokenizerWithoutImagePad.__name__ = f"{tokenizer.__class__.__name__}WithoutImagePad"
new_tokenizer.__class__ = TokenizerWithoutImagePad
return new_tokenizer
class QwenVLProcessor:
"""
This model doesn't define its own HF processor,
@@ -574,12 +518,6 @@ class QwenVLProcessor:
class QwenVLProcessingInfo(BaseProcessingInfo):
def get_tokenizer(self) -> PreTrainedTokenizer:
tokenizer = self.ctx.get_tokenizer()
assert isinstance(tokenizer, PreTrainedTokenizer)
return _get_tokenizer_without_image_pad(tokenizer)
def get_hf_processor(self, **kwargs: object) -> QwenVLProcessor:
return self.ctx.init_processor(
QwenVLProcessor,

29
vllm/renderers/qwen_vl.py Normal file
View File

@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from vllm.config import VllmConfig
from vllm.tokenizers import cached_get_tokenizer
from vllm.tokenizers.qwen_vl import QwenVLTokenizer
from .base import BaseRenderer
from .hf import HfRenderer
class QwenVLRenderer(BaseRenderer[QwenVLTokenizer]):
@classmethod
def from_config( # type: ignore[override]
cls,
config: VllmConfig,
tokenizer_kwargs: dict[str, Any],
) -> "HfRenderer":
model_config = config.model_config
if model_config.skip_tokenizer_init:
tokenizer = None
else:
tokenizer = cached_get_tokenizer(
tokenizer_cls=QwenVLTokenizer,
**tokenizer_kwargs,
)
return HfRenderer(config, tokenizer)

View File

@@ -20,6 +20,7 @@ _VLLM_RENDERERS = {
"hf": ("hf", "HfRenderer"),
"grok2": ("grok2", "Grok2Renderer"),
"mistral": ("mistral", "MistralRenderer"),
"qwen_vl": ("qwen_vl", "QwenVLRenderer"),
"terratorch": ("terratorch", "TerratorchRenderer"),
}

View File

@@ -7,9 +7,9 @@ from transformers import AutoTokenizer
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from . import TokenizerLike
from .deepseek_v32_encoding import encode_messages
from .hf import HfTokenizer, get_cached_tokenizer
from .protocol import TokenizerLike
def get_deepseek_v32_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:

View File

@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import unicodedata
from collections.abc import Collection, Set
from transformers import AutoTokenizer
from .hf import HfTokenizer, get_cached_tokenizer
from .protocol import TokenizerLike
def get_qwen_vl_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
"""
The logic of adding image pad tokens should only be applied in
`QwenVLProcessor`, so they are patched out here.
The definition of the wrapped tokenizer can be found here:
https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py
"""
new_tokenizer = copy.copy(tokenizer)
class TokenizerWithoutImagePad(tokenizer.__class__): # type: ignore
def tokenize(
self,
text: str,
allowed_special: Set[str] | str = "all",
disallowed_special: Collection[str] | str = (),
**kwargs,
) -> list[bytes | str]:
text = unicodedata.normalize("NFC", text)
return [
self.decoder[t]
for t in self.tokenizer.encode(
text,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)
]
def _decode(
self,
token_ids: int | list[int],
skip_special_tokens: bool = False,
errors: str | None = None,
**kwargs,
) -> str:
if isinstance(token_ids, int):
token_ids = [token_ids]
return self.tokenizer.decode(
token_ids,
errors=errors or self.errors,
)
TokenizerWithoutImagePad.__name__ = f"{tokenizer.__class__.__name__}WithoutImagePad"
new_tokenizer.__class__ = TokenizerWithoutImagePad
return new_tokenizer
class QwenVLTokenizer(TokenizerLike):
@classmethod
def from_pretrained(cls, *args, **kwargs) -> HfTokenizer:
tokenizer = AutoTokenizer.from_pretrained(*args, **kwargs)
return get_cached_tokenizer(get_qwen_vl_tokenizer(tokenizer))

View File

@@ -36,6 +36,7 @@ _VLLM_TOKENIZERS = {
"grok2": ("grok2", "Grok2Tokenizer"),
"hf": ("hf", "CachedHfTokenizer"),
"mistral": ("mistral", "MistralTokenizer"),
"qwen_vl": ("qwen_vl", "QwenVLTokenizer"),
}
@@ -165,6 +166,10 @@ def resolve_tokenizer_args(
):
tokenizer_mode = "grok2"
# Model-specific tokenizers
if tokenizer_mode == "auto" and "/Qwen-VL" in str(tokenizer_name):
tokenizer_mode = "qwen_vl"
# Fallback to HF tokenizer
if tokenizer_mode == "auto":
tokenizer_mode = "hf"