[Core] Async Scheduling X Spec Decoding Compatibility (#24799)
Signed-off-by: Ronald1995 <ronaldautomobile@163.com> Signed-off-by: Nick Hill <nhill@redhat.com> Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
|
||||
import ast
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal, get_args
|
||||
|
||||
from pydantic import Field, SkipValidation, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
@@ -29,31 +29,25 @@ else:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
SpeculativeMethod = Literal[
|
||||
"ngram",
|
||||
"eagle",
|
||||
"eagle3",
|
||||
"medusa",
|
||||
"mlp_speculator",
|
||||
"draft_model",
|
||||
"deepseek_mtp",
|
||||
"ernie_mtp",
|
||||
"qwen3_next_mtp",
|
||||
"mimo_mtp",
|
||||
"longcat_flash_mtp",
|
||||
"pangu_ultra_moe_mtp",
|
||||
"mtp",
|
||||
"suffix",
|
||||
]
|
||||
MTP_MODEL_TYPES = (
|
||||
MTPModelTypes = Literal[
|
||||
"deepseek_mtp",
|
||||
"mimo_mtp",
|
||||
"glm4_moe_mtp",
|
||||
"ernie_mtp",
|
||||
"qwen3_next_mtp",
|
||||
"longcat_flash_mtp",
|
||||
"mtp",
|
||||
"pangu_ultra_moe_mtp",
|
||||
)
|
||||
]
|
||||
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
|
||||
SpeculativeMethod = Literal[
|
||||
"ngram",
|
||||
"medusa",
|
||||
"mlp_speculator",
|
||||
"draft_model",
|
||||
"suffix",
|
||||
EagleModelTypes,
|
||||
]
|
||||
|
||||
|
||||
@config
|
||||
@@ -244,7 +238,7 @@ class SpeculativeConfig:
|
||||
# can not be detected, it will be considered as the "draft_model" by
|
||||
# default.
|
||||
|
||||
if self.method in MTP_MODEL_TYPES:
|
||||
if self.method in get_args(MTPModelTypes) and self.method != "mtp":
|
||||
logger.warning(
|
||||
"method `%s` is deprecated and replaced with mtp.", self.method
|
||||
)
|
||||
@@ -361,7 +355,9 @@ class SpeculativeConfig:
|
||||
self.method = "medusa"
|
||||
elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
|
||||
self.method = "mlp_speculator"
|
||||
elif self.draft_model_config.hf_config.model_type in MTP_MODEL_TYPES:
|
||||
elif self.draft_model_config.hf_config.model_type in get_args(
|
||||
MTPModelTypes
|
||||
):
|
||||
self.method = "mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
|
||||
@@ -14,13 +14,14 @@ from dataclasses import replace
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, TypeVar, get_args
|
||||
|
||||
import torch
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.speculative import EagleModelTypes
|
||||
from vllm.logger import enable_trace_function_call, init_logger
|
||||
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
|
||||
from vllm.utils import random_uuid
|
||||
@@ -374,10 +375,22 @@ class VllmConfig:
|
||||
"Async scheduling is not yet compatible with "
|
||||
"pipeline_parallel_size > 1."
|
||||
)
|
||||
# Currently, async scheduling only support eagle speculative
|
||||
# decoding.
|
||||
if self.speculative_config is not None:
|
||||
raise ValueError(
|
||||
"Async scheduling is not yet compatible with speculative decoding."
|
||||
)
|
||||
if self.speculative_config.method not in get_args(EagleModelTypes):
|
||||
raise ValueError(
|
||||
"Currently, async scheduling is only supported "
|
||||
"with EAGLE/MTP kind of speculative decoding"
|
||||
)
|
||||
if self.speculative_config.disable_padded_drafter_batch:
|
||||
raise ValueError(
|
||||
"async scheduling for EAGLE/MTP kind of speculative "
|
||||
"decoding is enabled, but disable_padded_drafter_batch=True "
|
||||
"disable_padded_drafter_batch=True is not supported for "
|
||||
"this situation now. please set "
|
||||
"disable_padded_drafter_batch=Fasle"
|
||||
)
|
||||
if not executor_supports_async_sched:
|
||||
raise ValueError(
|
||||
"Currently, async scheduling only supports `mp`, `uni`, or "
|
||||
|
||||
Reference in New Issue
Block a user