GLM-4 Update (#20736)
Signed-off-by: zRzRzRzRzRzRzR <2448370773@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Lu Fang <fanglu@fb.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Lu Fang <fanglu@fb.com>
This commit is contained in:
@@ -576,7 +576,11 @@ def main(args: argparse.Namespace):
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"):
|
||||
elif config.architectures[0] in (
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekV2ForCausalLM",
|
||||
"Glm4MoeForCausalLM",
|
||||
):
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
|
||||
@@ -318,6 +318,7 @@ def main(args: argparse.Namespace):
|
||||
elif (
|
||||
config.architectures[0] == "DeepseekV3ForCausalLM"
|
||||
or config.architectures[0] == "DeepseekV2ForCausalLM"
|
||||
or config.architectures[0] == "Glm4MoeForCausalLM"
|
||||
):
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
|
||||
@@ -576,6 +576,7 @@ Specified using `--task generate`.
|
||||
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
|
||||
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `THUDM/glm-4v-9b`, `THUDM/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `THUDM/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Glm4MoeForCausalLM` | GLM-4.5 | T + I<sup>E+</sup> + V<sup>E+</sup> | `THUDM/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `GraniteSpeechForConditionalGeneration` | Granite Speech | T + A | `ibm-granite/granite-speech-3.3-8b` | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `H2OVLChatModel` | H2OVL | T + I<sup>E+</sup> | `h2oai/h2ovl-mississippi-800m`, `h2oai/h2ovl-mississippi-2b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Idefics3ForConditionalGeneration` | Idefics3 | T + I | `HuggingFaceM4/Idefics3-8B-Llama3`, etc. | ✅︎ | | ✅︎ |
|
||||
|
||||
@@ -360,6 +360,9 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
trust_remote_code=True,
|
||||
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501
|
||||
"Glm4vForConditionalGeneration": _HfExamplesInfo("THUDM/GLM-4.1V-9B-Thinking", min_transformers_version="4.53"), # noqa: E501
|
||||
"Glm4MoeForCausalLM": _HfExamplesInfo("THUDM/GLM-4.5",
|
||||
min_transformers_version="4.54",
|
||||
is_available_online=False), # noqa: E501
|
||||
"H2OVLChatModel": _HfExamplesInfo("h2oai/h2ovl-mississippi-800m",
|
||||
extras={"2b": "h2oai/h2ovl-mississippi-2b"}, # noqa: E501
|
||||
max_transformers_version="4.48", # noqa: E501
|
||||
@@ -485,6 +488,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
is_available_online=False,
|
||||
speculative_model="openbmb/MiniCPM-2B-sft-bf16",
|
||||
tokenizer="openbmb/MiniCPM-2B-sft-bf16"),
|
||||
"Glm4MoeMTPModel": _HfExamplesInfo("THUDM/GLM-4.5",
|
||||
speculative_model="THUDM/GLM-4.5",
|
||||
min_transformers_version="4.54",
|
||||
is_available_online=False),
|
||||
"MiMoMTPModel": _HfExamplesInfo("XiaomiMiMo/MiMo-7B-RL",
|
||||
trust_remote_code=True,
|
||||
speculative_model="XiaomiMiMo/MiMo-7B-RL")
|
||||
|
||||
410
tests/tool_use/test_glm4_moe_tool_parser.py
Normal file
410
tests/tool_use/test_glm4_moe_tool_parser.py
Normal file
@@ -0,0 +1,410 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: E501
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.protocol import FunctionCall, ToolCall
|
||||
from vllm.entrypoints.openai.tool_parsers import Glm4MoeModelToolParser
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
pytest.skip("skip glm4_moe parser test", allow_module_level=True)
|
||||
# Use a common model that is likely to be available
|
||||
MODEL = "THUDM/GLM-4.5"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def glm4_moe_tokenizer():
|
||||
return get_tokenizer(tokenizer_name=MODEL)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def glm4_moe_tool_parser(glm4_moe_tokenizer):
|
||||
return Glm4MoeModelToolParser(glm4_moe_tokenizer)
|
||||
|
||||
|
||||
def assert_tool_calls(actual_tool_calls: list[ToolCall],
|
||||
expected_tool_calls: list[ToolCall]):
|
||||
assert len(actual_tool_calls) == len(expected_tool_calls)
|
||||
|
||||
for actual_tool_call, expected_tool_call in zip(actual_tool_calls,
|
||||
expected_tool_calls):
|
||||
assert isinstance(actual_tool_call.id, str)
|
||||
assert len(actual_tool_call.id) > 0
|
||||
|
||||
assert actual_tool_call.type == "function"
|
||||
assert actual_tool_call.function.name == expected_tool_call.function.name
|
||||
# Compare arguments as JSON objects to handle formatting differences
|
||||
actual_args = json.loads(actual_tool_call.function.arguments)
|
||||
expected_args = json.loads(expected_tool_call.function.arguments)
|
||||
assert actual_args == expected_args
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(glm4_moe_tool_parser):
|
||||
model_output = "This is a test"
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
ids=[
|
||||
"single_tool_call",
|
||||
"multiple_tool_calls",
|
||||
"tool_call_with_content_before",
|
||||
"tool_call_with_mixed_args",
|
||||
"tool_call_with_chinese_content",
|
||||
],
|
||||
argnames=["model_output", "expected_tool_calls", "expected_content"],
|
||||
argvalues=[
|
||||
(
|
||||
"""<tool_call>get_current_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Dallas</arg_value>
|
||||
<arg_key>state</arg_key>
|
||||
<arg_value>TX</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>fahrenheit</arg_value>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
))
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""<tool_call>get_current_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Dallas</arg_value>
|
||||
<arg_key>state</arg_key>
|
||||
<arg_value>TX</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>fahrenheit</arg_value>
|
||||
</tool_call>
|
||||
<tool_call>get_current_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Orlando</arg_value>
|
||||
<arg_key>state</arg_key>
|
||||
<arg_value>FL</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>fahrenheit</arg_value>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Dallas",
|
||||
"state": "TX",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
)),
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Orlando",
|
||||
"state": "FL",
|
||||
"unit": "fahrenheit",
|
||||
}),
|
||||
)),
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
"""I'll help you check the weather. <tool_call>get_current_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Seattle</arg_value>
|
||||
<arg_key>state</arg_key>
|
||||
<arg_value>WA</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>celsius</arg_value>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Seattle",
|
||||
"state": "WA",
|
||||
"unit": "celsius",
|
||||
}),
|
||||
))
|
||||
],
|
||||
"I'll help you check the weather.",
|
||||
),
|
||||
(
|
||||
"""<tool_call>get_current_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>New York</arg_value>
|
||||
<arg_key>state</arg_key>
|
||||
<arg_value>NY</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>celsius</arg_value>
|
||||
</tool_call>""",
|
||||
[
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_current_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "New York",
|
||||
"state": "NY",
|
||||
"unit": "celsius",
|
||||
}),
|
||||
))
|
||||
],
|
||||
None,
|
||||
),
|
||||
("""I will help you get the weather.<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Beijing</arg_value>
|
||||
<arg_key>date</arg_key>
|
||||
<arg_value>2025-08-01</arg_value>
|
||||
</tool_call>""", [
|
||||
ToolCall(function=FunctionCall(
|
||||
name="get_weather",
|
||||
arguments=json.dumps({
|
||||
"city": "Beijing",
|
||||
"date": "2025-08-01",
|
||||
}),
|
||||
))
|
||||
], "I will help you get the weather."),
|
||||
],
|
||||
)
|
||||
def test_extract_tool_calls(glm4_moe_tool_parser, model_output,
|
||||
expected_tool_calls, expected_content):
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert_tool_calls(extracted_tool_calls.tool_calls, expected_tool_calls)
|
||||
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
def test_extract_tool_calls_with_thinking_tags(glm4_moe_tool_parser):
|
||||
"""Test tool extraction when thinking tags are present."""
|
||||
model_output = """<think>I want to get the weather.</think>
|
||||
|
||||
I will help you get the weather.
|
||||
<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Beijing</arg_value>
|
||||
<arg_key>date</arg_key>
|
||||
<arg_value>2025-08-01</arg_value>
|
||||
</tool_call>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert len(extracted_tool_calls.tool_calls) == 1
|
||||
assert extracted_tool_calls.tool_calls[0].function.name == "get_weather"
|
||||
|
||||
expected_content = """<think>I want to get the weather.</think>
|
||||
|
||||
I will help you get the weather."""
|
||||
assert extracted_tool_calls.content == expected_content
|
||||
|
||||
|
||||
def test_extract_tool_calls_malformed_xml(glm4_moe_tool_parser):
|
||||
"""Test that malformed XML is handled gracefully."""
|
||||
model_output = """<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Seattle</arg_value>
|
||||
<arg_key>incomplete_arg
|
||||
<arg_value>value</arg_value>
|
||||
</tool_call>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
|
||||
# Should handle malformed XML gracefully
|
||||
# The parser should either extract what it can or return no tool calls
|
||||
# depending on how robust we want the parsing to be
|
||||
assert isinstance(extracted_tool_calls.tools_called, bool)
|
||||
assert isinstance(extracted_tool_calls.tool_calls, list)
|
||||
|
||||
|
||||
def test_extract_tool_calls_empty_arguments(glm4_moe_tool_parser):
|
||||
"""Test tool calls with no arguments."""
|
||||
model_output = """<tool_call>get_current_time
|
||||
</tool_call>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert len(extracted_tool_calls.tool_calls) == 1
|
||||
assert extracted_tool_calls.tool_calls[
|
||||
0].function.name == "get_current_time"
|
||||
# Empty arguments should result in empty JSON object
|
||||
assert extracted_tool_calls.tool_calls[0].function.arguments == "{}"
|
||||
|
||||
|
||||
def test_extract_tool_calls_mixed_content(glm4_moe_tool_parser):
|
||||
"""Test extraction with mixed content and multiple tool calls."""
|
||||
model_output = """I will help you get the weather info.
|
||||
|
||||
<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Beijing</arg_value>
|
||||
<arg_key>date</arg_key>
|
||||
<arg_value>2025-08-01</arg_value>
|
||||
</tool_call>
|
||||
|
||||
meaningwhile, I will also check the weather in Shanghai.
|
||||
|
||||
<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Shanghai</arg_value>
|
||||
<arg_key>date</arg_key>
|
||||
<arg_value>2025-08-01</arg_value>
|
||||
</tool_call>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert len(extracted_tool_calls.tool_calls) == 2
|
||||
|
||||
# Check first tool call
|
||||
assert extracted_tool_calls.tool_calls[0].function.name == "get_weather"
|
||||
args1 = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
|
||||
assert args1["city"] == "Beijing"
|
||||
assert args1["date"] == "2025-08-01"
|
||||
|
||||
# Check second tool call
|
||||
assert extracted_tool_calls.tool_calls[1].function.name == "get_weather"
|
||||
args2 = json.loads(extracted_tool_calls.tool_calls[1].function.arguments)
|
||||
assert args2["city"] == "Shanghai"
|
||||
assert args2["date"] == "2025-08-01"
|
||||
|
||||
# Content should be everything before the first tool call
|
||||
assert extracted_tool_calls.content == "I will help you get the weather info."
|
||||
|
||||
|
||||
def test_streaming_basic_functionality(glm4_moe_tool_parser):
|
||||
"""Test basic streaming functionality."""
|
||||
# Reset streaming state
|
||||
glm4_moe_tool_parser.current_tool_name_sent = False
|
||||
glm4_moe_tool_parser.prev_tool_call_arr = []
|
||||
glm4_moe_tool_parser.current_tool_id = -1
|
||||
glm4_moe_tool_parser.streamed_args_for_tool = []
|
||||
|
||||
# Test with a simple tool call
|
||||
current_text = """<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Beijing</arg_value>
|
||||
</tool_call>"""
|
||||
|
||||
# Mock token IDs for testing
|
||||
tool_call_start_id = glm4_moe_tool_parser.tool_call_start_token_id or 12345
|
||||
tool_call_end_id = glm4_moe_tool_parser.tool_call_end_token_id or 12346
|
||||
|
||||
result = glm4_moe_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text=current_text,
|
||||
delta_text="</tool_call>",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[tool_call_start_id, tool_call_end_id],
|
||||
delta_token_ids=[tool_call_end_id],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# The result behavior depends on the streaming state
|
||||
# This test mainly ensures no exceptions are thrown
|
||||
assert result is None or hasattr(result, 'tool_calls') or hasattr(
|
||||
result, 'content')
|
||||
|
||||
|
||||
def test_streaming_no_tool_calls(glm4_moe_tool_parser):
|
||||
"""Test streaming when there are no tool calls."""
|
||||
current_text = "This is just regular text without any tool calls."
|
||||
|
||||
result = glm4_moe_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="This is just regular text",
|
||||
current_text=current_text,
|
||||
delta_text=" without any tool calls.",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Should return the delta text as content
|
||||
assert result is not None
|
||||
assert hasattr(result, 'content')
|
||||
assert result.content == " without any tool calls."
|
||||
|
||||
|
||||
def test_streaming_with_content_before_tool_calls(glm4_moe_tool_parser):
|
||||
"""Test streaming when there's content before tool calls."""
|
||||
# Reset streaming state
|
||||
glm4_moe_tool_parser.current_tool_name_sent = False
|
||||
glm4_moe_tool_parser.prev_tool_call_arr = []
|
||||
glm4_moe_tool_parser.current_tool_id = -1
|
||||
glm4_moe_tool_parser.streamed_args_for_tool = []
|
||||
|
||||
current_text = "I will help you get the weather<tool_call>"
|
||||
|
||||
result = glm4_moe_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="I will help you",
|
||||
current_text=current_text,
|
||||
delta_text="get the weather.<tool_call>",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[],
|
||||
delta_token_ids=[],
|
||||
request=None,
|
||||
)
|
||||
|
||||
# Should return content when no tool call tokens are detected
|
||||
assert result is not None
|
||||
assert hasattr(result, 'content')
|
||||
assert result.content == "get the weather.<tool_call>"
|
||||
|
||||
|
||||
def test_extract_tool_calls_special_characters(glm4_moe_tool_parser):
|
||||
"""Test tool calls with special characters and unicode."""
|
||||
model_output = """<tool_call>send_message
|
||||
<arg_key>recipient</arg_key>
|
||||
<arg_value>Amy</arg_value>
|
||||
<arg_key>message</arg_key>
|
||||
<arg_value>It is a nice day</arg_value>
|
||||
<arg_key>priority</arg_key>
|
||||
<arg_value>high</arg_value>
|
||||
</tool_call>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
|
||||
assert extracted_tool_calls.tools_called
|
||||
assert len(extracted_tool_calls.tool_calls) == 1
|
||||
assert extracted_tool_calls.tool_calls[0].function.name == "send_message"
|
||||
|
||||
args = json.loads(extracted_tool_calls.tool_calls[0].function.arguments)
|
||||
assert args["recipient"] == "Amy"
|
||||
assert args["message"] == "It is a nice day"
|
||||
assert args["priority"] == "high"
|
||||
|
||||
|
||||
def test_extract_tool_calls_incomplete_tool_call(glm4_moe_tool_parser):
|
||||
"""Test incomplete tool calls (missing closing tag)."""
|
||||
model_output = """<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Beijing</arg_value>
|
||||
<arg_key>date</arg_key>
|
||||
<arg_value>2025-08-01</arg_value>"""
|
||||
|
||||
extracted_tool_calls = glm4_moe_tool_parser.extract_tool_calls(
|
||||
model_output, request=None) # type: ignore[arg-type]
|
||||
|
||||
# Incomplete tool calls should not be extracted
|
||||
assert not extracted_tool_calls.tools_called
|
||||
assert extracted_tool_calls.tool_calls == []
|
||||
assert extracted_tool_calls.content == model_output
|
||||
@@ -1333,7 +1333,8 @@ class ModelConfig:
|
||||
self, parallel_config: "ParallelConfig") -> tuple[int, int]:
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
if (self.hf_text_config.model_type == "deepseek_mtp"
|
||||
or self.hf_config.model_type == "mimo_mtp"):
|
||||
or self.hf_config.model_type == "mimo_mtp"
|
||||
or self.hf_config.model_type == "glm4_moe_mtp"):
|
||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||
"num_nextn_predict_layers", 0)
|
||||
else:
|
||||
@@ -2663,7 +2664,15 @@ class SpeculativeConfig:
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["MiMoMTPModel"]
|
||||
})
|
||||
return hf_config
|
||||
|
||||
if hf_config.architectures[0] == "Glm4MoeForCausalLM":
|
||||
hf_config.model_type = "glm4_moe_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update({
|
||||
"num_hidden_layers": 0,
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["Glm4MoeMTPModel"]
|
||||
})
|
||||
|
||||
return hf_config
|
||||
|
||||
@@ -2774,7 +2783,7 @@ class SpeculativeConfig:
|
||||
"mlp_speculator"):
|
||||
self.method = "mlp_speculator"
|
||||
elif (self.draft_model_config.hf_config.model_type
|
||||
in ("deepseek_mtp", "mimo_mtp")):
|
||||
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
|
||||
self.method = "deepseek_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
from .abstract_tool_parser import ToolParser, ToolParserManager
|
||||
from .deepseekv3_tool_parser import DeepSeekV3ToolParser
|
||||
from .glm4_moe_tool_parser import Glm4MoeModelToolParser
|
||||
from .granite_20b_fc_tool_parser import Granite20bFCToolParser
|
||||
from .granite_tool_parser import GraniteToolParser
|
||||
from .hermes_tool_parser import Hermes2ProToolParser
|
||||
@@ -19,10 +20,22 @@ from .pythonic_tool_parser import PythonicToolParser
|
||||
from .xlam_tool_parser import xLAMToolParser
|
||||
|
||||
__all__ = [
|
||||
"ToolParser", "ToolParserManager", "Granite20bFCToolParser",
|
||||
"GraniteToolParser", "Hermes2ProToolParser", "MistralToolParser",
|
||||
"Internlm2ToolParser", "Llama3JsonToolParser", "JambaToolParser",
|
||||
"Llama4PythonicToolParser", "PythonicToolParser", "Phi4MiniJsonToolParser",
|
||||
"DeepSeekV3ToolParser", "xLAMToolParser", "MinimaxToolParser",
|
||||
"KimiK2ToolParser", "HunyuanA13BToolParser"
|
||||
"ToolParser",
|
||||
"ToolParserManager",
|
||||
"Granite20bFCToolParser",
|
||||
"GraniteToolParser",
|
||||
"Hermes2ProToolParser",
|
||||
"MistralToolParser",
|
||||
"Internlm2ToolParser",
|
||||
"Llama3JsonToolParser",
|
||||
"JambaToolParser",
|
||||
"Llama4PythonicToolParser",
|
||||
"PythonicToolParser",
|
||||
"Phi4MiniJsonToolParser",
|
||||
"DeepSeekV3ToolParser",
|
||||
"xLAMToolParser",
|
||||
"MinimaxToolParser",
|
||||
"KimiK2ToolParser",
|
||||
"HunyuanA13BToolParser",
|
||||
"Glm4MoeModelToolParser",
|
||||
]
|
||||
|
||||
402
vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py
Normal file
402
vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py
Normal file
@@ -0,0 +1,402 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# code modified from deepseekv3_tool_parser.py
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Union
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("glm4_moe")
|
||||
class Glm4MoeModelToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
self.current_tool_name_sent = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.current_tool_id = -1
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
self.tool_call_start_token = "<tool_call>"
|
||||
self.tool_call_end_token = "</tool_call>"
|
||||
|
||||
self.tool_calls_start_token = self.tool_call_start_token
|
||||
|
||||
# Updated regex for the XML-based format
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<tool_call>\s*"
|
||||
r"(?P<function_name>[^\n<]+)\s*" # 函数名(到换行或 <)
|
||||
r"(?P<arguments>(?:\s*<arg_key>[^<]+</arg_key>\s*"
|
||||
r"<arg_value>[^<]*</arg_value>\s*)*)\s*"
|
||||
r"</tool_call>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
# Regex for parsing individual arguments
|
||||
self.arg_regex = re.compile(
|
||||
r"<arg_key>(?P<key>[^<]+)</arg_key>\s*<arg_value>(?P<value>[^<]*)</arg_value>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
# Streaming regex
|
||||
self.stream_tool_call_portion_regex = re.compile(
|
||||
r"(?P<function_name>[^\n<]+)\s*"
|
||||
r"(?P<arguments>(?:\s*<arg_key>[^<]+</arg_key>\s*"
|
||||
r"<arg_value>[^<]*</arg_value>\s*)*)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
# For streaming, we also need a regex to match just the function name
|
||||
self.stream_tool_call_name_regex = re.compile(
|
||||
r"(?P<function_name>[^\n<]+)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction.")
|
||||
|
||||
self.tool_call_start_token_id = self.vocab.get(
|
||||
self.tool_call_start_token)
|
||||
self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
|
||||
|
||||
def _parse_arguments(self, args_text: str) -> str:
|
||||
"""Parse XML-based arguments into JSON format."""
|
||||
if not args_text or not args_text.strip():
|
||||
return "{}"
|
||||
|
||||
args_dict = {}
|
||||
matches = self.arg_regex.findall(args_text)
|
||||
|
||||
for key, value in matches:
|
||||
args_dict[key.strip()] = value.strip()
|
||||
|
||||
import json
|
||||
return json.dumps(args_dict, ensure_ascii=False)
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_calls_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
try:
|
||||
# Find all tool calls in the output
|
||||
function_call_matches = self.tool_call_regex.findall(model_output)
|
||||
|
||||
logger.debug("function_call_matches: %s", function_call_matches)
|
||||
|
||||
if not function_call_matches:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output,
|
||||
)
|
||||
|
||||
tool_calls = []
|
||||
for i, match in enumerate(function_call_matches):
|
||||
function_name, function_args_xml = match
|
||||
function_name = function_name.strip()
|
||||
|
||||
# Parse XML arguments to JSON
|
||||
function_args_json = self._parse_arguments(function_args_xml)
|
||||
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
id=f"call_{i}",
|
||||
type='function',
|
||||
function=FunctionCall(name=function_name,
|
||||
arguments=function_args_json),
|
||||
))
|
||||
|
||||
# Extract content before the first tool call
|
||||
content = model_output[:model_output.find(self.
|
||||
tool_calls_start_token)]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=bool(tool_calls),
|
||||
tool_calls=tool_calls,
|
||||
content=content.strip() if content.strip() else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in extracting tool call from response.")
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||
# check to see if we should be streaming a tool call - is there a
|
||||
if self.tool_call_start_token_id not in current_token_ids:
|
||||
logger.debug("No tool call tokens found!")
|
||||
return DeltaMessage(content=delta_text)
|
||||
delta_text = delta_text.replace(self.tool_calls_start_token,
|
||||
"").replace(self.tool_call_end_token,
|
||||
"")
|
||||
try:
|
||||
|
||||
# figure out where we are in the parsing by counting tool call
|
||||
# start & end tags
|
||||
prev_tool_start_count = previous_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
prev_tool_end_count = previous_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
cur_tool_start_count = current_token_ids.count(
|
||||
self.tool_call_start_token_id)
|
||||
cur_tool_end_count = current_token_ids.count(
|
||||
self.tool_call_end_token_id)
|
||||
tool_call_portion = None
|
||||
text_portion = None
|
||||
|
||||
# case: if we're generating text, OR rounding out a tool call
|
||||
if (cur_tool_start_count == cur_tool_end_count
|
||||
and prev_tool_end_count == cur_tool_end_count
|
||||
and self.tool_call_end_token not in delta_text):
|
||||
logger.debug("Generating text content! skipping tool parsing.")
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
if self.tool_call_end_token in delta_text:
|
||||
logger.debug("tool_call_end_token in delta_text")
|
||||
full_text = current_text + delta_text
|
||||
tool_call_portion = full_text.split(
|
||||
self.tool_call_start_token)[-1].split(
|
||||
self.tool_call_end_token)[0].rstrip()
|
||||
delta_text = delta_text.split(
|
||||
self.tool_call_end_token)[0].rstrip()
|
||||
text_portion = delta_text.split(
|
||||
self.tool_call_end_token)[-1].lstrip()
|
||||
|
||||
# case -- we're starting a new tool call
|
||||
if (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count > prev_tool_start_count):
|
||||
if len(delta_token_ids) > 1:
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
else:
|
||||
tool_call_portion = None
|
||||
delta = None
|
||||
|
||||
text_portion = None
|
||||
|
||||
# set cursors and state appropriately
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
logger.debug("Starting on a new tool %s", self.current_tool_id)
|
||||
|
||||
# case -- we're updating an existing tool call
|
||||
elif (cur_tool_start_count > cur_tool_end_count
|
||||
and cur_tool_start_count == prev_tool_start_count):
|
||||
|
||||
# get the portion of the text that's the tool call
|
||||
tool_call_portion = current_text.split(
|
||||
self.tool_call_start_token)[-1]
|
||||
text_portion = None
|
||||
|
||||
# case -- the current tool call is being closed.
|
||||
elif (cur_tool_start_count == cur_tool_end_count
|
||||
and cur_tool_end_count >= prev_tool_end_count):
|
||||
if self.prev_tool_call_arr is None or len(
|
||||
self.prev_tool_call_arr) == 0:
|
||||
logger.debug(
|
||||
"attempting to close tool call, but no tool call")
|
||||
return None
|
||||
diff = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
if diff:
|
||||
diff = (diff.encode("utf-8").decode("unicode_escape")
|
||||
if diff is str else diff)
|
||||
if '"}' not in delta_text:
|
||||
return None
|
||||
end_loc = delta_text.rindex('"}')
|
||||
diff = delta_text[:end_loc] + '"}'
|
||||
logger.debug(
|
||||
"Finishing tool and found diff that had not "
|
||||
"been streamed yet: %s",
|
||||
diff,
|
||||
)
|
||||
self.streamed_args_for_tool[self.current_tool_id] += diff
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=diff).model_dump(exclude_none=True),
|
||||
)
|
||||
])
|
||||
|
||||
# case -- otherwise we're just generating text
|
||||
else:
|
||||
text = delta_text.replace(self.tool_call_start_token, "")
|
||||
text = text.replace(self.tool_call_end_token, "")
|
||||
delta = DeltaMessage(tool_calls=[], content=text)
|
||||
return delta
|
||||
|
||||
current_tool_call = dict()
|
||||
if tool_call_portion:
|
||||
current_tool_call_matches = (
|
||||
self.stream_tool_call_portion_regex.match(
|
||||
tool_call_portion))
|
||||
if current_tool_call_matches:
|
||||
tool_id, tool_args = (current_tool_call_matches.groups())
|
||||
tool_name = tool_id.split('.')[1].split(':')[0]
|
||||
current_tool_call['id'] = tool_id
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = tool_args
|
||||
else:
|
||||
current_tool_call_name_matches = (
|
||||
self.stream_tool_call_name_regex.match(
|
||||
tool_call_portion))
|
||||
if current_tool_call_name_matches:
|
||||
tool_id_str, = current_tool_call_name_matches.groups()
|
||||
tool_name = tool_id_str.split('.')[1].split(':')[0]
|
||||
current_tool_call['id'] = tool_id_str
|
||||
current_tool_call["name"] = tool_name
|
||||
current_tool_call["arguments"] = ""
|
||||
else:
|
||||
logger.debug("Not enough token")
|
||||
return None
|
||||
|
||||
# case - we haven't sent the tool name yet. If it's available, send
|
||||
# it. otherwise, wait until it's available.
|
||||
if not self.current_tool_name_sent:
|
||||
if current_tool_call is None:
|
||||
return None
|
||||
function_name: Union[str, None] = current_tool_call.get("name")
|
||||
tool_id = current_tool_call.get("id")
|
||||
if function_name:
|
||||
self.current_tool_name_sent = True
|
||||
return DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
else:
|
||||
return None
|
||||
|
||||
# case -- otherwise, send the tool call delta
|
||||
|
||||
# if the tool call portion is None, send the delta as text
|
||||
if tool_call_portion is None:
|
||||
# if there's text but not tool calls, send that -
|
||||
# otherwise None to skip chunk
|
||||
delta = (DeltaMessage(
|
||||
content=delta_text) if text_portion is not None else None)
|
||||
return delta
|
||||
|
||||
# now, the nitty-gritty of tool calls
|
||||
# now we have the portion to parse as tool call.
|
||||
|
||||
logger.debug("Trying to parse current tool call with ID %s",
|
||||
self.current_tool_id)
|
||||
|
||||
# if we're starting a new tool call, push an empty object in as
|
||||
# a placeholder for the arguments
|
||||
if len(self.prev_tool_call_arr) <= self.current_tool_id:
|
||||
self.prev_tool_call_arr.append({})
|
||||
|
||||
# main logic for tool parsing here - compare prev. partially-parsed
|
||||
# JSON to the current partially-parsed JSON
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments")
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
|
||||
logger.debug("diffing old arguments: %s", prev_arguments)
|
||||
logger.debug("against new ones: %s", cur_arguments)
|
||||
|
||||
# case -- no arguments have been created yet. skip sending a delta.
|
||||
if not cur_arguments and not prev_arguments:
|
||||
logger.debug("Skipping text %s - no arguments", delta_text)
|
||||
delta = None
|
||||
|
||||
# case -- prev arguments are defined, but non are now.
|
||||
# probably impossible, but not a fatal error - just keep going
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error("should be impossible to have arguments reset "
|
||||
"mid-call. skipping streaming anything.")
|
||||
delta = None
|
||||
|
||||
# case -- we now have the first info about arguments available from
|
||||
# autocompleting the JSON
|
||||
elif cur_arguments and not prev_arguments:
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=cur_arguments).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] = cur_arguments
|
||||
|
||||
# last case -- we have an update to existing arguments.
|
||||
elif cur_arguments and prev_arguments:
|
||||
if (isinstance(delta_text, str)
|
||||
and cur_arguments != prev_arguments
|
||||
and len(cur_arguments) > len(prev_arguments)
|
||||
and cur_arguments.startswith(prev_arguments)):
|
||||
delta_arguments = cur_arguments[len(prev_arguments):]
|
||||
logger.debug("got diff %s", delta_text)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=delta_arguments).model_dump(
|
||||
exclude_none=True),
|
||||
)
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] = cur_arguments
|
||||
else:
|
||||
delta = None
|
||||
|
||||
# handle saving the state for the current tool into
|
||||
# the "prev" list for use in diffing for the next iteration
|
||||
if self.current_tool_id == len(self.prev_tool_call_arr) - 1:
|
||||
self.prev_tool_call_arr[
|
||||
self.current_tool_id] = current_tool_call
|
||||
else:
|
||||
self.prev_tool_call_arr.append(current_tool_call)
|
||||
|
||||
return delta
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error trying to handle streaming tool call.")
|
||||
return None # do not stream a delta. skip this token ID.
|
||||
685
vllm/model_executor/models/glm4_moe.py
Normal file
685
vllm/model_executor/models/glm4_moe.py
Normal file
@@ -0,0 +1,685 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2025 The ZhipuAI Team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GLM-4.5 model compatible with HuggingFace weights."""
|
||||
import typing
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import (get_ep_group, get_pp_group,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class Glm4MoeMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Glm4MoE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
|
||||
self.ep_group = get_ep_group().device_group
|
||||
self.ep_rank = self.ep_group.rank()
|
||||
self.ep_size = self.ep_group.size()
|
||||
self.n_routed_experts: int = config.n_routed_experts
|
||||
self.n_shared_experts: int = config.n_shared_experts
|
||||
|
||||
if config.hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
config.n_routed_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate")
|
||||
|
||||
# noaux_tc is not set in transformers new config now
|
||||
self.gate.e_score_correction_bias = (nn.Parameter(
|
||||
torch.empty(config.n_routed_experts)))
|
||||
|
||||
# Load balancing settings.
|
||||
vllm_config = get_current_vllm_config()
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.enable_eplb = enable_eplb
|
||||
|
||||
self.n_redundant_experts = parallel_config.num_redundant_experts
|
||||
self.n_logical_experts = self.n_routed_experts
|
||||
self.n_physical_experts = (self.n_logical_experts +
|
||||
self.n_redundant_experts)
|
||||
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
||||
|
||||
self.physical_expert_start = (self.ep_rank *
|
||||
self.n_local_physical_experts)
|
||||
self.physical_expert_end = (self.physical_expert_start +
|
||||
self.n_local_physical_experts)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
prefix=f"{prefix}.experts",
|
||||
scoring_func="sigmoid",
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = (config.moe_intermediate_size *
|
||||
config.n_shared_experts)
|
||||
self.shared_experts = Glm4MoeMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=self.experts.must_reduce_shared_expert_outputs(
|
||||
),
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits) * self.routed_scaling_factor
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = (
|
||||
self.experts.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states))
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
class Glm4MoeAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 131072,
|
||||
head_dim: Optional[int] = None,
|
||||
rms_norm_eps: float = 1e-05,
|
||||
qkv_bias: bool = False,
|
||||
use_qk_norm: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = head_dim or (hidden_size // self.total_num_heads)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.use_qk_norm = use_qk_norm
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj")
|
||||
|
||||
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
|
||||
partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
partial_rotary_factor=partial_rotary_factor,
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if self.use_qk_norm:
|
||||
q = self.q_norm(q.reshape(-1, self.num_heads,
|
||||
self.head_dim)).reshape(q.shape)
|
||||
k = self.k_norm(k.reshape(-1, self.num_kv_heads,
|
||||
self.head_dim)).reshape(k.shape)
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class Glm4MoeDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
131072)
|
||||
# DecoderLayers are created with `make_layers` which passes the prefix
|
||||
# with the layer's index.
|
||||
layer_idx = int(prefix.split(sep='.')[-1])
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.self_attn = Glm4MoeAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
head_dim=config.head_dim,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
qkv_bias=config.attention_bias,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_qk_norm=config.use_qk_norm,
|
||||
)
|
||||
|
||||
if (config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace):
|
||||
self.mlp = Glm4MoE(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
enable_eplb=enable_eplb,
|
||||
)
|
||||
else:
|
||||
self.mlp = Glm4MoeMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(positions=positions,
|
||||
hidden_states=hidden_states)
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Glm4MoeModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
enable_eplb = vllm_config.parallel_config.enable_eplb
|
||||
self.config = config
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.embed_tokens")
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: Glm4MoeDecoderLayer(
|
||||
config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
enable_eplb=enable_eplb,
|
||||
),
|
||||
prefix=f"{prefix}.layers")
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
self.make_empty_intermediate_tensors = (
|
||||
make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size))
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states": hidden_states,
|
||||
"residual": residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
def make_empty_intermediate_tensors(
|
||||
self, batch_size: int, dtype: torch.dtype,
|
||||
device: torch.device) -> IntermediateTensors:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
"residual":
|
||||
torch.zeros((batch_size, self.config.hidden_size),
|
||||
dtype=dtype,
|
||||
device=device),
|
||||
})
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is not None:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if (("mlp.experts." in name) and name not in params_dict):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
is_expert_weight = False
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
# Anyway, this is an expert weight and should not be
|
||||
# attempted to load as other weights later
|
||||
is_expert_weight = True
|
||||
|
||||
# Do not modify `name` since the loop may continue here
|
||||
# Instead, create a new variable
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
|
||||
if is_pp_missing_parameter(name_mapped, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name_mapped]
|
||||
# We should ask the weight loader to return success or not
|
||||
# here since otherwise we may skip experts with other
|
||||
# available replicas.
|
||||
weight_loader = typing.cast(Callable[..., bool],
|
||||
param.weight_loader)
|
||||
success = weight_loader(param,
|
||||
loaded_weight,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True)
|
||||
if success:
|
||||
name = name_mapped
|
||||
break
|
||||
else:
|
||||
if is_expert_weight:
|
||||
# We've checked that this is an expert weight
|
||||
# However it's not mapped locally to this rank
|
||||
# So we simply skip it
|
||||
continue
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Glm4MoeForCausalLM(nn.Module, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.model = Glm4MoeModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head.weight = self.model.embed_tokens.weight
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
self.expert_weights = []
|
||||
|
||||
# Set MoE hyperparameters
|
||||
self.num_moe_layers = (config.num_hidden_layers -
|
||||
config.first_k_dense_replace)
|
||||
self.num_expert_groups = config.n_group
|
||||
|
||||
self.moe_layers: list[FusedMoE] = []
|
||||
for layer in self.model.layers:
|
||||
assert isinstance(layer, Glm4MoeDecoderLayer)
|
||||
if isinstance(layer.mlp, Glm4MoE):
|
||||
self.moe_layers.append(layer.mlp.experts)
|
||||
|
||||
# Pick last one layer since the first ones may be dense layers.
|
||||
example_moe = typing.cast(
|
||||
Glm4MoE, self.model.layers[config.num_hidden_layers - 1].mlp)
|
||||
self.num_logical_experts = example_moe.n_logical_experts
|
||||
self.num_physical_experts = example_moe.n_physical_experts
|
||||
self.num_local_physical_experts = example_moe.n_local_physical_experts
|
||||
self.num_routed_experts = example_moe.n_routed_experts
|
||||
self.num_shared_experts = example_moe.n_shared_experts
|
||||
self.num_redundant_experts = example_moe.n_redundant_experts
|
||||
|
||||
def set_eplb_state(
|
||||
self,
|
||||
expert_load_view: torch.Tensor,
|
||||
logical_to_physical_map: torch.Tensor,
|
||||
logical_replica_count: torch.Tensor,
|
||||
) -> None:
|
||||
for layer_idx, layer in enumerate(self.moe_layers):
|
||||
# Register the expert weights.
|
||||
self.expert_weights.append(layer.get_expert_weights())
|
||||
layer.set_eplb_state(
|
||||
moe_layer_idx=layer_idx,
|
||||
expert_load_view=expert_load_view,
|
||||
logical_to_physical_map=logical_to_physical_map,
|
||||
logical_replica_count=logical_replica_count,
|
||||
)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
def get_spec_layer_idx_from_weight_name(config: PretrainedConfig,
|
||||
weight_name: str) -> Optional[int]:
|
||||
if hasattr(config,
|
||||
"num_nextn_predict_layers") and (config.num_nextn_predict_layers
|
||||
> 0):
|
||||
layer_idx = config.num_hidden_layers
|
||||
for i in range(config.num_nextn_predict_layers):
|
||||
if f"layers.{layer_idx+i}." in weight_name:
|
||||
return layer_idx + i
|
||||
return None
|
||||
307
vllm/model_executor/models/glm4_moe_mtp.py
Normal file
307
vllm/model_executor/models/glm4_moe_mtp.py
Normal file
@@ -0,0 +1,307 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2025 The ZhipuAI Team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GLM-4.5 MTP model compatible with HuggingFace weights."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .glm4_moe import Glm4MoeDecoderLayer, get_spec_layer_idx_from_weight_name
|
||||
from .interfaces import SupportsPP
|
||||
from .utils import maybe_prefix
|
||||
|
||||
|
||||
class SharedHead(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
return self.norm(hidden_states)
|
||||
|
||||
|
||||
class Glm4MoeMultiTokenPredictorLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
self.shared_head = SharedHead(config=config, quant_config=quant_config)
|
||||
self.mtp_block = Glm4MoeDecoderLayer(config=config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
assert inputs_embeds is not None
|
||||
# masking inputs at position 0, as not needed by MTP
|
||||
inputs_embeds[positions == 0] = 0
|
||||
inputs_embeds = self.enorm(inputs_embeds)
|
||||
previous_hidden_states = self.hnorm(previous_hidden_states)
|
||||
|
||||
hidden_states = self.eh_proj(
|
||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
||||
|
||||
hidden_states, residual = self.mtp_block(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=None)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Glm4MoeMultiTokenPredictor(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.mtp_start_layer_idx = config.num_hidden_layers
|
||||
self.num_mtp_layers = config.num_nextn_predict_layers
|
||||
# to map the exact layer index from weights
|
||||
self.layers = torch.nn.ModuleDict({
|
||||
str(idx):
|
||||
Glm4MoeMultiTokenPredictorLayer(
|
||||
config,
|
||||
f"{prefix}.layers.{idx}",
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=vllm_config.quant_config,
|
||||
)
|
||||
for idx in range(self.mtp_start_layer_idx,
|
||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
||||
})
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
|
||||
input_ids,
|
||||
positions,
|
||||
previous_hidden_states,
|
||||
inputs_embeds,
|
||||
current_step_idx,
|
||||
)
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
mtp_layer = self.layers[str(self.mtp_start_layer_idx +
|
||||
current_step_idx)]
|
||||
logits = self.logits_processor(mtp_layer.shared_head.head,
|
||||
mtp_layer.shared_head(hidden_states),
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
|
||||
class Glm4MoeMTP(nn.Module, SupportsPP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = vllm_config.model_config.hf_config
|
||||
self.model = Glm4MoeMultiTokenPredictor(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "model"))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions,
|
||||
previous_hidden_states, inputs_embeds,
|
||||
spec_step_idx)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
spec_step_idx: int = 0,
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.model.compute_logits(hidden_states, sampling_metadata,
|
||||
spec_step_idx)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is None:
|
||||
continue
|
||||
name = self._rewrite_spec_layer_name(spec_layer, name)
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if (("mlp.experts." in name) and name not in params_dict):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
# According to DeepSeek-V3 Technical Report, MTP modules
|
||||
# shares embedding layer. We only load the first weights.
|
||||
if (spec_layer != self.model.mtp_start_layer_idx
|
||||
and ".layers" not in name):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
|
||||
"""
|
||||
Rewrite the weight name to match the format of the original model.
|
||||
Add .mtp_block for modules in transformer layer block for spec layer
|
||||
and rename shared layer weights to be top level.
|
||||
"""
|
||||
spec_layer_weight_names = [
|
||||
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
|
||||
]
|
||||
shared_weight_names = ["embed_tokens"]
|
||||
spec_layer_weight = False
|
||||
shared_weight = False
|
||||
for weight_name in spec_layer_weight_names:
|
||||
if weight_name in name:
|
||||
spec_layer_weight = True
|
||||
if weight_name in shared_weight_names:
|
||||
shared_weight = True
|
||||
break
|
||||
if not spec_layer_weight:
|
||||
# treat rest weights as weights for transformer layer block
|
||||
name = name.replace(f"model.layers.{spec_layer}.",
|
||||
f"model.layers.{spec_layer}.mtp_block.")
|
||||
elif shared_weight:
|
||||
# treat shared weights as top level weights
|
||||
name = name.replace(f"model.layers.{spec_layer}.", "model.")
|
||||
return name
|
||||
@@ -67,6 +67,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"Gemma3nForConditionalGeneration": ("gemma3n", "Gemma3nForConditionalGeneration"), # noqa: E501
|
||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
|
||||
"Glm4MoeForCausalLM": ("glm4_moe", "Glm4MoeForCausalLM"),
|
||||
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
||||
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
||||
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
|
||||
@@ -244,6 +245,7 @@ _SPECULATIVE_DECODING_MODELS = {
|
||||
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
|
||||
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
|
||||
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
|
||||
"MedusaModel": ("medusa", "Medusa"),
|
||||
# Temporarily disabled.
|
||||
# # TODO(woosuk): Re-enable this once the MLP Speculator is supported in V1.
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
from .abs_reasoning_parsers import ReasoningParser, ReasoningParserManager
|
||||
from .deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
|
||||
from .glm4_moe_reasoning_parser import Glm4MoeModelReasoningParser
|
||||
from .granite_reasoning_parser import GraniteReasoningParser
|
||||
from .hunyuan_a13b_reasoning_parser import HunyuanA13BReasoningParser
|
||||
from .qwen3_reasoning_parser import Qwen3ReasoningParser
|
||||
@@ -14,4 +15,5 @@ __all__ = [
|
||||
"GraniteReasoningParser",
|
||||
"HunyuanA13BReasoningParser",
|
||||
"Qwen3ReasoningParser",
|
||||
"Glm4MoeModelReasoningParser",
|
||||
]
|
||||
|
||||
151
vllm/reasoning/glm4_moe_reasoning_parser.py
Normal file
151
vllm/reasoning/glm4_moe_reasoning_parser.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional, Union
|
||||
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ReasoningParserManager.register_module("glm4_moe")
|
||||
class Glm4MoeModelReasoningParser(ReasoningParser):
|
||||
"""
|
||||
Reasoning parser for the Glm4MoeModel model.
|
||||
|
||||
The Glm4MoeModel model uses <think>...</think> tokens to denote reasoning
|
||||
text within its output. The model provides a strict switch to disable
|
||||
reasoning output via the 'enable_thinking=False' parameter. This parser
|
||||
extracts the reasoning content enclosed by <think> and </think> tokens
|
||||
from the model's output.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||||
super().__init__(tokenizer)
|
||||
self.think_start_token = "<think>"
|
||||
self.think_end_token = "</think>"
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ReasoningParser "
|
||||
"constructor during construction.")
|
||||
|
||||
self.think_start_token_id = self.vocab.get(self.think_start_token)
|
||||
self.think_end_token_id = self.vocab.get(self.think_end_token)
|
||||
if (self.think_start_token_id is None
|
||||
or self.think_end_token_id is None):
|
||||
raise RuntimeError(
|
||||
"Glm4MoeModel reasoning parser could not locate "
|
||||
"think start/end tokens in the tokenizer!")
|
||||
|
||||
def is_reasoning_end(self, input_ids: list[int]) -> bool:
|
||||
return self.think_end_token_id in input_ids
|
||||
|
||||
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
|
||||
"""
|
||||
Extract the content after the end tokens
|
||||
"""
|
||||
if self.think_end_token_id not in input_ids[:-1]:
|
||||
return []
|
||||
else:
|
||||
return input_ids[input_ids.index(self.think_end_token_id) + 1:]
|
||||
|
||||
def extract_reasoning_content_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
Extract reasoning content from a delta message.
|
||||
Handles streaming output where previous + delta = current.
|
||||
Uses token IDs for faster processing.
|
||||
For text <think>abc</think>xyz:
|
||||
- 'abc' goes to reasoning_content
|
||||
- 'xyz' goes to content
|
||||
"""
|
||||
# Skip single special tokens
|
||||
if len(delta_token_ids) == 1 and (delta_token_ids[0] in [
|
||||
self.think_start_token_id, self.think_end_token_id
|
||||
]):
|
||||
return None
|
||||
|
||||
if self.think_start_token_id in previous_token_ids:
|
||||
if self.think_end_token_id in delta_token_ids:
|
||||
# <think> in previous, </think> in delta,
|
||||
# extract reasoning content
|
||||
end_index = delta_text.find(self.think_end_token)
|
||||
reasoning_content = delta_text[:end_index]
|
||||
content = delta_text[end_index + len(self.think_end_token):]
|
||||
return DeltaMessage(reasoning_content=reasoning_content,
|
||||
content=content if content else None)
|
||||
elif self.think_end_token_id in previous_token_ids:
|
||||
# <think> in previous, </think> in previous,
|
||||
# reasoning content continues
|
||||
return DeltaMessage(content=delta_text)
|
||||
else:
|
||||
# <think> in previous, no </think> in previous or delta,
|
||||
# reasoning content continues
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
elif self.think_start_token_id in delta_token_ids:
|
||||
if self.think_end_token_id in delta_token_ids:
|
||||
# <think> in delta, </think> in delta, extract reasoning content
|
||||
start_index = delta_text.find(self.think_start_token)
|
||||
end_index = delta_text.find(self.think_end_token)
|
||||
reasoning_content = delta_text[start_index +
|
||||
len(self.think_start_token
|
||||
):end_index]
|
||||
content = delta_text[end_index + len(self.think_end_token):]
|
||||
return DeltaMessage(reasoning_content=reasoning_content,
|
||||
content=content if content else None)
|
||||
else:
|
||||
# <think> in delta, no </think> in delta,
|
||||
# reasoning content continues
|
||||
return DeltaMessage(reasoning_content=delta_text)
|
||||
else:
|
||||
# thinking is disabled, just content
|
||||
return DeltaMessage(content=delta_text)
|
||||
|
||||
def extract_reasoning_content(
|
||||
self, model_output: str, request: ChatCompletionRequest
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Extract reasoning content from the model output.
|
||||
|
||||
For text <think>abc</think>xyz:
|
||||
- 'abc' goes to reasoning_content
|
||||
- 'xyz' goes to content
|
||||
|
||||
Returns:
|
||||
tuple[Optional[str], Optional[str]]: reasoning content and content
|
||||
"""
|
||||
|
||||
# Check if the model output contains the <think> and </think> tokens.
|
||||
if (self.think_start_token not in model_output
|
||||
or self.think_end_token not in model_output):
|
||||
return None, model_output
|
||||
# Check if the <think> is present in the model output, remove it
|
||||
# if it is present.
|
||||
model_output_parts = model_output.partition(self.think_start_token)
|
||||
model_output = model_output_parts[2] if model_output_parts[
|
||||
1] else model_output_parts[0]
|
||||
# Check if the model output contains the </think> tokens.
|
||||
# If the end token is not found, return the model output as is.
|
||||
if self.think_end_token not in model_output:
|
||||
return None, model_output
|
||||
|
||||
# Extract reasoning content from the model output.
|
||||
reasoning_content, _, content = model_output.partition(
|
||||
self.think_end_token)
|
||||
|
||||
final_content = content or None
|
||||
return reasoning_content, final_content
|
||||
@@ -77,7 +77,8 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
"mlp_speculator",
|
||||
"eagle",
|
||||
"deepseek_mtp",
|
||||
"mimo_mtp")) \
|
||||
"glm4_moe_mtp",
|
||||
"mimo_mtp")) \
|
||||
else {"return_hidden_states": True}
|
||||
|
||||
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||
|
||||
Reference in New Issue
Block a user