[Model Runner V2] Support Eagle3 (no CUDA graph) (#35029)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-02-21 12:55:24 -08:00
committed by GitHub
parent 965fe45935
commit a4047d4ea9
7 changed files with 169 additions and 49 deletions

View File

@@ -66,6 +66,9 @@ from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
from vllm.v1.worker.gpu.sample.sampler import Sampler
from vllm.v1.worker.gpu.spec_decode import init_speculator
from vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils import (
set_eagle3_aux_hidden_state_layers,
)
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler
from vllm.v1.worker.gpu.states import RequestState
@@ -133,14 +136,42 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.output_copy_stream = torch.cuda.Stream(self.device)
self.output_copy_event = torch.cuda.Event()
# Pipeline parallelism.
self.pp_size = self.parallel_config.pipeline_parallel_size
self.use_pp = self.pp_size > 1
if self.use_pp:
self.is_first_pp_rank = get_pp_group().is_first_rank
self.is_last_pp_rank = get_pp_group().is_last_rank
else:
self.is_first_pp_rank = True
self.is_last_pp_rank = True
# Decode context parallelism.
self.dcp_size = self.parallel_config.decode_context_parallel_size
self.use_dcp = self.dcp_size > 1
self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0
self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size
self.speculator = None
self.use_aux_hidden_state_outputs = False
if self.speculative_config is not None:
self.do_spec_decode = True
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
self.speculator = init_speculator(self.vllm_config, self.device)
if self.is_last_pp_rank:
self.speculator = init_speculator(self.vllm_config, self.device)
if self.speculative_config.method == "eagle3":
# EAGLE3 may require auxiliary hidden states from target model outputs.
self.use_aux_hidden_state_outputs = True
if self.pp_size > 1:
raise ValueError("EAGLE3 with pipeline parallel is not supported.")
else:
self.do_spec_decode = False
self.num_speculative_steps = 0
self.speculator = None
# Draft tokens propagation - for spec-dec + struct outputs.
self.draft_tokens_handler = DraftTokensHandler(self.device)
self.req_states = RequestState(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
@@ -176,28 +207,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
# LoRA-related workers.
self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)
# Draft tokens propagation - for spec-dec + struct outputs.
self.draft_tokens_handler = DraftTokensHandler(self.device)
# KV Connector if configured.
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
# Pipeline parallelism.
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
if self.use_pp:
self.is_first_pp_rank = get_pp_group().is_first_rank
self.is_last_pp_rank = get_pp_group().is_last_rank
else:
self.is_first_pp_rank = True
self.is_last_pp_rank = True
# Decode context parallelism.
self.dcp_size = self.parallel_config.decode_context_parallel_size
self.use_dcp = self.dcp_size > 1
self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0
self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size
def update_max_model_len(self, max_model_len: int) -> None:
self.max_model_len = max_model_len
self.req_states.max_model_len = max_model_len
@@ -220,7 +232,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.model = self.load_lora_model(
self.model, self.vllm_config, self.device
)
if self.do_spec_decode:
if self.use_aux_hidden_state_outputs:
assert self.speculative_config is not None
set_eagle3_aux_hidden_state_layers(self.model, self.speculative_config)
if self.speculator is not None:
self.speculator.load_model(self.model)
time_after_load = time.perf_counter()
@@ -271,7 +287,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.kv_cache_config, self.vllm_config, self.device
)
check_attention_cp_compatibility(self.vllm_config)
if self.do_spec_decode:
if self.speculator is not None:
# HACK(woosuk)
self.speculator.set_attn(
self.kv_cache_config,
@@ -359,7 +375,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return None, None
assert self.execute_model_state is not None
hidden_states, input_batch, _ = self.execute_model_state
hidden_states, _, input_batch, _ = self.execute_model_state
assert hidden_states is not None # Last PP rank always has hidden_states
sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states
@@ -399,7 +415,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert sample_hidden_states is not None
self._dummy_sampler_run(sample_hidden_states)
if self.do_spec_decode:
if self.speculator is not None:
num_tokens_across_dp = make_num_tokens_across_dp(
self.parallel_config.data_parallel_size, self.max_num_tokens
)
@@ -465,7 +481,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
kv_cache_config=self.kv_cache_config,
has_lora=self.lora_config is not None,
)
if self.do_spec_decode:
if self.speculator is not None:
self.speculator.capture_model()
end_time = time.perf_counter()
@@ -964,9 +980,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# NOTE(woosuk): Here, we don't need to pass the input tensors,
# because they are already copied to the CUDA graph input buffers.
self.kv_connector.pre_forward(scheduler_output)
hidden_states = self.cudagraph_manager.run_fullgraph(
model_output = self.cudagraph_manager.run_fullgraph(
input_batch.num_tokens_after_padding
)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
else:
# For piecewise and eager mode, just call model().
positions = input_batch.positions
@@ -998,12 +1019,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
slot_mapping=input_batch.slot_mappings,
):
self.kv_connector.pre_forward(scheduler_output)
hidden_states = self.model(
model_output = self.model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
intermediate_tensors=intermediate_tensors,
)
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
else:
hidden_states = model_output
aux_hidden_states = None
kv_connector_output = self.kv_connector.post_forward(scheduler_output)
@@ -1011,12 +1037,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Non-last PP rank: return IntermediateTensors for sending.
assert isinstance(hidden_states, IntermediateTensors)
hidden_states.kv_connector_output = kv_connector_output
self.execute_model_state = (None, input_batch, kv_connector_output)
self.execute_model_state = (None, None, input_batch, kv_connector_output)
return hidden_states
assert isinstance(hidden_states, torch.Tensor)
# Last rank (or no PP): hidden_states is a tensor for sampling.
self.execute_model_state = (hidden_states, input_batch, kv_connector_output)
assert isinstance(hidden_states, torch.Tensor)
self.execute_model_state = (
hidden_states,
aux_hidden_states,
input_batch,
kv_connector_output,
)
return None
@torch.inference_mode()
@@ -1024,7 +1055,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self, grammar_output: GrammarOutput | None
) -> AsyncOutput | ModelRunnerOutput | None:
assert self.execute_model_state is not None
hidden_states, input_batch, kv_connector_output = self.execute_model_state
hidden_states, aux_hidden_states, input_batch, kv_connector_output = (
self.execute_model_state
)
self.execute_model_state = None # type: ignore
if not self.is_last_pp_rank:
@@ -1084,11 +1117,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.postprocess(
input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
)
if self.do_spec_decode:
if self.speculator is not None:
draft_tokens = self.propose_draft(
input_batch,
hidden_states,
None, # aux_hidden_states
aux_hidden_states,
num_sampled,
num_rejected,
)

