[Core] Async Scheduling X Spec Decoding Compatibility (#24799)
Signed-off-by: Ronald1995 <ronaldautomobile@163.com> Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
|
||||
import ast
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal, get_args
|
||||
|
||||
from pydantic import Field, SkipValidation, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
@@ -29,31 +29,25 @@ else:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
SpeculativeMethod = Literal[
|
||||
"ngram",
|
||||
"eagle",
|
||||
"eagle3",
|
||||
"medusa",
|
||||
"mlp_speculator",
|
||||
"draft_model",
|
||||
"deepseek_mtp",
|
||||
"ernie_mtp",
|
||||
"qwen3_next_mtp",
|
||||
"mimo_mtp",
|
||||
"longcat_flash_mtp",
|
||||
"pangu_ultra_moe_mtp",
|
||||
"mtp",
|
||||
"suffix",
|
||||
]
|
||||
MTP_MODEL_TYPES = (
|
||||
MTPModelTypes = Literal[
|
||||
"deepseek_mtp",
|
||||
"mimo_mtp",
|
||||
"glm4_moe_mtp",
|
||||
"ernie_mtp",
|
||||
"qwen3_next_mtp",
|
||||
"longcat_flash_mtp",
|
||||
"mtp",
|
||||
"pangu_ultra_moe_mtp",
|
||||
)
|
||||
]
|
||||
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
|
||||
SpeculativeMethod = Literal[
|
||||
"ngram",
|
||||
"medusa",
|
||||
"mlp_speculator",
|
||||
"draft_model",
|
||||
"suffix",
|
||||
EagleModelTypes,
|
||||
]
|
||||
|
||||
|
||||
@config
|
||||
@@ -244,7 +238,7 @@ class SpeculativeConfig:
|
||||
# can not be detected, it will be considered as the "draft_model" by
|
||||
# default.
|
||||
|
||||
if self.method in MTP_MODEL_TYPES:
|
||||
if self.method in get_args(MTPModelTypes) and self.method != "mtp":
|
||||
logger.warning(
|
||||
"method `%s` is deprecated and replaced with mtp.", self.method
|
||||
)
|
||||
@@ -361,7 +355,9 @@ class SpeculativeConfig:
|
||||
self.method = "medusa"
|
||||
elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
|
||||
self.method = "mlp_speculator"
|
||||
elif self.draft_model_config.hf_config.model_type in MTP_MODEL_TYPES:
|
||||
elif self.draft_model_config.hf_config.model_type in get_args(
|
||||
MTPModelTypes
|
||||
):
|
||||
self.method = "mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
|
||||
Reference in New Issue
Block a user