[Model Runner V2] Fix draft logits not populated during cudagraph replay (#37639)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
This commit is contained in:
Giancarlo Delfin
2026-03-20 00:43:47 -07:00
committed by GitHub
parent bd8c4c0752
commit dcee9be95a
3 changed files with 17 additions and 25 deletions

View File

@@ -195,7 +195,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_speculative_steps=self.num_speculative_steps,
vocab_size=self.vocab_size,
device=self.device,
cache_draft_logits=not use_strict_rejection_sampling,
)
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
@@ -446,7 +445,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
next_prefill_tokens=self.req_states.next_prefill_tokens,
temperature=self.sampler.sampling_states.temperature.gpu,
seeds=self.sampler.sampling_states.seeds.gpu,
draft_logits_out=self.req_states.draft_logits,
num_tokens_across_dp=num_tokens_across_dp,
dummy_run=True,
skip_attn_for_dummy_run=skip_attn,
@@ -815,11 +813,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else:
# Rejection sampling for spec decoding.
assert self.rejection_sampler is not None
assert self.speculator is not None
sampler_output = self.rejection_sampler(
logits,
input_batch,
# Draft logits are needed for probabilistic rejection sampling.
self.req_states.draft_logits,
self.speculator.draft_logits,
)
# Get the number of sampled and rejected tokens.
@@ -1145,7 +1144,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.req_states.next_prefill_tokens,
self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu,
self.req_states.draft_logits,
num_tokens_across_dp=num_tokens_across_dp,
)
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens

View File

@@ -76,6 +76,17 @@ class EagleSpeculator:
device=device,
)
cache_draft_logits = self.speculative_config.rejection_sample_method != "strict"
self.draft_logits: torch.Tensor | None = None
if cache_draft_logits:
self.draft_logits = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,
self.vocab_size,
dtype=torch.float32,
device=device,
)
# currently we don't support PIECEWISE for Eagle.
cudagraph_mode = vllm_config.compilation_config.cudagraph_mode
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL:
@@ -158,7 +169,6 @@ class EagleSpeculator:
slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
draft_logits_out: torch.Tensor | None = None,
) -> None:
pos = self.input_buffers.positions[:num_reqs]
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
@@ -185,8 +195,8 @@ class EagleSpeculator:
self.seeds,
pos + 1,
apply_temperature=True,
processed_logits_out=draft_logits_out[:, step]
if draft_logits_out is not None
processed_logits_out=self.draft_logits[:, step]
if self.draft_logits is not None
else None,
)
self.draft_tokens[:num_reqs, step] = draft_tokens
@@ -241,8 +251,6 @@ class EagleSpeculator:
temperature: torch.Tensor,
# [max_num_reqs]
seeds: torch.Tensor,
# [max_num_reqs, num_speculative_steps, vocab_size]
draft_logits_out: torch.Tensor | None,
num_tokens_across_dp: torch.Tensor | None = None,
dummy_run: bool = False,
skip_attn_for_dummy_run: bool = False,
@@ -308,8 +316,8 @@ class EagleSpeculator:
self.seeds,
pos + 1,
apply_temperature=True,
processed_logits_out=draft_logits_out[:, 0]
if draft_logits_out is not None
processed_logits_out=self.draft_logits[:, 0]
if self.draft_logits is not None
else None,
)
@@ -394,7 +402,6 @@ class EagleSpeculator:
slot_mappings_updated,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=batch_desc.cg_mode,
draft_logits_out=draft_logits_out,
)
return self.draft_tokens[:num_reqs]

View File

@@ -15,7 +15,6 @@ class RequestState:
num_speculative_steps: int,
vocab_size: int,
device: torch.device,
cache_draft_logits: bool,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
@@ -71,18 +70,6 @@ class RequestState:
dtype=torch.int64,
device=device,
)
# Draft token logits.
# NOTE: This tensor maintains the "processed" logits after applying temperature,
# top-p, etc.
self.draft_logits: torch.Tensor | None = None
if cache_draft_logits:
self.draft_logits = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,
self.vocab_size,
dtype=torch.float32,
device=device,
)
self.next_prefill_tokens = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=device