[V1] [Spec Decode] Support random sampling for spec decode (#13933)
Some checks failed
Create Release / Create Release (push) Has been cancelled
Some checks failed
Create Release / Create Release (push) Has been cancelled
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -37,6 +37,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_supported
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
@@ -1020,15 +1021,26 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
else:
|
||||
target_probs = self.model.sampler.compute_probs(
|
||||
logits, sampling_metadata)
|
||||
draft_token_ids = [
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
|
||||
for req_id in self.input_batch.req_ids
|
||||
]
|
||||
sampler_output = self.rejection_sampler(draft_token_ids,
|
||||
target_probs,
|
||||
sampling_metadata)
|
||||
sample_lens = [len(tokens) + 1 for tokens in draft_token_ids]
|
||||
recover_logits_idx = np.cumsum(sample_lens) - 1
|
||||
target_probs = self.rejection_sampler.compute_probs(
|
||||
logits, sampling_metadata, sample_lens)
|
||||
sampler_output = self.model.sample(
|
||||
logits=logits[recover_logits_idx, :],
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
bonus_token_ids = sampler_output.sampled_token_ids
|
||||
output_token_ids = self.rejection_sampler(
|
||||
draft_token_ids,
|
||||
None, # draft_probs
|
||||
bonus_token_ids,
|
||||
target_probs,
|
||||
sampling_metadata)
|
||||
sampler_output.sampled_token_ids = output_token_ids
|
||||
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
# the requests one by one. Optimize.
|
||||
@@ -1075,7 +1087,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
spec_token_ids = None
|
||||
else:
|
||||
spec_token_ids = self.generate_draft_token_ids(
|
||||
valid_sampled_token_ids)
|
||||
valid_sampled_token_ids, sampling_metadata)
|
||||
|
||||
return ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
@@ -1089,6 +1101,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def generate_draft_token_ids(
|
||||
self,
|
||||
sampled_token_ids: list[list[int]],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> list[list[int]]:
|
||||
# TODO(woosuk): Optimize.
|
||||
draft_token_ids: list[list[int]] = []
|
||||
@@ -1099,6 +1112,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
# Skip requests that require top-p, top-k, etc.
|
||||
req_id = self.input_batch.req_ids[i]
|
||||
if not is_spec_decode_supported(req_id, self.input_batch):
|
||||
draft_token_ids.append([])
|
||||
continue
|
||||
|
||||
# Add sampled_token_ids to token_ids_cpu.
|
||||
start_idx = self.input_batch.num_tokens_no_spec[i]
|
||||
end_idx = start_idx + num_sampled_ids
|
||||
|
||||
Reference in New Issue
Block a user