[Model Runner V2] Simplify Eagle bookkeeping with num_rejected (#29347)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-11-24 13:54:59 -08:00
committed by GitHub
parent 3cfa63ad99
commit f32c7d6f54
4 changed files with 50 additions and 30 deletions

View File

@@ -46,7 +46,10 @@ from vllm.v1.worker.gpu.input_batch import (
)
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
from vllm.v1.worker.gpu.spec_decode import init_speculator
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
from vllm.v1.worker.gpu.spec_decode.rejection_sample import (
get_num_rejected,
rejection_sample,
)
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
@@ -311,12 +314,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
device=self.device,
)
num_sampled = torch.ones(num_reqs, dtype=torch.int32, device=self.device)
num_rejected = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
self.propose_draft(
input_batch=input_batch,
sampling_metadata=sampling_metadata,
last_hidden_states=hidden_states,
aux_hidden_states=aux_hidden_states,
num_sampled=num_sampled,
num_rejected=num_rejected,
)
@torch.inference_mode()
@@ -606,7 +611,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
grammar_output: GrammarOutput | None,
) -> tuple[SamplerOutput, torch.Tensor]:
) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
sample_hidden_states = hidden_states[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states)
if grammar_output is not None:
@@ -632,6 +637,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# No draft tokens (common case).
# 0 if chunked-prefilling, 1 if not.
num_sampled = (~is_chunked_prefilling).int()
num_rejected = torch.zeros_like(num_sampled)
else:
# Draft tokens for spec decoding.
input_ids = input_batch.input_ids[input_batch.logits_indices]
@@ -642,9 +648,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.num_speculative_steps,
)
num_sampled *= ~is_chunked_prefilling
num_rejected = get_num_rejected(
input_batch.cu_num_logits,
num_sampled,
)
sampler_output.sampled_token_ids = sampled_tokens
# TODO(woosuk): Support logprobs with spec decoding.
return sampler_output, num_sampled
return sampler_output, num_sampled, num_rejected
def compute_prompt_logprobs(
self,
@@ -750,6 +760,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch: InputBatch,
sampled_tokens: torch.Tensor,
num_sampled: torch.Tensor,
num_rejected: torch.Tensor,
) -> None:
# Update the number of computed tokens.
post_update(
@@ -758,8 +769,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.req_states.last_sampled_tokens,
sampled_tokens,
num_sampled,
num_rejected,
input_batch.query_start_loc,
input_batch.cu_num_logits,
)
# Update the number of computed prefill tokens.
@@ -779,6 +790,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
last_hidden_states: torch.Tensor,
aux_hidden_states: list[torch.Tensor] | None,
num_sampled: torch.Tensor,
num_rejected: torch.Tensor,
) -> torch.Tensor:
num_reqs = input_batch.num_reqs
idx_mapping_np = input_batch.idx_mapping_np
@@ -800,6 +812,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
last_hidden_states,
aux_hidden_states,
num_sampled,
num_rejected,
self.req_states.last_sampled_tokens,
next_prefill_tokens,
)
@@ -958,7 +971,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.execute_model_state = None # type: ignore
assert sampling_metadata is not None
sampler_output, num_sampled_tokens = self.sample(
sampler_output, num_sampled, num_rejected = self.sample(
hidden_states, input_batch, sampling_metadata, grammar_output
)
prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch)
@@ -979,7 +992,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
async_output = AsyncOutput(
model_runner_output=model_runner_output,
sampler_output=sampler_output,
num_sampled_tokens=num_sampled_tokens,
num_sampled_tokens=num_sampled,
copy_stream=self.output_copy_stream,
copy_event=self.output_copy_event,
)
@@ -990,7 +1003,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# This sequencing may slightly reduce latency as async D2H copy does not
# need to wait for the postprocess to finish.
self.postprocess(
input_batch, sampler_output.sampled_token_ids, num_sampled_tokens
input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
)
if self.do_spec_decode:
_ = self.propose_draft(
@@ -998,7 +1011,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampling_metadata,
hidden_states,
None, # aux_hidden_states
num_sampled_tokens,
num_sampled,
num_rejected,
)
if self.use_async_scheduling: