diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 5c3668392..17aad591e 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -399,6 +399,7 @@ th { | `GraniteMoeSharedForCausalLM` | Granite MoE Shared | `ibm-research/moe-7b-1b-active-shared-experts` (test model) | ✅︎ | ✅︎ | | `GritLM` | GritLM | `parasail-ai/GritLM-7B-vllm`. | ✅︎ | ✅︎ | | `Grok1ModelForCausalLM` | Grok1 | `hpcai-tech/grok-1`. | ✅︎ | ✅︎ | +| `Grok1ForCausalLM` | Grok2 | `xai-org/grok-2` | ✅︎ | ✅︎ | | `HunYuanDenseV1ForCausalLM` | Hunyuan Dense | `tencent/Hunyuan-7B-Instruct` | ✅︎ | ✅︎ | | `HunYuanMoEV1ForCausalLM` | Hunyuan-A13B | `tencent/Hunyuan-A13B-Instruct`, `tencent/Hunyuan-A13B-Pretrain`, `tencent/Hunyuan-A13B-Instruct-FP8`, etc. | ✅︎ | ✅︎ | | `HCXVisionForCausalLM` | HyperCLOVAX-SEED-Vision-Instruct-3B | `naver-hyperclovax/HyperCLOVAX-SEED-Vision-Instruct-3B` | | | @@ -459,6 +460,9 @@ th { | `Zamba2ForCausalLM` | Zamba2 | `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. | | | | `LongcatFlashForCausalLM` | LongCat-Flash | `meituan-longcat/LongCat-Flash-Chat`, `meituan-longcat/LongCat-Flash-Chat-FP8` | ✅︎ | ✅︎ | +!!! note + Grok2 requires `tokenizer.tok.json` with `tiktoken` installed. You can optionally override MoE router renormalization with `moe_router_renormalize`. + Some models are supported only via the [Transformers modeling backend](#transformers). The purpose of the table below is to acknowledge models which we officially support in this way. The logs will say that the Transformers modeling backend is being used, and you will see no warning that this is fallback behaviour. This means that, if you have issues with any of the models listed below, please [make an issue](https://github.com/vllm-project/vllm/issues/new/choose) and we'll do our best to fix it! | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | diff --git a/tests/models/language/generation/test_grok.py b/tests/models/language/generation/test_grok.py new file mode 100644 index 000000000..a2f1e8b44 --- /dev/null +++ b/tests/models/language/generation/test_grok.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from ...utils import dummy_hf_overrides + +MODELS = ["xai-org/grok-2"] + + +def _grok2_dummy_overrides(hf_config): + hf_config = dummy_hf_overrides(hf_config, model_arch="Grok1ForCausalLM") + text_config = hf_config.get_text_config() + text_config.update( + { + "hidden_size": 256, + "intermediate_size": 512, + "moe_intermediate_size": 256, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "head_dim": 64, + } + ) + return hf_config + + +@pytest.mark.parametrize("model", MODELS) +def test_dummy_generate(vllm_runner, monkeypatch, model: str) -> None: + with monkeypatch.context() as m: + m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") + with vllm_runner( + model, + load_format="dummy", + max_model_len=128, + hf_overrides=_grok2_dummy_overrides, + enforce_eager=True, + ) as llm: + prompt = "Hello from Grok-2" + tokenizer = llm.get_llm().get_tokenizer() + prompt_len = len(tokenizer.encode(prompt)) + outputs = llm.generate_greedy([prompt], max_tokens=1) + output_ids, output_str = outputs[0] + assert len(output_ids) > prompt_len + assert output_str is not None diff --git a/tests/models/registry.py b/tests/models/registry.py index 9778678b3..790105e57 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -289,6 +289,7 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "Grok1ModelForCausalLM": _HfExamplesInfo( "hpcai-tech/grok-1", trust_remote_code=True ), + "Grok1ForCausalLM": _HfExamplesInfo("xai-org/grok-2", trust_remote_code=True), "HunYuanDenseV1ForCausalLM": _HfExamplesInfo("tencent/Hunyuan-7B-Instruct"), "HunYuanMoEV1ForCausalLM": _HfExamplesInfo( "tencent/Hunyuan-A13B-Instruct", trust_remote_code=True diff --git a/tests/tokenizers_/test_basic.py b/tests/tokenizers_/test_basic.py index 0510261ea..b5c26a659 100644 --- a/tests/tokenizers_/test_basic.py +++ b/tests/tokenizers_/test_basic.py @@ -10,6 +10,7 @@ from transformers import ( ) from vllm.tokenizers import TokenizerLike, get_tokenizer +from vllm.tokenizers.grok2 import Grok2Tokenizer from vllm.tokenizers.mistral import MistralTokenizer @@ -37,6 +38,10 @@ def test_tokenizer_like_protocol(): assert isinstance(tokenizer, MistralTokenizer) _assert_tokenizer_like(tokenizer) + tokenizer = get_tokenizer("xai-org/grok-2", tokenizer_mode="grok2") + assert isinstance(tokenizer, Grok2Tokenizer) + _assert_tokenizer_like(tokenizer) + @pytest.mark.parametrize("tokenizer_name", ["facebook/opt-125m", "gpt2"]) def test_tokenizer_revision(tokenizer_name: str): diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py index 2c41f2a11..43c658a2c 100644 --- a/vllm/model_executor/models/grok1.py +++ b/vllm/model_executor/models/grok1.py @@ -21,8 +21,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Inference-only Grok1 model.""" +"""Inference-only Grok (Grok1/Grok2) model.""" +import math from collections.abc import Iterable from itertools import islice from typing import Any @@ -35,9 +36,12 @@ from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( + MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, @@ -68,6 +72,100 @@ from .utils import ( DEFAULT_ATTN_OUTPUT_MULTIPLIER = 0.08838834764831845 DEFAULT_OUTPUT_MULTIPLIER_SCALE = 0.5773502691896257 DEFAULT_EMBEDDING_MULTIPLIER_SCALE = 78.38367176906169 +DEFAULT_ROUTER_LOGIT_SOFTCAP = 30.0 + +logger = init_logger(__name__) + + +def _get_num_experts(config) -> int: + return getattr(config, "num_experts", getattr(config, "num_local_experts", 8)) + + +def _get_moe_intermediate_size(config) -> int: + return getattr(config, "moe_intermediate_size", config.intermediate_size) + + +def _get_grok_version(config) -> str: + """Detect Grok version from HF config using multiple heuristics.""" + # Check for Grok2-specific attributes (both for robust detection) + has_residual_moe = getattr(config, "residual_moe", False) + has_moe_intermediate_size = hasattr(config, "moe_intermediate_size") + + if has_residual_moe or has_moe_intermediate_size: + return "grok2" + + return "grok1" # Default to Grok1 + + +def _get_rope_parameters(config) -> dict[str, Any] | None: + rope_parameters = getattr(config, "rope_parameters", None) + if rope_parameters is None: + rope_type = getattr(config, "rope_type", None) + if rope_type is None: + return None + rope_parameters = {"rope_type": rope_type} + rope_theta = getattr(config, "rope_theta", None) + if rope_theta is not None: + rope_parameters["rope_theta"] = rope_theta + scaling_factor = getattr(config, "scaling_factor", None) + if scaling_factor is not None: + rope_parameters["factor"] = scaling_factor + for name in ( + "original_max_position_embeddings", + "extrapolation_factor", + "attn_factor", + "beta_fast", + "beta_slow", + ): + value = getattr(config, name, None) + if value is not None: + rope_parameters[name] = value + + if rope_parameters.get("rope_type") == "original": + rope_parameters = dict(rope_parameters) + rope_parameters["rope_type"] = "default" + return rope_parameters + + +def _get_moe_renormalize(config) -> bool: + explicit_value = getattr( + config, "moe_router_renormalize", getattr(config, "moe_renormalize", None) + ) + if explicit_value is not None: + return bool(explicit_value) + return not getattr(config, "residual_moe", False) + + +class Grok1MLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = GeluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x class Grok1MoE(nn.Module): @@ -85,9 +183,11 @@ class Grok1MoE(nn.Module): top_k: int, hidden_size: int, intermediate_size: int, + router_logit_soft_cap: float = 0.0, params_dtype: torch.dtype | None = None, quant_config: QuantizationConfig | None = None, tp_size: int | None = None, + renormalize: bool = False, prefix: str = "", ): super().__init__() @@ -110,12 +210,13 @@ class Grok1MoE(nn.Module): intermediate_size=intermediate_size, params_dtype=params_dtype, reduce_results=True, - renormalize=True, + renormalize=renormalize, quant_config=quant_config, tp_size=tp_size, activation="gelu", prefix=f"{prefix}.experts", ) + self.router_logit_soft_cap = router_logit_soft_cap def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -123,7 +224,10 @@ class Grok1MoE(nn.Module): hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - router_logits = 30.0 * F.tanh(router_logits / 30.0) + if self.router_logit_soft_cap > 0: + router_logits = self.router_logit_soft_cap * F.tanh( + router_logits / self.router_logit_soft_cap + ) final_hidden_states = self.experts(hidden_states, router_logits) return final_hidden_states.view(orig_shape) @@ -187,6 +291,15 @@ class Grok1Attention(nn.Module): ) attn_logits_soft_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0) + attn_logit_softcapping_method = getattr( + config, "attn_logit_softcapping_method", None + ) + if attn_logit_softcapping_method not in (None, "tanh"): + logger.warning_once( + "Grok attention logit softcapping method '%s' is not " + "supported; falling back to default behavior.", + attn_logit_softcapping_method, + ) self.attn = Attention( self.num_heads, @@ -238,30 +351,50 @@ class Grok1DecoderLayer(nn.Module): num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, - rope_parameters=getattr(config, "rope_parameters", None), + rope_parameters=_get_rope_parameters(config), cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", config=config, ) # Pass config to Grok1Attention - # Grok1 uses "num_experts" in its config - num_experts = getattr(config, "num_experts", 8) + num_experts = _get_num_experts(config) num_experts_per_tok = getattr(config, "num_experts_per_tok", 2) + moe_intermediate_size = _get_moe_intermediate_size(config) + moe_renormalize = _get_moe_renormalize(config) self.moe_block = Grok1MoE( num_experts=num_experts, top_k=num_experts_per_tok, hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, + intermediate_size=moe_intermediate_size, + router_logit_soft_cap=max( + getattr( + config, + "router_logit_softcapping", + DEFAULT_ROUTER_LOGIT_SOFTCAP, + ), + 0.0, + ), quant_config=quant_config, + renormalize=moe_renormalize, prefix=f"{prefix}.moe_block", ) + self.residual_moe = getattr(config, "residual_moe", False) + self.residual_moe_scale = 1.0 / math.sqrt(2.0) self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = None + if self.residual_moe: + self.mlp = Grok1MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) def forward( self, @@ -286,7 +419,13 @@ class Grok1DecoderLayer(nn.Module): # MoE block with normalization hidden_states, residual = self.pre_moe_norm(hidden_states, residual) - hidden_states = self.moe_block(hidden_states) + if self.residual_moe: + assert self.mlp is not None + hidden_states = ( + self.moe_block(hidden_states) + self.mlp(hidden_states) + ) * self.residual_moe_scale + else: + hidden_states = self.moe_block(hidden_states) hidden_states = self.post_moe_norm(hidden_states) return hidden_states, residual @@ -294,7 +433,16 @@ class Grok1DecoderLayer(nn.Module): @support_torch_compile class Grok1Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + ckpt_gate_proj_name: str = "linear", + ckpt_down_proj_name: str = "linear_1", + ckpt_up_proj_name: str = "linear_v", + weight_name_remapping: dict[str, str] | None = None, + ): super().__init__() config = vllm_config.model_config.hf_config @@ -305,6 +453,12 @@ class Grok1Model(nn.Module): self.quant_config = quant_config self.padding_idx = config.pad_token_id + # Store expert naming for weight loading + self.ckpt_gate_proj_name = ckpt_gate_proj_name + self.ckpt_down_proj_name = ckpt_down_proj_name + self.ckpt_up_proj_name = ckpt_up_proj_name + self.weight_name_remapping = weight_name_remapping or {} + self.vocab_size = config.vocab_size self.embedding_multiplier_scale = getattr( @@ -365,14 +519,13 @@ class Grok1Model(nn.Module): return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - # Map Grok1's unique expert parameter names to standard names - # Grok1 uses "num_experts" in its config - num_experts = getattr(self.config, "num_experts", 8) + # Map expert parameter names to standard names + num_experts = _get_num_experts(self.config) return FusedMoE.make_expert_params_mapping( self, - ckpt_gate_proj_name="linear", # Grok1 specific - ckpt_down_proj_name="linear_1", # Grok1 specific - ckpt_up_proj_name="linear_v", # Grok1 specific + ckpt_gate_proj_name=self.ckpt_gate_proj_name, + ckpt_down_proj_name=self.ckpt_down_proj_name, + ckpt_up_proj_name=self.ckpt_up_proj_name, num_experts=num_experts, ) @@ -382,12 +535,18 @@ class Grok1Model(nn.Module): ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), + ("mlp.gate_up_proj", "mlp.gate_proj", 0), + ("mlp.gate_up_proj", "mlp.up_proj", 1), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: + # Apply version-specific weight name remapping + for old_pattern, new_pattern in self.weight_name_remapping.items(): + if old_pattern in name: + name = name.replace(old_pattern, new_pattern) if self.quant_config is not None and ( scale_name := self.quant_config.get_cache_scale(name) ): @@ -418,6 +577,8 @@ class Grok1Model(nn.Module): name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue + if name not in params_dict: + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -464,6 +625,8 @@ class Grok1Model(nn.Module): if "norm.scale" in name: name = name.replace("scale", "weight") + if name not in params_dict: + continue param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader @@ -473,9 +636,12 @@ class Grok1Model(nn.Module): return loaded_params -class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class GrokBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + """Base class for Grok models with shared logic.""" + fall_back_to_pt_during_load = False + # Subclasses should override these packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -484,6 +650,15 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ], } + # Expert weight naming - subclasses override these + ckpt_gate_proj_name: str = "linear" + ckpt_down_proj_name: str = "linear_1" + ckpt_up_proj_name: str = "linear_v" + + def get_weight_name_remapping(self) -> dict[str, str]: + """Return weight name remapping for this version. Override in subclasses.""" + return {} + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -491,11 +666,15 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): quant_config = vllm_config.quant_config self.config = config - self.quant_config = quant_config self.model = Grok1Model( - vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ckpt_gate_proj_name=self.ckpt_gate_proj_name, + ckpt_down_proj_name=self.ckpt_down_proj_name, + ckpt_up_proj_name=self.ckpt_up_proj_name, + weight_name_remapping=self.get_weight_name_remapping(), ) self.lm_head = ParallelLMHead( @@ -512,7 +691,9 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): config, "output_multiplier_scale", DEFAULT_OUTPUT_MULTIPLIER_SCALE ) self.logits_processor = LogitsProcessor( - config.vocab_size, scale=self.output_multiplier_scale + config.vocab_size, + scale=self.output_multiplier_scale, + soft_cap=getattr(config, "final_logit_softcapping", None), ) self.make_empty_intermediate_tensors = ( @@ -553,3 +734,70 @@ class Grok1ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() + + +class Grok1ForCausalLM(GrokBaseForCausalLM): + """Grok1-specific implementation.""" + + # Grok1 expert weight naming + ckpt_gate_proj_name = "linear" + ckpt_down_proj_name = "linear_1" + ckpt_up_proj_name = "linear_v" + + def get_weight_name_remapping(self) -> dict[str, str]: + # Grok1 uses standard naming, no remapping needed + return {} + + +class Grok2ForCausalLM(GrokBaseForCausalLM): + """Grok2-specific implementation.""" + + # Grok2 has additional packed modules for MLP + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # Grok2 expert weight naming + ckpt_gate_proj_name = "w1" + ckpt_down_proj_name = "w2" + ckpt_up_proj_name = "w3" + + def get_weight_name_remapping(self) -> dict[str, str]: + # Grok2 checkpoint uses different naming conventions + return { + ".self_attn.": ".attn.", + ".block_sparse_moe.": ".moe_block.", + } + + +# Version dispatch mapping +_GROK_VERSIONS: dict[str, type[GrokBaseForCausalLM]] = { + "grok1": Grok1ForCausalLM, + "grok2": Grok2ForCausalLM, +} + + +class GrokForCausalLM(GrokBaseForCausalLM): + """Factory class that dispatches to version-specific implementation.""" + + def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""): + config = vllm_config.model_config.hf_config + version = _get_grok_version(config) + + instance_cls = _GROK_VERSIONS.get(version) + if instance_cls is None: + raise ValueError(f"Unsupported Grok version: {version}") + + # Merge class attributes for LoRA/quantization compatibility + cls.packed_modules_mapping = dict(cls.packed_modules_mapping) + cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping) + + return instance_cls(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index a25267fc2..f50651512 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -119,7 +119,8 @@ _TEXT_GENERATION_MODELS = { "GraniteMoeHybridForCausalLM": ("granitemoehybrid", "GraniteMoeHybridForCausalLM"), # noqa: E501 "GraniteMoeSharedForCausalLM": ("granitemoeshared", "GraniteMoeSharedForCausalLM"), # noqa: E501 "GritLM": ("gritlm", "GritLM"), - "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"), + "Grok1ModelForCausalLM": ("grok1", "GrokForCausalLM"), + "Grok1ForCausalLM": ("grok1", "GrokForCausalLM"), "HunYuanMoEV1ForCausalLM": ("hunyuan_v1", "HunYuanMoEV1ForCausalLM"), "HunYuanDenseV1ForCausalLM": ("hunyuan_v1", "HunYuanDenseV1ForCausalLM"), "HCXVisionForCausalLM": ("hyperclovax_vision", "HCXVisionForCausalLM"), diff --git a/vllm/tokenizers/grok2.py b/vllm/tokenizers/grok2.py new file mode 100644 index 000000000..a4071908d --- /dev/null +++ b/vllm/tokenizers/grok2.py @@ -0,0 +1,443 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Tokenizer for Grok-2 .tok.json format.""" + +import functools +import json +from collections.abc import Collection, Set +from pathlib import Path +from typing import Any, Literal, overload + +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import ( + EntryNotFoundError, + HfHubHTTPError, + RepositoryNotFoundError, + RevisionNotFoundError, +) +from transformers import BatchEncoding +from transformers.utils import chat_template_utils as hf_chat_utils + +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.logger import init_logger + +from .protocol import TokenizerLike + +logger = init_logger(__name__) + +PAD = "<|pad|>" +EOS = "<|eos|>" +SEP = "<|separator|>" +RESERVED_TOKEN_TEXTS = [f"<|reserved_{i}|>" for i in range(3, 128)] +CONTROL_TOKEN_TEXTS = [f"<|control{i}|>" for i in range(1, 705)] +DEFAULT_SPECIAL_TOKENS = [PAD, SEP, EOS] +DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": SEP, "eos": EOS} +DEFAULT_CHAT_TEMPLATE = ( + "{% for message in messages %}" + "{% if message['role'] == 'user' %}" + "{{ 'Human: ' + message['content'].strip() + '<|separator|>\\n\\n' }}" + "{% elif message['role'] == 'system' %}" + "{{ 'System: ' + message['content'].strip() + '<|separator|>\\n\\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ 'Assistant: ' + message['content'] + '<|separator|>\\n\\n' }}" + "{% endif %}" + "{% endfor %}" + "{% if add_generation_prompt %}" + "{{ 'Assistant:' }}" + "{% endif %}" +) + +# Default + separate each single digit. +PAT_STR_B = ( + r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}|""" + r""" ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" +) + + +def _maybe_load_tokenizer_config( + model_path: Path, + *, + repo_id: str | None, + revision: str | None, + download_dir: str | None, +) -> dict[str, Any]: + config_path = model_path / "tokenizer_config.json" + if config_path.is_file(): + with config_path.open("r", encoding="utf-8") as f: + return json.load(f) + + if repo_id is None: + return {} + + try: + config_file = hf_hub_download( + repo_id=repo_id, + filename="tokenizer_config.json", + revision=revision, + cache_dir=download_dir, + ) + except (RepositoryNotFoundError, RevisionNotFoundError, EntryNotFoundError): + # If the repo, revision, or file does not exist, fall back silently. + return {} + except HfHubHTTPError as exc: + logger.warning( + "Failed to download tokenizer_config.json from %s. " + "This may be due to a network or authentication issue. " + "The default chat template will be used. Error: %s", + repo_id, + exc, + ) + return {} + + try: + with Path(config_file).open("r", encoding="utf-8") as f: + return json.load(f) + except json.JSONDecodeError as exc: + logger.warning( + "Failed to parse tokenizer_config.json. " + "The default chat template will be used. Error: %s", + exc, + ) + return {} + except OSError as exc: + logger.warning( + "Failed to open tokenizer_config.json. " + "The default chat template will be used. Error: %s", + exc, + ) + return {} + + +def _load_tiktoken_encoding( + vocab_file: Path, +) -> tuple[Any, dict[str, int]]: + try: + import tiktoken + except ImportError as exc: + raise ImportError("Grok-2 tokenizer requires the `tiktoken` package.") from exc + + with vocab_file.open("rb") as f: + xtok_dict = json.load(f) + + mergeable_ranks = { + bytes(item["bytes"]): item["token"] + for item in xtok_dict.get("regular_tokens", []) + } + special_tokens = { + bytes(item["bytes"]).decode("utf-8", errors="replace"): item["token"] + for item in xtok_dict.get("special_tokens", []) + } + + if xtok_dict.get("word_split") == "V1": + pat_str = PAT_STR_B + else: + raise ValueError(f"Unknown word_split: {xtok_dict.get('word_split')!r}") + + pat_str = xtok_dict.get("pat_str", pat_str) + + kwargs = { + "name": str(vocab_file), + "pat_str": pat_str, + "mergeable_ranks": mergeable_ranks, + "special_tokens": special_tokens, + } + + if "vocab_size" in xtok_dict: + kwargs["explicit_n_vocab"] = xtok_dict["vocab_size"] + + tokenizer = tiktoken.Encoding(**kwargs) + + default_allowed_special: set[str] | None = None + if "default_allowed_special" in xtok_dict: + default_allowed_special = { + bytes(bytes_list).decode("utf-8", errors="replace") + for bytes_list in xtok_dict["default_allowed_special"] + } + + tokenizer._default_allowed_special = default_allowed_special or set() + tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS + + def encode_patched( + self, + text: str, + *, + allowed_special: Literal["all"] | Set[str] = set(), + disallowed_special: Literal["all"] | Collection[str] = "all", + ) -> list[int]: + del disallowed_special + if isinstance(allowed_special, set): + allowed_special |= self._default_allowed_special + return tiktoken.Encoding.encode( + self, + text, + allowed_special=allowed_special, + disallowed_special=(), + ) + + tokenizer.encode = functools.partial(encode_patched, tokenizer) + tokenizer._default_allowed_special |= set(DEFAULT_CONTROL_TOKENS.values()) + tokenizer._default_allowed_special |= set( + CONTROL_TOKEN_TEXTS + RESERVED_TOKEN_TEXTS + ) + + return tokenizer, special_tokens + + +class Grok2Tokenizer(TokenizerLike): + @classmethod + def from_pretrained( + cls, + path_or_repo_id: str | Path, + *args, + trust_remote_code: bool = False, + revision: str | None = None, + download_dir: str | None = None, + **kwargs, + ) -> "Grok2Tokenizer": + if args: + logger.debug_once("Ignoring extra positional args for Grok2Tokenizer.") + + path = Path(path_or_repo_id) + if path.is_file(): + vocab_file = path + model_path = path.parent + repo_id = None + elif path.is_dir(): + vocab_file = path / "tokenizer.tok.json" + model_path = path + repo_id = None + else: + vocab_file = Path( + hf_hub_download( + repo_id=str(path_or_repo_id), + filename="tokenizer.tok.json", + revision=revision, + cache_dir=download_dir, + ) + ) + model_path = vocab_file.parent + repo_id = str(path_or_repo_id) + + if not vocab_file.is_file(): + raise FileNotFoundError(f"tokenizer.tok.json not found at {vocab_file}.") + + config = _maybe_load_tokenizer_config( + model_path, + repo_id=repo_id, + revision=revision, + download_dir=download_dir, + ) + + return cls( + vocab_file=vocab_file, + name_or_path=str(path_or_repo_id), + truncation_side=kwargs.get("truncation_side", "left"), + chat_template=config.get("chat_template"), + init_kwargs=config, + ) + + def __init__( + self, + *, + vocab_file: Path, + name_or_path: str, + truncation_side: str, + chat_template: str | None, + init_kwargs: dict[str, Any] | None = None, + ) -> None: + super().__init__() + self.name_or_path = name_or_path + self._truncation_side = truncation_side + self.init_kwargs = init_kwargs or {} + self._chat_template = chat_template or DEFAULT_CHAT_TEMPLATE + + self._tokenizer, self._special_tokens = _load_tiktoken_encoding(vocab_file) + + self._token_to_id: dict[str, int] = {} + self._id_to_token: dict[int, str] = {} + for token, token_id in self._tokenizer._mergeable_ranks.items(): + token_str = token.decode("utf-8", errors="replace") + self._token_to_id[token_str] = token_id + self._id_to_token[token_id] = token_str + + for token, token_id in self._special_tokens.items(): + self._token_to_id[token] = token_id + self._id_to_token[token_id] = token + + bos_token_id = self._special_tokens.get(SEP) + if bos_token_id is None: + bos_token_id = self._special_tokens.get(PAD) + if bos_token_id is None: + bos_token_id = self._special_tokens.get(EOS) + if bos_token_id is None: + bos_token_id = 0 + self._bos_token_id = bos_token_id + + self._eos_token_id = self._special_tokens.get(EOS, self._bos_token_id) + self._pad_token_id = self._special_tokens.get(PAD, self._eos_token_id) + self._unk_token_id = self._pad_token_id + + def num_special_tokens_to_add(self) -> int: + return 0 + + @property + def all_special_tokens(self) -> list[str]: + return list(self._special_tokens.keys()) + + @property + def all_special_ids(self) -> list[int]: + return list(self._special_tokens.values()) + + @property + def bos_token_id(self) -> int: + return self._bos_token_id + + @property + def eos_token_id(self) -> int: + return self._eos_token_id + + @property + def pad_token_id(self) -> int: + return self._pad_token_id + + @property + def is_fast(self) -> bool: + return False + + @property + def vocab_size(self) -> int: + return self._tokenizer.n_vocab + + @property + def max_token_id(self) -> int: + return self._tokenizer.n_vocab - 1 + + @property + def truncation_side(self) -> str: + return self._truncation_side + + def get_vocab(self) -> dict[str, int]: + return dict(self._token_to_id) + + def get_added_vocab(self) -> dict[str, int]: + return dict(self._special_tokens) + + def _maybe_truncate(self, tokens: list[int], max_length: int | None) -> list[int]: + if max_length is None or len(tokens) <= max_length: + return tokens + if self.truncation_side == "left": + return tokens[-max_length:] + return tokens[:max_length] + + def encode( + self, + text: str, + truncation: bool | None = None, + max_length: int | None = None, + add_special_tokens: bool = True, + ) -> list[int]: + del add_special_tokens + tokens = self._tokenizer.encode(text) + if truncation: + tokens = self._maybe_truncate(tokens, max_length) + return tokens + + def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str: + if isinstance(ids, int): + ids = [ids] + if skip_special_tokens: + ids = [ + token_id + for token_id in ids + if token_id not in self._special_tokens.values() + ] + return self._tokenizer.decode(ids) + + @overload + def convert_tokens_to_ids(self, tokens: str) -> int: ... + + @overload + def convert_tokens_to_ids(self, tokens: list[str]) -> list[int]: ... + + def convert_tokens_to_ids(self, tokens: str | list[str]) -> int | list[int]: + if isinstance(tokens, str): + return self._token_to_id.get(tokens, self._unk_token_id) + return [self._token_to_id.get(token, self._unk_token_id) for token in tokens] + + def convert_ids_to_tokens( + self, ids: list[int], skip_special_tokens: bool = False + ) -> list[str]: + tokens = [] + for token_id in ids: + if skip_special_tokens and token_id in self._special_tokens.values(): + continue + tokens.append(self._id_to_token.get(token_id, "<|unk|>")) + return tokens + + def convert_tokens_to_string(self, tokens: list[str]) -> str: + token_ids = self.convert_tokens_to_ids(tokens) + return self.decode(token_ids, skip_special_tokens=False) + + def __call__( + self, + text: str | list[str], + text_pair: str | None = None, + add_special_tokens: bool = True, + truncation: bool = False, + max_length: int | None = None, + ) -> BatchEncoding: + if text_pair is not None: + raise NotImplementedError("text_pair is not supported for Grok2Tokenizer.") + + if isinstance(text, list): + input_ids_batch: list[list[int]] = [ + self.encode( + item, + truncation=truncation, + max_length=max_length, + add_special_tokens=add_special_tokens, + ) + for item in text + ] + attention_mask_batch = [[1] * len(ids) for ids in input_ids_batch] + return BatchEncoding( + {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} + ) + + input_ids = self.encode( + text, + truncation=truncation, + max_length=max_length, + add_special_tokens=add_special_tokens, + ) + attention_mask = [1] * len(input_ids) + return BatchEncoding({"input_ids": input_ids, "attention_mask": attention_mask}) + + def get_chat_template( + self, chat_template: str | None, tools: list[dict[str, Any]] | None = None + ) -> str | None: + del tools + return chat_template or self._chat_template + + def apply_chat_template( + self, + messages: list[ChatCompletionMessageParam], + tools: list[dict[str, Any]] | None = None, + chat_template: str | None = None, + tokenize: bool = False, + **kwargs, + ) -> str | list[int]: + template = self.get_chat_template(chat_template, tools=tools) + if template is None: + raise ValueError( + "No chat template available. Provide `chat_template` explicitly." + ) + prompt = hf_chat_utils.apply_chat_template( + conversation=messages, + chat_template=template, + tools=tools, + **kwargs, + ) + if tokenize: + return self.encode(prompt, add_special_tokens=False) + return prompt diff --git a/vllm/tokenizers/registry.py b/vllm/tokenizers/registry.py index f8610bb47..b5088a116 100644 --- a/vllm/tokenizers/registry.py +++ b/vllm/tokenizers/registry.py @@ -31,6 +31,7 @@ logger = init_logger(__name__) _VLLM_TOKENIZERS = { "deepseek_v32": ("deepseek_v32", "DeepseekV32Tokenizer"), + "grok2": ("grok2", "Grok2Tokenizer"), "hf": ("hf", "CachedHfTokenizer"), "mistral": ("mistral", "MistralTokenizer"), } @@ -151,6 +152,17 @@ def resolve_tokenizer_args( if len(files_list) > 0: tokenizer_mode = "mistral" + # Try to use Grok2 tiktoken tokenizer if possible + if tokenizer_mode == "auto": + allow_patterns = ["tokenizer.tok.json"] + files_list = list_filtered_repo_files( + model_name_or_path=str(tokenizer_name), + allow_patterns=allow_patterns, + revision=revision, + ) + if len(files_list) > 0: + tokenizer_mode = "grok2" + # Fallback to HF tokenizer if tokenizer_mode == "auto": tokenizer_mode = "hf"