[Spec Decode] Make propose_draft_token_ids non-blocking for lower TTFT (#23041)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-08-18 17:20:38 -07:00
committed by GitHub
parent 0dd3f4f5ab
commit c9b38be8aa
13 changed files with 100 additions and 64 deletions

View File

@@ -65,8 +65,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
FullAttentionSpec, KVCacheConfig,
KVCacheSpec, MambaSpec,
SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
LogprobsTensors, ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
from vllm.v1.sample.metadata import SamplingMetadata
@@ -348,6 +348,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.reorder_batch_threshold: Optional[int] = None
# Cached outputs.
self._draft_token_ids: Optional[Union[list[list[int]],
torch.Tensor]] = None
def _init_model_kwargs(self, num_tokens: int):
model_kwargs = dict[str, Any]()
num_reqs = self.input_batch.num_reqs
@@ -1493,7 +1497,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=[],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output,
@@ -1764,12 +1767,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids)
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
else:
if self.speculative_config:
assert spec_decode_common_attn_metadata is not None
spec_token_ids = self.propose_draft_token_ids(
self._draft_token_ids = self.propose_draft_token_ids(
scheduler_output,
valid_sampled_token_ids,
sampling_metadata,
@@ -1786,7 +1786,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
@@ -1794,6 +1793,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_nans_in_logits=num_nans_in_logits,
)
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
if self._draft_token_ids is None:
return None
req_ids = self.input_batch.req_ids
if isinstance(self._draft_token_ids, torch.Tensor):
draft_token_ids = self._draft_token_ids.tolist()
else:
draft_token_ids = self._draft_token_ids
self._draft_token_ids = None
return DraftTokenIds(req_ids, draft_token_ids)
def propose_draft_token_ids(
self,
scheduler_output: "SchedulerOutput",
@@ -1804,11 +1814,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
aux_hidden_states: Optional[torch.Tensor],
spec_decode_metadata: Optional[SpecDecodeMetadata],
common_attn_metadata: CommonAttentionMetadata,
) -> list[list[int]]:
) -> Union[list[list[int]], torch.Tensor]:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if self.speculative_config.method == "ngram":
assert isinstance(self.drafter, NgramProposer)
spec_token_ids = self.propose_ngram_draft_token_ids(
draft_token_ids = self.propose_ngram_draft_token_ids(
sampled_token_ids)
elif self.speculative_config.method == "medusa":
assert isinstance(self.drafter, MedusaProposer)
@@ -1826,7 +1836,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
indices = torch.tensor(indices, device=self.device)
hidden_states = sample_hidden_states[indices]
spec_token_ids = self.drafter.propose(
draft_token_ids = self.drafter.propose(
target_hidden_states=hidden_states,
sampling_metadata=sampling_metadata,
)
@@ -1897,8 +1907,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
common_attn_metadata=common_attn_metadata,
mm_embeds=mm_embeds,
)
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids
return draft_token_ids
def propose_ngram_draft_token_ids(
self,