[V1] [Spec Decode] Support random sampling for spec decode (#13933)
Some checks failed
Create Release / Create Release (push) Has been cancelled

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Lily Liu
2025-03-16 22:00:20 -07:00
committed by GitHub
parent 583a9778e0
commit 8d6cf89526
5 changed files with 568 additions and 194 deletions

View File

@@ -37,6 +37,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.spec_decode.utils import is_spec_decode_supported
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
@@ -1020,15 +1021,26 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampling_metadata=sampling_metadata,
)
else:
target_probs = self.model.sampler.compute_probs(
logits, sampling_metadata)
draft_token_ids = [
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
for req_id in self.input_batch.req_ids
]
sampler_output = self.rejection_sampler(draft_token_ids,
target_probs,
sampling_metadata)
sample_lens = [len(tokens) + 1 for tokens in draft_token_ids]
recover_logits_idx = np.cumsum(sample_lens) - 1
target_probs = self.rejection_sampler.compute_probs(
logits, sampling_metadata, sample_lens)
sampler_output = self.model.sample(
logits=logits[recover_logits_idx, :],
sampling_metadata=sampling_metadata,
)
bonus_token_ids = sampler_output.sampled_token_ids
output_token_ids = self.rejection_sampler(
draft_token_ids,
None, # draft_probs
bonus_token_ids,
target_probs,
sampling_metadata)
sampler_output.sampled_token_ids = output_token_ids
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
@@ -1075,7 +1087,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids = None
else:
spec_token_ids = self.generate_draft_token_ids(
valid_sampled_token_ids)
valid_sampled_token_ids, sampling_metadata)
return ModelRunnerOutput(
req_ids=self.input_batch.req_ids,
@@ -1089,6 +1101,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def generate_draft_token_ids(
self,
sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata,
) -> list[list[int]]:
# TODO(woosuk): Optimize.
draft_token_ids: list[list[int]] = []
@@ -1099,6 +1112,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
draft_token_ids.append([])
continue
# Skip requests that require top-p, top-k, etc.
req_id = self.input_batch.req_ids[i]
if not is_spec_decode_supported(req_id, self.input_batch):
draft_token_ids.append([])
continue
# Add sampled_token_ids to token_ids_cpu.
start_idx = self.input_batch.num_tokens_no_spec[i]
end_idx = start_idx + num_sampled_ids