[Spec Decode] Make propose_draft_token_ids non-blocking for lower TTFT (#23041)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -65,8 +65,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec, MambaSpec,
|
||||
SlidingWindowSpec)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||
ModelRunnerOutput)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
|
||||
LogprobsTensors, ModelRunnerOutput)
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
@@ -348,6 +348,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
self.reorder_batch_threshold: Optional[int] = None
|
||||
|
||||
# Cached outputs.
|
||||
self._draft_token_ids: Optional[Union[list[list[int]],
|
||||
torch.Tensor]] = None
|
||||
|
||||
def _init_model_kwargs(self, num_tokens: int):
|
||||
model_kwargs = dict[str, Any]()
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
@@ -1493,7 +1497,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=[],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=pooler_output,
|
||||
@@ -1764,12 +1767,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_state = self.requests[req_id]
|
||||
req_state.output_token_ids.extend(sampled_ids)
|
||||
|
||||
if not self.speculative_config:
|
||||
# Speculative decoding is not enabled.
|
||||
spec_token_ids = None
|
||||
else:
|
||||
if self.speculative_config:
|
||||
assert spec_decode_common_attn_metadata is not None
|
||||
spec_token_ids = self.propose_draft_token_ids(
|
||||
self._draft_token_ids = self.propose_draft_token_ids(
|
||||
scheduler_output,
|
||||
valid_sampled_token_ids,
|
||||
sampling_metadata,
|
||||
@@ -1786,7 +1786,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
spec_token_ids=spec_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
@@ -1794,6 +1793,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_nans_in_logits=num_nans_in_logits,
|
||||
)
|
||||
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
if self._draft_token_ids is None:
|
||||
return None
|
||||
req_ids = self.input_batch.req_ids
|
||||
if isinstance(self._draft_token_ids, torch.Tensor):
|
||||
draft_token_ids = self._draft_token_ids.tolist()
|
||||
else:
|
||||
draft_token_ids = self._draft_token_ids
|
||||
self._draft_token_ids = None
|
||||
return DraftTokenIds(req_ids, draft_token_ids)
|
||||
|
||||
def propose_draft_token_ids(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
@@ -1804,11 +1814,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
aux_hidden_states: Optional[torch.Tensor],
|
||||
spec_decode_metadata: Optional[SpecDecodeMetadata],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> list[list[int]]:
|
||||
) -> Union[list[list[int]], torch.Tensor]:
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if self.speculative_config.method == "ngram":
|
||||
assert isinstance(self.drafter, NgramProposer)
|
||||
spec_token_ids = self.propose_ngram_draft_token_ids(
|
||||
draft_token_ids = self.propose_ngram_draft_token_ids(
|
||||
sampled_token_ids)
|
||||
elif self.speculative_config.method == "medusa":
|
||||
assert isinstance(self.drafter, MedusaProposer)
|
||||
@@ -1826,7 +1836,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
indices = torch.tensor(indices, device=self.device)
|
||||
hidden_states = sample_hidden_states[indices]
|
||||
|
||||
spec_token_ids = self.drafter.propose(
|
||||
draft_token_ids = self.drafter.propose(
|
||||
target_hidden_states=hidden_states,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
@@ -1897,8 +1907,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
mm_embeds=mm_embeds,
|
||||
)
|
||||
spec_token_ids = draft_token_ids.tolist()
|
||||
return spec_token_ids
|
||||
return draft_token_ids
|
||||
|
||||
def propose_ngram_draft_token_ids(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user