[Model Runner V2] Support Eagle3 (no CUDA graph) (#35029)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -66,6 +66,9 @@ from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
|
||||
from vllm.v1.worker.gpu.sample.sampler import Sampler
|
||||
from vllm.v1.worker.gpu.spec_decode import init_speculator
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils import (
|
||||
set_eagle3_aux_hidden_state_layers,
|
||||
)
|
||||
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
|
||||
from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
@@ -133,14 +136,42 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.output_copy_stream = torch.cuda.Stream(self.device)
|
||||
self.output_copy_event = torch.cuda.Event()
|
||||
|
||||
# Pipeline parallelism.
|
||||
self.pp_size = self.parallel_config.pipeline_parallel_size
|
||||
self.use_pp = self.pp_size > 1
|
||||
if self.use_pp:
|
||||
self.is_first_pp_rank = get_pp_group().is_first_rank
|
||||
self.is_last_pp_rank = get_pp_group().is_last_rank
|
||||
else:
|
||||
self.is_first_pp_rank = True
|
||||
self.is_last_pp_rank = True
|
||||
|
||||
# Decode context parallelism.
|
||||
self.dcp_size = self.parallel_config.decode_context_parallel_size
|
||||
self.use_dcp = self.dcp_size > 1
|
||||
self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0
|
||||
self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size
|
||||
|
||||
self.speculator = None
|
||||
self.use_aux_hidden_state_outputs = False
|
||||
if self.speculative_config is not None:
|
||||
self.do_spec_decode = True
|
||||
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
|
||||
self.speculator = init_speculator(self.vllm_config, self.device)
|
||||
if self.is_last_pp_rank:
|
||||
self.speculator = init_speculator(self.vllm_config, self.device)
|
||||
|
||||
if self.speculative_config.method == "eagle3":
|
||||
# EAGLE3 may require auxiliary hidden states from target model outputs.
|
||||
self.use_aux_hidden_state_outputs = True
|
||||
if self.pp_size > 1:
|
||||
raise ValueError("EAGLE3 with pipeline parallel is not supported.")
|
||||
else:
|
||||
self.do_spec_decode = False
|
||||
self.num_speculative_steps = 0
|
||||
self.speculator = None
|
||||
|
||||
# Draft tokens propagation - for spec-dec + struct outputs.
|
||||
self.draft_tokens_handler = DraftTokensHandler(self.device)
|
||||
|
||||
self.req_states = RequestState(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
@@ -176,28 +207,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
)
|
||||
# LoRA-related workers.
|
||||
self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)
|
||||
|
||||
# Draft tokens propagation - for spec-dec + struct outputs.
|
||||
self.draft_tokens_handler = DraftTokensHandler(self.device)
|
||||
|
||||
# KV Connector if configured.
|
||||
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
|
||||
|
||||
# Pipeline parallelism.
|
||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||
if self.use_pp:
|
||||
self.is_first_pp_rank = get_pp_group().is_first_rank
|
||||
self.is_last_pp_rank = get_pp_group().is_last_rank
|
||||
else:
|
||||
self.is_first_pp_rank = True
|
||||
self.is_last_pp_rank = True
|
||||
|
||||
# Decode context parallelism.
|
||||
self.dcp_size = self.parallel_config.decode_context_parallel_size
|
||||
self.use_dcp = self.dcp_size > 1
|
||||
self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0
|
||||
self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size
|
||||
|
||||
def update_max_model_len(self, max_model_len: int) -> None:
|
||||
self.max_model_len = max_model_len
|
||||
self.req_states.max_model_len = max_model_len
|
||||
@@ -220,7 +232,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.model = self.load_lora_model(
|
||||
self.model, self.vllm_config, self.device
|
||||
)
|
||||
if self.do_spec_decode:
|
||||
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
assert self.speculative_config is not None
|
||||
set_eagle3_aux_hidden_state_layers(self.model, self.speculative_config)
|
||||
if self.speculator is not None:
|
||||
self.speculator.load_model(self.model)
|
||||
time_after_load = time.perf_counter()
|
||||
|
||||
@@ -271,7 +287,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.kv_cache_config, self.vllm_config, self.device
|
||||
)
|
||||
check_attention_cp_compatibility(self.vllm_config)
|
||||
if self.do_spec_decode:
|
||||
if self.speculator is not None:
|
||||
# HACK(woosuk)
|
||||
self.speculator.set_attn(
|
||||
self.kv_cache_config,
|
||||
@@ -359,7 +375,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
return None, None
|
||||
|
||||
assert self.execute_model_state is not None
|
||||
hidden_states, input_batch, _ = self.execute_model_state
|
||||
hidden_states, _, input_batch, _ = self.execute_model_state
|
||||
assert hidden_states is not None # Last PP rank always has hidden_states
|
||||
sample_hidden_states = hidden_states[input_batch.logits_indices]
|
||||
return hidden_states, sample_hidden_states
|
||||
@@ -399,7 +415,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
assert sample_hidden_states is not None
|
||||
self._dummy_sampler_run(sample_hidden_states)
|
||||
|
||||
if self.do_spec_decode:
|
||||
if self.speculator is not None:
|
||||
num_tokens_across_dp = make_num_tokens_across_dp(
|
||||
self.parallel_config.data_parallel_size, self.max_num_tokens
|
||||
)
|
||||
@@ -465,7 +481,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
kv_cache_config=self.kv_cache_config,
|
||||
has_lora=self.lora_config is not None,
|
||||
)
|
||||
if self.do_spec_decode:
|
||||
if self.speculator is not None:
|
||||
self.speculator.capture_model()
|
||||
|
||||
end_time = time.perf_counter()
|
||||
@@ -964,9 +980,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# NOTE(woosuk): Here, we don't need to pass the input tensors,
|
||||
# because they are already copied to the CUDA graph input buffers.
|
||||
self.kv_connector.pre_forward(scheduler_output)
|
||||
hidden_states = self.cudagraph_manager.run_fullgraph(
|
||||
model_output = self.cudagraph_manager.run_fullgraph(
|
||||
input_batch.num_tokens_after_padding
|
||||
)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
hidden_states = model_output
|
||||
aux_hidden_states = None
|
||||
else:
|
||||
# For piecewise and eager mode, just call model().
|
||||
positions = input_batch.positions
|
||||
@@ -998,12 +1019,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
slot_mapping=input_batch.slot_mappings,
|
||||
):
|
||||
self.kv_connector.pre_forward(scheduler_output)
|
||||
hidden_states = self.model(
|
||||
model_output = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
hidden_states, aux_hidden_states = model_output
|
||||
else:
|
||||
hidden_states = model_output
|
||||
aux_hidden_states = None
|
||||
|
||||
kv_connector_output = self.kv_connector.post_forward(scheduler_output)
|
||||
|
||||
@@ -1011,12 +1037,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Non-last PP rank: return IntermediateTensors for sending.
|
||||
assert isinstance(hidden_states, IntermediateTensors)
|
||||
hidden_states.kv_connector_output = kv_connector_output
|
||||
self.execute_model_state = (None, input_batch, kv_connector_output)
|
||||
self.execute_model_state = (None, None, input_batch, kv_connector_output)
|
||||
return hidden_states
|
||||
|
||||
assert isinstance(hidden_states, torch.Tensor)
|
||||
# Last rank (or no PP): hidden_states is a tensor for sampling.
|
||||
self.execute_model_state = (hidden_states, input_batch, kv_connector_output)
|
||||
assert isinstance(hidden_states, torch.Tensor)
|
||||
self.execute_model_state = (
|
||||
hidden_states,
|
||||
aux_hidden_states,
|
||||
input_batch,
|
||||
kv_connector_output,
|
||||
)
|
||||
return None
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -1024,7 +1055,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self, grammar_output: GrammarOutput | None
|
||||
) -> AsyncOutput | ModelRunnerOutput | None:
|
||||
assert self.execute_model_state is not None
|
||||
hidden_states, input_batch, kv_connector_output = self.execute_model_state
|
||||
hidden_states, aux_hidden_states, input_batch, kv_connector_output = (
|
||||
self.execute_model_state
|
||||
)
|
||||
self.execute_model_state = None # type: ignore
|
||||
|
||||
if not self.is_last_pp_rank:
|
||||
@@ -1084,11 +1117,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.postprocess(
|
||||
input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
|
||||
)
|
||||
if self.do_spec_decode:
|
||||
if self.speculator is not None:
|
||||
draft_tokens = self.propose_draft(
|
||||
input_batch,
|
||||
hidden_states,
|
||||
None, # aux_hidden_states
|
||||
aux_hidden_states,
|
||||
num_sampled,
|
||||
num_rejected,
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ def init_speculator(vllm_config: VllmConfig, device: torch.device):
|
||||
speculative_config = vllm_config.speculative_config
|
||||
assert speculative_config is not None
|
||||
if speculative_config.use_eagle():
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle import EagleSpeculator
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle.speculator import EagleSpeculator
|
||||
|
||||
return EagleSpeculator(vllm_config, device)
|
||||
raise NotImplementedError(f"{speculative_config.method} is not supported yet.")
|
||||
|
||||
0
vllm/v1/worker/gpu/spec_decode/eagle/__init__.py
Normal file
0
vllm/v1/worker/gpu/spec_decode/eagle/__init__.py
Normal file
46
vllm/v1/worker/gpu/spec_decode/eagle/eagle3_utils.py
Normal file
46
vllm/v1/worker/gpu/spec_decode/eagle/eagle3_utils.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import cast
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import SpeculativeConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import SupportsEagle3, supports_eagle3
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def set_eagle3_aux_hidden_state_layers(
|
||||
model: nn.Module,
|
||||
spec_config: SpeculativeConfig,
|
||||
) -> None:
|
||||
if not supports_eagle3(model):
|
||||
raise RuntimeError("Model does not support EAGLE3 interface")
|
||||
# mypy may infer the class-level overload for supports_eagle3.
|
||||
# Narrow explicitly to the runtime protocol instance.
|
||||
if isinstance(model, type):
|
||||
raise RuntimeError("Expected model instance for EAGLE3 configuration")
|
||||
eagle3_model = cast(SupportsEagle3, model)
|
||||
|
||||
aux_layers = get_eagle3_aux_layers_from_config(spec_config)
|
||||
if aux_layers:
|
||||
logger.info("Using Eagle3 auxiliary layers from config: %s", aux_layers)
|
||||
else:
|
||||
aux_layers = eagle3_model.get_eagle3_aux_hidden_state_layers()
|
||||
logger.info("Using Eagle3 auxiliary layers from model: %s", aux_layers)
|
||||
eagle3_model.set_aux_hidden_state_layers(aux_layers)
|
||||
|
||||
|
||||
def get_eagle3_aux_layers_from_config(
|
||||
spec_config: SpeculativeConfig,
|
||||
) -> tuple[int, ...] | None:
|
||||
if not (spec_config and spec_config.draft_model_config):
|
||||
return None
|
||||
hf_config = spec_config.draft_model_config.hf_config
|
||||
if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"):
|
||||
return None
|
||||
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
|
||||
if layer_ids and isinstance(layer_ids, (list, tuple)):
|
||||
return tuple(layer_ids)
|
||||
return None
|
||||
@@ -9,7 +9,6 @@ from vllm.config import VllmConfig
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backend import AttentionMetadataBuilder
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
@@ -20,7 +19,8 @@ from vllm.v1.worker.gpu.attn_utils import (
|
||||
from vllm.v1.worker.gpu.block_table import BlockTables
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
|
||||
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle_cudagraph import EagleCudaGraphManager
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle.cudagraph import EagleCudaGraphManager
|
||||
from vllm.v1.worker.gpu.spec_decode.eagle.utils import load_eagle_model
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -73,18 +73,7 @@ class EagleSpeculator:
|
||||
self.cudagraph_manager = EagleCudaGraphManager(vllm_config, device)
|
||||
|
||||
def load_model(self, target_model: nn.Module) -> None:
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
|
||||
with set_model_tag("eagle_head"):
|
||||
self.model = get_model(
|
||||
vllm_config=self.vllm_config, model_config=self.draft_model_config
|
||||
)
|
||||
|
||||
share_lm_head = True
|
||||
if share_lm_head and hasattr(target_model, "lm_head"):
|
||||
if hasattr(self.model, "lm_head"):
|
||||
del self.model.lm_head
|
||||
self.model.lm_head = target_model.lm_head
|
||||
self.model = load_eagle_model(target_model, self.vllm_config)
|
||||
|
||||
def set_attn(
|
||||
self,
|
||||
52
vllm/v1/worker/gpu/spec_decode/eagle/utils.py
Normal file
52
vllm/v1/worker/gpu/spec_decode/eagle/utils.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
|
||||
|
||||
def load_eagle_model(target_model: nn.Module, vllm_config: VllmConfig) -> nn.Module:
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
|
||||
speculative_config = vllm_config.speculative_config
|
||||
assert speculative_config is not None
|
||||
draft_model_config = speculative_config.draft_model_config
|
||||
with set_model_tag("eagle_head"):
|
||||
eagle_model = get_model(
|
||||
vllm_config=vllm_config, model_config=draft_model_config
|
||||
)
|
||||
|
||||
# Share target embeddings when the draft checkpoint does not include
|
||||
# its own vocab embedding table.
|
||||
share_embeddings = True
|
||||
if hasattr(eagle_model, "has_own_embed_tokens"):
|
||||
share_embeddings = not eagle_model.has_own_embed_tokens
|
||||
if share_embeddings:
|
||||
target_language_model = (
|
||||
target_model.get_language_model()
|
||||
if hasattr(target_model, "get_language_model")
|
||||
else target_model
|
||||
)
|
||||
inner_model = getattr(target_language_model, "model", None)
|
||||
target_embed_tokens = None
|
||||
if inner_model is not None:
|
||||
if hasattr(inner_model, "embed_tokens"):
|
||||
target_embed_tokens = inner_model.embed_tokens
|
||||
elif hasattr(inner_model, "embedding"):
|
||||
target_embed_tokens = inner_model.embedding
|
||||
if target_embed_tokens is not None and hasattr(eagle_model, "model"):
|
||||
if hasattr(eagle_model.model, "embed_tokens"):
|
||||
del eagle_model.model.embed_tokens
|
||||
eagle_model.model.embed_tokens = target_embed_tokens
|
||||
|
||||
# Only share target lm_head when the draft model does not own one.
|
||||
share_lm_head = True
|
||||
if hasattr(eagle_model, "has_own_lm_head"):
|
||||
share_lm_head = not eagle_model.has_own_lm_head
|
||||
if share_lm_head and hasattr(target_model, "lm_head"):
|
||||
if hasattr(eagle_model, "lm_head"):
|
||||
del eagle_model.lm_head
|
||||
eagle_model.lm_head = target_model.lm_head
|
||||
|
||||
return eagle_model
|
||||
Reference in New Issue
Block a user