[Model Runner V2] Implement Single-step Eagle 1 (#29300)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-11-24 09:32:27 -08:00
committed by GitHub
parent 26a465584a
commit cc313cb73d
5 changed files with 300 additions and 2 deletions

View File

@@ -45,6 +45,7 @@ from vllm.v1.worker.gpu.input_batch import (
prepare_prefill_inputs,
)
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
from vllm.v1.worker.gpu.spec_decode import init_speculator
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
@@ -97,16 +98,20 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.use_async_scheduling:
self.input_prep_event = torch.cuda.Event()
self.structured_outputs_event = torch.cuda.Event()
self.spec_decode_event = torch.cuda.Event()
else:
self.input_prep_event = None
self.structured_outputs_event = None
self.spec_decode_event = None
if self.speculative_config is not None:
self.do_spec_decode = True
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
self.speculator = init_speculator(self.vllm_config, self.device)
else:
self.do_spec_decode = False
self.num_speculative_steps = 0
self.speculator = None
self.req_states = RequestState(
max_num_reqs=self.max_num_reqs,
@@ -153,6 +158,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.vllm_config,
self.device,
)
if self.do_spec_decode:
self.speculator.load_model(self.model)
time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory
@@ -285,6 +292,33 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits = self.model.compute_logits(hidden_states)
self.sampler(logits, sampling_metadata)
@torch.inference_mode()
def _dummy_speculator_run(
self,
hidden_states: torch.Tensor,
aux_hidden_states: list[torch.Tensor] | None,
) -> None:
num_tokens = hidden_states.shape[0]
num_reqs = min(num_tokens, self.max_num_reqs)
input_batch = InputBatch.make_dummy(
num_reqs=num_reqs,
num_tokens=num_tokens,
input_buffers=self.input_buffers,
device=self.device,
)
sampling_metadata = SamplingMetadata.make_dummy(
num_reqs=num_reqs,
device=self.device,
)
num_sampled = torch.ones(num_reqs, dtype=torch.int32, device=self.device)
self.propose_draft(
input_batch=input_batch,
sampling_metadata=sampling_metadata,
last_hidden_states=hidden_states,
aux_hidden_states=aux_hidden_states,
num_sampled=num_sampled,
)
@torch.inference_mode()
def profile_run(self) -> None:
hidden_states, sample_hidden_states = self._dummy_run(
@@ -292,6 +326,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
skip_attn=True,
)
self._dummy_sampler_run(sample_hidden_states)
if self.do_spec_decode:
self._dummy_speculator_run(hidden_states, None)
torch.cuda.synchronize()
del hidden_states, sample_hidden_states
gc.collect()
@@ -727,6 +763,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.req_states.prefill_len.np[idx_mapping_np],
)
@torch.inference_mode()
def propose_draft(
self,
input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
last_hidden_states: torch.Tensor,
aux_hidden_states: list[torch.Tensor] | None,
num_sampled: torch.Tensor,
) -> torch.Tensor:
num_reqs = input_batch.num_reqs
idx_mapping_np = input_batch.idx_mapping_np
with async_barrier(self.spec_decode_event):
self.input_buffers.next_prefill_tokens.np[:num_reqs] = (
self.req_states.prefill_token_ids[
idx_mapping_np,
self.req_states.num_computed_prefill_tokens[idx_mapping_np],
]
)
next_prefill_tokens = self.input_buffers.next_prefill_tokens.copy_to_gpu(
num_reqs
)
assert self.speculator is not None
draft_tokens = self.speculator.propose(
input_batch,
sampling_metadata,
last_hidden_states,
aux_hidden_states,
num_sampled,
self.req_states.last_sampled_tokens,
next_prefill_tokens,
)
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
return draft_tokens
def get_cudagraph_and_dp_padding(
self,
scheduler_output: SchedulerOutput,
@@ -913,6 +984,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.postprocess(
input_batch, sampler_output.sampled_token_ids, num_sampled_tokens
)
if self.do_spec_decode:
_ = self.propose_draft(
input_batch,
sampling_metadata,
hidden_states,
None, # aux_hidden_states
num_sampled_tokens,
)
if self.use_async_scheduling:
return async_output