View File

@@ -9,7 +9,7 @@ def init_speculator(vllm_config: VllmConfig, device: torch.device):
speculative_config = vllm_config.speculative_config
assert speculative_config is not None
if speculative_config.use_eagle():
from vllm.v1.worker.gpu.spec_decode.eagle import EagleSpeculator
from vllm.v1.worker.gpu.spec_decode.eagle.speculator import EagleSpeculator
return EagleSpeculator(vllm_config, device)
raise NotImplementedError(f"{speculative_config.method} is not supported yet.")

View File

@@ -0,0 +1,46 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import cast
import torch.nn as nn
from vllm.config import SpeculativeConfig
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import SupportsEagle3, supports_eagle3
logger = init_logger(__name__)
def set_eagle3_aux_hidden_state_layers(
model: nn.Module,
spec_config: SpeculativeConfig,
) -> None:
if not supports_eagle3(model):
raise RuntimeError("Model does not support EAGLE3 interface")
# mypy may infer the class-level overload for supports_eagle3.
# Narrow explicitly to the runtime protocol instance.
if isinstance(model, type):
raise RuntimeError("Expected model instance for EAGLE3 configuration")
eagle3_model = cast(SupportsEagle3, model)
aux_layers = get_eagle3_aux_layers_from_config(spec_config)
if aux_layers:
logger.info("Using Eagle3 auxiliary layers from config: %s", aux_layers)
else:
aux_layers = eagle3_model.get_eagle3_aux_hidden_state_layers()
logger.info("Using Eagle3 auxiliary layers from model: %s", aux_layers)
eagle3_model.set_aux_hidden_state_layers(aux_layers)
def get_eagle3_aux_layers_from_config(
spec_config: SpeculativeConfig,
) -> tuple[int, ...] | None:
if not (spec_config and spec_config.draft_model_config):
return None
hf_config = spec_config.draft_model_config.hf_config
if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"):
return None
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
if layer_ids and isinstance(layer_ids, (list, tuple)):
return tuple(layer_ids)
return None

