[Core] Async Scheduling X Spec Decoding Compatibility (#24799)

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
Ronald
2025-11-18 04:16:20 +08:00
committed by GitHub
parent f8b19c0ffd
commit d8874c61a5
11 changed files with 314 additions and 98 deletions

View File

@@ -3,7 +3,7 @@
import ast
import hashlib
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, get_args
from pydantic import Field, SkipValidation, model_validator
from pydantic.dataclasses import dataclass
@@ -29,31 +29,25 @@ else:
logger = init_logger(__name__)
SpeculativeMethod = Literal[
"ngram",
"eagle",
"eagle3",
"medusa",
"mlp_speculator",
"draft_model",
"deepseek_mtp",
"ernie_mtp",
"qwen3_next_mtp",
"mimo_mtp",
"longcat_flash_mtp",
"pangu_ultra_moe_mtp",
"mtp",
"suffix",
]
MTP_MODEL_TYPES = (
MTPModelTypes = Literal[
"deepseek_mtp",
"mimo_mtp",
"glm4_moe_mtp",
"ernie_mtp",
"qwen3_next_mtp",
"longcat_flash_mtp",
"mtp",
"pangu_ultra_moe_mtp",
)
]
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
SpeculativeMethod = Literal[
"ngram",
"medusa",
"mlp_speculator",
"draft_model",
"suffix",
EagleModelTypes,
]
@config
@@ -244,7 +238,7 @@ class SpeculativeConfig:
# can not be detected, it will be considered as the "draft_model" by
# default.
if self.method in MTP_MODEL_TYPES:
if self.method in get_args(MTPModelTypes) and self.method != "mtp":
logger.warning(
"method `%s` is deprecated and replaced with mtp.", self.method
)
@@ -361,7 +355,9 @@ class SpeculativeConfig:
self.method = "medusa"
elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
self.method = "mlp_speculator"
elif self.draft_model_config.hf_config.model_type in MTP_MODEL_TYPES:
elif self.draft_model_config.hf_config.model_type in get_args(
MTPModelTypes
):
self.method = "mtp"
if self.num_speculative_tokens > 1:
logger.warning(