diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 0590f01b7..c987acfa3 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -481,6 +481,7 @@ th { | `Step3p5ForCausalLM` | Step-3.5-flash | `stepfun-ai/Step-3.5-Flash`, etc. | | ✅︎ | | `TeleChatForCausalLM` | TeleChat | `chuhac/TeleChat2-35B`, etc. | ✅︎ | ✅︎ | | `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ | +| `TeleChat3ForCausalLM` | TeleChat3 | `Tele-AI/TeleChat3-36B-Thinking`, `Tele-AI/TeleChat3-Coder-36B-Thinking`, etc. | ✅︎ | ✅︎ | | `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ | | `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ | | `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | | diff --git a/tests/models/registry.py b/tests/models/registry.py index f75dae4ef..acfc4786e 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -537,6 +537,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "TeleChat2ForCausalLM": _HfExamplesInfo( "Tele-AI/TeleChat2-3B", trust_remote_code=True ), + "TeleChat3ForCausalLM": _HfExamplesInfo( + "Tele-AI/TeleChat3-36B-Thinking", trust_remote_code=True + ), "TeleFLMForCausalLM": _HfExamplesInfo( "CofeAI/FLM-2-52B-Instruct-2407", trust_remote_code=True ), diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py index 28157daab..9a5418775 100644 --- a/vllm/model_executor/layers/rotary_embedding/__init__.py +++ b/vllm/model_executor/layers/rotary_embedding/__init__.py @@ -20,6 +20,7 @@ from .mrope import MRotaryEmbedding from .mrope_interleaved import MRotaryEmbeddingInterleaved from .ntk_scaling_rope import NTKScalingRotaryEmbedding from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding +from .telechat3_scaling_rope import TeleChat3RoPEScaledRotaryEmbedding from .xdrope import XDRotaryEmbedding from .yarn_scaling_rope import YaRNScalingRotaryEmbedding @@ -334,6 +335,36 @@ def get_rope( ) else: raise ValueError("Pangu mrope lacks necessary parameters.") + elif scaling_type == "telechat3-yarn": + scaling_factor = rope_parameters["factor"] + if "original_max_position_embeddings" in rope_parameters: + original_max_position = rope_parameters["original_max_position_embeddings"] + scaling_factor = max_position / original_max_position + else: + original_max_position = max_position + extra_kwargs = { + k: v + for k, v in rope_parameters.items() + if k + in ( + "extrapolation_factor", + "attn_factor", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ) + } + rotary_emb = TeleChat3RoPEScaledRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + dtype, + **extra_kwargs, + ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") _ROPE_DICT[key] = rotary_emb diff --git a/vllm/model_executor/layers/rotary_embedding/telechat3_scaling_rope.py b/vllm/model_executor/layers/rotary_embedding/telechat3_scaling_rope.py new file mode 100644 index 000000000..dd2fb9c32 --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/telechat3_scaling_rope.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math + +import torch + +from .base import RotaryEmbedding +from .yarn_scaling_rope import YaRNScalingRotaryEmbedding + + +class TeleChat3RoPEScaledRotaryEmbedding(YaRNScalingRotaryEmbedding): + """TeleChat3 uses a variant of YaRN method. + + To achieve code reuse as much as possible, we have rewritten the + `get_mscale` method in the initialization function + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + truncate: bool = True, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.truncate = truncate + + def get_mscale(scale, mscale=1): + if scale <= 1: + return 1.0 + return 0.07 * mscale * math.log(scale) + 1.0 + + self.mscale = float(get_mscale(self.scaling_factor) * attn_factor) + # Initialization must be performed after mscale, otherwise mscale is useless + RotaryEmbedding.__init__( + self, + head_size, + rotary_dim, + max_position_embeddings, + base, + is_neox_style, + dtype, + ) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 58c59a29a..d52a3e48a 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -206,6 +206,7 @@ _TEXT_GENERATION_MODELS = { "SolarForCausalLM": ("solar", "SolarForCausalLM"), "TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), + "TeleChat3ForCausalLM": ("llama", "LlamaForCausalLM"), "TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"), "XverseForCausalLM": ("llama", "LlamaForCausalLM"), "Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),