View File

@@ -9,7 +9,6 @@ from vllm.config import VllmConfig
from vllm.config.compilation import CUDAGraphMode
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import AttentionMetadataBuilder
from vllm.v1.kv_cache_interface import KVCacheConfig
@@ -20,7 +19,8 @@ from vllm.v1.worker.gpu.attn_utils import (
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.spec_decode.eagle_cudagraph import EagleCudaGraphManager
from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager
from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model
logger = init_logger(__name__)
@@ -73,18 +73,7 @@ class EagleSpeculator:
self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device)
def load_model(self, target_model: nn.Module) -> None:
from vllm.compilation.backends import set_model_tag
with set_model_tag("eagle_head"):
self.model = get_model(
vllm_config=self.vllm_config, model_config=self.draft_model_config
)
share_lm_head = True
if share_lm_head and hasattr(target_model, "lm_head"):
if hasattr(self.model, "lm_head"):
del self.model.lm_head
self.model.lm_head = target_model.lm_head
self.model = load_eagle_model(target_model, self.vllm_config)
def set_attn(
self,

View File

@@ -0,0 +1,52 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch.nn as nn
from vllm.config import VllmConfig
from vllm.model_executor.model_loader import get_model
def load_eagle_model(target_model: nn.Module, vllm_config: VllmConfig) -> nn.Module:
from vllm.compilation.backends import set_model_tag
speculative_config = vllm_config.speculative_config
assert speculative_config is not None
draft_model_config = speculative_config.draft_model_config
with set_model_tag("eagle_head"):
eagle_model = get_model(
vllm_config=vllm_config, model_config=draft_model_config
)
# Share target embeddings when the draft checkpoint does not include
# its own vocab embedding table.
share_embeddings = True
if hasattr(eagle_model, "has_own_embed_tokens"):
share_embeddings = not eagle_model.has_own_embed_tokens
if share_embeddings:
target_language_model = (
target_model.get_language_model()
if hasattr(target_model, "get_language_model")
else target_model
)
inner_model = getattr(target_language_model, "model", None)
target_embed_tokens = None
if inner_model is not None:
if hasattr(inner_model, "embed_tokens"):
target_embed_tokens = inner_model.embed_tokens
elif hasattr(inner_model, "embedding"):
target_embed_tokens = inner_model.embedding
if target_embed_tokens is not None and hasattr(eagle_model, "model"):
if hasattr(eagle_model.model, "embed_tokens"):
del eagle_model.model.embed_tokens
eagle_model.model.embed_tokens = target_embed_tokens
# Only share target lm_head when the draft model does not own one.
share_lm_head = True
if hasattr(eagle_model, "has_own_lm_head"):
share_lm_head = not eagle_model.has_own_lm_head
if share_lm_head and hasattr(target_model, "lm_head"):
if hasattr(eagle_model, "lm_head"):
del eagle_model.lm_head
eagle_model.lm_head = target_model.lm_head
return eagle_model