[New Model]: add support for telechat3 (#38510)
Signed-off-by: xiayongqiang <xiayq1@chinatelecom.cn> Co-authored-by: xiayongqiang <xiayq1@chinatelecom.cn>
This commit is contained in:
@@ -481,6 +481,7 @@ th {
|
||||
| `Step3p5ForCausalLM` | Step-3.5-flash | `stepfun-ai/Step-3.5-Flash`, etc. | | ✅︎ |
|
||||
| `TeleChatForCausalLM` | TeleChat | `chuhac/TeleChat2-35B`, etc. | ✅︎ | ✅︎ |
|
||||
| `TeleChat2ForCausalLM` | TeleChat2 | `Tele-AI/TeleChat2-3B`, `Tele-AI/TeleChat2-7B`, `Tele-AI/TeleChat2-35B`, etc. | ✅︎ | ✅︎ |
|
||||
| `TeleChat3ForCausalLM` | TeleChat3 | `Tele-AI/TeleChat3-36B-Thinking`, `Tele-AI/TeleChat3-Coder-36B-Thinking`, etc. | ✅︎ | ✅︎ |
|
||||
| `TeleFLMForCausalLM` | TeleFLM | `CofeAI/FLM-2-52B-Instruct-2407`, `CofeAI/Tele-FLM`, etc. | ✅︎ | ✅︎ |
|
||||
| `XverseForCausalLM` | XVERSE | `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. | ✅︎ | ✅︎ |
|
||||
| `MiniMaxM1ForCausalLM` | MiniMax-Text | `MiniMaxAI/MiniMax-M1-40k`, `MiniMaxAI/MiniMax-M1-80k`, etc. | | |
|
||||
|
||||
@@ -537,6 +537,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"TeleChat2ForCausalLM": _HfExamplesInfo(
|
||||
"Tele-AI/TeleChat2-3B", trust_remote_code=True
|
||||
),
|
||||
"TeleChat3ForCausalLM": _HfExamplesInfo(
|
||||
"Tele-AI/TeleChat3-36B-Thinking", trust_remote_code=True
|
||||
),
|
||||
"TeleFLMForCausalLM": _HfExamplesInfo(
|
||||
"CofeAI/FLM-2-52B-Instruct-2407", trust_remote_code=True
|
||||
),
|
||||
|
||||
@@ -20,6 +20,7 @@ from .mrope import MRotaryEmbedding
|
||||
from .mrope_interleaved import MRotaryEmbeddingInterleaved
|
||||
from .ntk_scaling_rope import NTKScalingRotaryEmbedding
|
||||
from .phi3_long_rope_scaled_rope import Phi3LongRoPEScaledRotaryEmbedding
|
||||
from .telechat3_scaling_rope import TeleChat3RoPEScaledRotaryEmbedding
|
||||
from .xdrope import XDRotaryEmbedding
|
||||
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding
|
||||
|
||||
@@ -334,6 +335,36 @@ def get_rope(
|
||||
)
|
||||
else:
|
||||
raise ValueError("Pangu mrope lacks necessary parameters.")
|
||||
elif scaling_type == "telechat3-yarn":
|
||||
scaling_factor = rope_parameters["factor"]
|
||||
if "original_max_position_embeddings" in rope_parameters:
|
||||
original_max_position = rope_parameters["original_max_position_embeddings"]
|
||||
scaling_factor = max_position / original_max_position
|
||||
else:
|
||||
original_max_position = max_position
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_parameters.items()
|
||||
if k
|
||||
in (
|
||||
"extrapolation_factor",
|
||||
"attn_factor",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"mscale",
|
||||
"mscale_all_dim",
|
||||
)
|
||||
}
|
||||
rotary_emb = TeleChat3RoPEScaledRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
original_max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
scaling_factor,
|
||||
dtype,
|
||||
**extra_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
_ROPE_DICT[key] = rotary_emb
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from .base import RotaryEmbedding
|
||||
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding
|
||||
|
||||
|
||||
class TeleChat3RoPEScaledRotaryEmbedding(YaRNScalingRotaryEmbedding):
|
||||
"""TeleChat3 uses a variant of YaRN method.
|
||||
|
||||
To achieve code reuse as much as possible, we have rewritten the
|
||||
`get_mscale` method in the initialization function
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
scaling_factor: float,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
extrapolation_factor: float = 1,
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
truncate: bool = True,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
self.attn_factor = attn_factor
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
self.truncate = truncate
|
||||
|
||||
def get_mscale(scale, mscale=1):
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.07 * mscale * math.log(scale) + 1.0
|
||||
|
||||
self.mscale = float(get_mscale(self.scaling_factor) * attn_factor)
|
||||
# Initialization must be performed after mscale, otherwise mscale is useless
|
||||
RotaryEmbedding.__init__(
|
||||
self,
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
)
|
||||
@@ -206,6 +206,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
|
||||
"TeleChatForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
||||
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
||||
"TeleChat3ForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"TeleFLMForCausalLM": ("teleflm", "TeleFLMForCausalLM"),
|
||||
"XverseForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"Zamba2ForCausalLM": ("zamba2", "Zamba2ForCausalLM"),
|
||||
|
||||
Reference in New Issue
Block a user