[Core] Async Scheduling X Spec Decoding Compatibility (#24799)

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
Signed-off-by: Nick Hill <nhill@redhat.com>
Signed-off-by: Benjamin Chislett <chislett.ben@gmail.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
Co-authored-by: Benjamin Chislett <chislett.ben@gmail.com>
This commit is contained in:
Ronald
2025-11-18 04:16:20 +08:00
committed by GitHub
parent f8b19c0ffd
commit d8874c61a5
11 changed files with 314 additions and 98 deletions

View File

@@ -3,7 +3,7 @@
import ast
import hashlib
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, get_args
from pydantic import Field, SkipValidation, model_validator
from pydantic.dataclasses import dataclass
@@ -29,31 +29,25 @@ else:
logger = init_logger(__name__)
SpeculativeMethod = Literal[
"ngram",
"eagle",
"eagle3",
"medusa",
"mlp_speculator",
"draft_model",
"deepseek_mtp",
"ernie_mtp",
"qwen3_next_mtp",
"mimo_mtp",
"longcat_flash_mtp",
"pangu_ultra_moe_mtp",
"mtp",
"suffix",
]
MTP_MODEL_TYPES = (
MTPModelTypes = Literal[
"deepseek_mtp",
"mimo_mtp",
"glm4_moe_mtp",
"ernie_mtp",
"qwen3_next_mtp",
"longcat_flash_mtp",
"mtp",
"pangu_ultra_moe_mtp",
)
]
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
SpeculativeMethod = Literal[
"ngram",
"medusa",
"mlp_speculator",
"draft_model",
"suffix",
EagleModelTypes,
]
@config
@@ -244,7 +238,7 @@ class SpeculativeConfig:
# can not be detected, it will be considered as the "draft_model" by
# default.
if self.method in MTP_MODEL_TYPES:
if self.method in get_args(MTPModelTypes) and self.method != "mtp":
logger.warning(
"method `%s` is deprecated and replaced with mtp.", self.method
)
@@ -361,7 +355,9 @@ class SpeculativeConfig:
self.method = "medusa"
elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
self.method = "mlp_speculator"
elif self.draft_model_config.hf_config.model_type in MTP_MODEL_TYPES:
elif self.draft_model_config.hf_config.model_type in get_args(
MTPModelTypes
):
self.method = "mtp"
if self.num_speculative_tokens > 1:
logger.warning(

View File

@@ -14,13 +14,14 @@ from dataclasses import replace
from datetime import datetime
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar
from typing import TYPE_CHECKING, Any, TypeVar, get_args
import torch
from pydantic import ConfigDict, Field, model_validator
from pydantic.dataclasses import dataclass
import vllm.envs as envs
from vllm.config.speculative import EagleModelTypes
from vllm.logger import enable_trace_function_call, init_logger
from vllm.transformers_utils.runai_utils import is_runai_obj_uri
from vllm.utils import random_uuid
@@ -374,10 +375,22 @@ class VllmConfig:
"Async scheduling is not yet compatible with "
"pipeline_parallel_size > 1."
)
# Currently, async scheduling only support eagle speculative
# decoding.
if self.speculative_config is not None:
raise ValueError(
"Async scheduling is not yet compatible with speculative decoding."
)
if self.speculative_config.method not in get_args(EagleModelTypes):
raise ValueError(
"Currently, async scheduling is only supported "
"with EAGLE/MTP kind of speculative decoding"
)
if self.speculative_config.disable_padded_drafter_batch:
raise ValueError(
"async scheduling for EAGLE/MTP kind of speculative "
"decoding is enabled, but disable_padded_drafter_batch=True "
"disable_padded_drafter_batch=True is not supported for "
"this situation now. please set "
"disable_padded_drafter_batch=Fasle"
)
if not executor_supports_async_sched:
raise ValueError(
"Currently, async scheduling only supports `mp`, `uni`, or "