[Spec Decode] Integrate Suffix Decoding from Arctic Inference (#25784)

Co-authored-by: Aurick Qiao <aurick.qiao@snowflake.com>
This commit is contained in:
Aurick Qiao
2025-11-03 09:23:31 -08:00
committed by GitHub
parent 4bc400f47e
commit 2c19d96777
8 changed files with 304 additions and 11 deletions

View File

@@ -12,7 +12,7 @@ from typing_extensions import Self
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.utils.import_utils import LazyLoader
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
if TYPE_CHECKING:
from transformers import PretrainedConfig
@@ -42,6 +42,7 @@ SpeculativeMethod = Literal[
"mimo_mtp",
"longcat_flash_mtp",
"mtp",
"suffix",
]
MTP_MODEL_TYPES = (
"deepseek_mtp",
@@ -129,6 +130,27 @@ class SpeculativeConfig:
draft_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore
"""The parallel configuration for the draft model initialized internal."""
# Suffix decoding configuration
suffix_decoding_max_tree_depth: int = 24
"""The maximum depth of the suffix decoding global and prompt trees. The
tree depth limits the sum of the prefix match and speculation lengths."""
suffix_decoding_max_cached_requests: int = 10000
"""The maximum number of requests to cache in the global suffix tree. If
exceeded, will trigger eviction in FIFO order. If set to 0, the global
suffix tree is disabled and past responses are not cached (prompt trees
are still used)."""
suffix_decoding_max_spec_factor: float = 1.0
"""The maximum spec factor for suffix decoding. The spec factor controls
speculation lengths based on the prefix match length: max_spec_tokens =
max_spec_factor * prefix_match_length."""
suffix_decoding_min_token_prob: float = 0.1
"""The minimum token probability for suffix decoding. Will only speculate
tokens with estimated probability (based on frequency counts) greater than
or equal to this value."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
@@ -235,6 +257,8 @@ class SpeculativeConfig:
self.quantization = self.target_model_config.quantization
elif self.method in ("ngram", "[ngram]"):
self.model = "ngram"
elif self.method == "suffix":
self.model = "suffix"
else:
raise ValueError(
"num_speculative_tokens was provided but without speculative model."
@@ -282,6 +306,8 @@ class SpeculativeConfig:
# draft related config as None here.
self.draft_model_config = self.target_model_config
self.draft_parallel_config = self.target_parallel_config
elif self.method == "suffix":
self._validate_suffix_decoding()
else:
self.prompt_lookup_max = 0
self.prompt_lookup_min = 0
@@ -430,6 +456,42 @@ class SpeculativeConfig:
)
return self
def _validate_suffix_decoding(self):
if not has_arctic_inference():
raise ImportError(
"Arctic Inference is required for suffix decoding. "
"Install via `pip install arctic-inference==0.1.0`."
)
if self.num_speculative_tokens is None:
# Suffix decoding decides the actual number of speculative tokens
# dynamically and treats num_speculative_tokens as a maximum limit.
self.num_speculative_tokens = self.suffix_decoding_max_tree_depth
logger.warning(
"Defaulted num_speculative_tokens to %s for suffix decoding.",
self.num_speculative_tokens,
)
# Validate values
if self.suffix_decoding_max_tree_depth < 1:
raise ValueError(
f"suffix_decoding_max_tree_depth="
f"{self.suffix_decoding_max_tree_depth} must be >= 1"
)
if self.suffix_decoding_max_cached_requests < 0:
raise ValueError(
f"suffix_decoding_max_cached_requests="
f"{self.suffix_decoding_max_cached_requests} must be >= 0"
)
if self.suffix_decoding_max_spec_factor < 0:
raise ValueError(
f"suffix_decoding_max_spec_factor="
f"{self.suffix_decoding_max_spec_factor} must be >= 0"
)
if not 0 <= self.suffix_decoding_min_token_prob <= 1:
raise ValueError(
f"suffix_decoding_min_token_prob="
f"{self.suffix_decoding_min_token_prob} must be in [0, 1]"
)
@staticmethod
def _maybe_override_draft_max_model_len(
speculative_max_model_len: int | None,
@@ -582,6 +644,6 @@ class SpeculativeConfig:
def __repr__(self) -> str:
method = self.method
model = None if method == "ngram" else self.draft_model_config.model
model = None if method in ("ngram", "suffix") else self.draft_model_config.model
num_spec_tokens = self.num_speculative_tokens
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"