Files
vllm/vllm/v1/spec_decode/utils.py
Lily Liu 8d6cf89526
Some checks failed
Create Release / Create Release (push) Has been cancelled
[V1] [Spec Decode] Support random sampling for spec decode (#13933)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
2025-03-16 22:00:20 -07:00

23 lines
897 B
Python

# SPDX-License-Identifier: Apache-2.0
from vllm.v1.sample.ops.topk_topp_sampler import random_sample # noqa
from vllm.v1.worker.gpu_input_batch import InputBatch
def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
if req_id in input_batch.top_k_reqs or req_id in input_batch.top_p_reqs:
# Spec decode doesn't support top_p/top_k sampling.
return False
elif req_id in input_batch.min_p_reqs:
# Spec decode doesn't support min_p sampling.
return False
elif (req_id in input_batch.frequency_penalties_reqs
or req_id in input_batch.presence_penalties_reqs
or req_id in input_batch.repetition_penalties_reqs):
# Spec decode doesn't support penalties.
return False
elif req_id in input_batch.num_logprobs:
# Spec decode doesn't support logprobs.
return False
return True