[Model Runner V2] Fix draft logits not populated during cudagraph replay (#37639)
Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
This commit is contained in:
@@ -195,7 +195,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_speculative_steps=self.num_speculative_steps,
|
||||
vocab_size=self.vocab_size,
|
||||
device=self.device,
|
||||
cache_draft_logits=not use_strict_rejection_sampling,
|
||||
)
|
||||
self.input_buffers = InputBuffers(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
@@ -446,7 +445,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
next_prefill_tokens=self.req_states.next_prefill_tokens,
|
||||
temperature=self.sampler.sampling_states.temperature.gpu,
|
||||
seeds=self.sampler.sampling_states.seeds.gpu,
|
||||
draft_logits_out=self.req_states.draft_logits,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
dummy_run=True,
|
||||
skip_attn_for_dummy_run=skip_attn,
|
||||
@@ -815,11 +813,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
# Rejection sampling for spec decoding.
|
||||
assert self.rejection_sampler is not None
|
||||
assert self.speculator is not None
|
||||
sampler_output = self.rejection_sampler(
|
||||
logits,
|
||||
input_batch,
|
||||
# Draft logits are needed for probabilistic rejection sampling.
|
||||
self.req_states.draft_logits,
|
||||
self.speculator.draft_logits,
|
||||
)
|
||||
|
||||
# Get the number of sampled and rejected tokens.
|
||||
@@ -1145,7 +1144,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.req_states.next_prefill_tokens,
|
||||
self.sampler.sampling_states.temperature.gpu,
|
||||
self.sampler.sampling_states.seeds.gpu,
|
||||
self.req_states.draft_logits,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
)
|
||||
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
|
||||
|
||||
@@ -76,6 +76,17 @@ class EagleSpeculator:
|
||||
device=device,
|
||||
)
|
||||
|
||||
cache_draft_logits = self.speculative_config.rejection_sample_method != "strict"
|
||||
self.draft_logits: torch.Tensor | None = None
|
||||
if cache_draft_logits:
|
||||
self.draft_logits = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.num_speculative_steps,
|
||||
self.vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# currently we don't support PIECEWISE for Eagle.
|
||||
cudagraph_mode = vllm_config.compilation_config.cudagraph_mode
|
||||
if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL:
|
||||
@@ -158,7 +169,6 @@ class EagleSpeculator:
|
||||
slot_mappings: dict[str, torch.Tensor] | None,
|
||||
num_tokens_across_dp: torch.Tensor | None,
|
||||
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
draft_logits_out: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
pos = self.input_buffers.positions[:num_reqs]
|
||||
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
|
||||
@@ -185,8 +195,8 @@ class EagleSpeculator:
|
||||
self.seeds,
|
||||
pos + 1,
|
||||
apply_temperature=True,
|
||||
processed_logits_out=draft_logits_out[:, step]
|
||||
if draft_logits_out is not None
|
||||
processed_logits_out=self.draft_logits[:, step]
|
||||
if self.draft_logits is not None
|
||||
else None,
|
||||
)
|
||||
self.draft_tokens[:num_reqs, step] = draft_tokens
|
||||
@@ -241,8 +251,6 @@ class EagleSpeculator:
|
||||
temperature: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
seeds: torch.Tensor,
|
||||
# [max_num_reqs, num_speculative_steps, vocab_size]
|
||||
draft_logits_out: torch.Tensor | None,
|
||||
num_tokens_across_dp: torch.Tensor | None = None,
|
||||
dummy_run: bool = False,
|
||||
skip_attn_for_dummy_run: bool = False,
|
||||
@@ -308,8 +316,8 @@ class EagleSpeculator:
|
||||
self.seeds,
|
||||
pos + 1,
|
||||
apply_temperature=True,
|
||||
processed_logits_out=draft_logits_out[:, 0]
|
||||
if draft_logits_out is not None
|
||||
processed_logits_out=self.draft_logits[:, 0]
|
||||
if self.draft_logits is not None
|
||||
else None,
|
||||
)
|
||||
|
||||
@@ -394,7 +402,6 @@ class EagleSpeculator:
|
||||
slot_mappings_updated,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
cudagraph_runtime_mode=batch_desc.cg_mode,
|
||||
draft_logits_out=draft_logits_out,
|
||||
)
|
||||
return self.draft_tokens[:num_reqs]
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ class RequestState:
|
||||
num_speculative_steps: int,
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
cache_draft_logits: bool,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
@@ -71,18 +70,6 @@ class RequestState:
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
# Draft token logits.
|
||||
# NOTE: This tensor maintains the "processed" logits after applying temperature,
|
||||
# top-p, etc.
|
||||
self.draft_logits: torch.Tensor | None = None
|
||||
if cache_draft_logits:
|
||||
self.draft_logits = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.num_speculative_steps,
|
||||
self.vocab_size,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.next_prefill_tokens = torch.zeros(
|
||||
self.max_num_reqs, dtype=torch.int32, device=device
|
||||
|
||||
Reference in New Issue
Block a user