[Async][Feat] support apply penalty or bad_words for async + spec (#30495)

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: izhuhaoran <izhuhaoran@qq.com>
Signed-off-by: Nick Hill <nickhill123@gmail.com>
Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
zhrrr
2026-01-09 10:31:50 +08:00
committed by GitHub
parent a4ec0c5595
commit 8ff4a99566
4 changed files with 70 additions and 34 deletions

View File

@@ -51,6 +51,14 @@ def test_without_spec_decoding(
dict(logprobs=2),
dict(logprobs=2, presence_penalty=-1.0),
dict(structured_outputs=struct_outputs),
dict(
structured_outputs=struct_outputs,
logprobs=2,
),
dict(
structured_outputs=struct_outputs,
presence_penalty=-1.0,
),
dict(
structured_outputs=struct_outputs,
logprobs=2,
@@ -105,11 +113,15 @@ def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch)
test_sampling_params = [
dict(),
dict(presence_penalty=-1.0),
dict(bad_words=["the", " the"]),
dict(logprobs=2),
dict(logprobs=2, presence_penalty=-1.0),
dict(structured_outputs=struct_outputs),
dict(
structured_outputs=struct_outputs,
logprobs=2,
presence_penalty=-1.0,
),
]

View File

@@ -165,23 +165,6 @@ class InputProcessor:
"are not yet supported with speculative decoding."
)
# Async scheduling + spec decode currently incompatible with some
# sampling parameters.
if (
self.vllm_config.speculative_config is not None
and self.vllm_config.scheduler_config.async_scheduling
and (
params.frequency_penalty != 0.0
or params.presence_penalty != 0.0
or params.repetition_penalty != 1.0
or params.bad_words_token_ids
)
):
raise ValueError(
"async scheduling with spec decoding doesn't yet support "
"penalties or bad words in sampling parameters."
)
def _validate_params(
self,
params: SamplingParams | PoolingParams,

View File

@@ -965,9 +965,40 @@ class InputBatch:
if sampled_token_ids is None:
assert self.async_copy_ready_event is not None
self.async_copy_ready_event.synchronize()
sampled_token_ids = self.sampled_token_ids_cpu.squeeze(-1).tolist()
# Replace placeholder token id with actual sampled id.
req_output_token_ids[-1] = sampled_token_ids[prev_index]
sampled_token_ids = self.sampled_token_ids_cpu.tolist()
# Replace placeholder token id(s) with actual sampled id(s).
new_ids: list[int] = sampled_token_ids[prev_index]
if not new_ids:
continue
num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1)
# Also account for case where there may be a smaller number of
# output placeholders (tokens can be discarded after a kv-load failure).
first_placeholder = req_output_token_ids.index(-1)
num_placeholders = len(req_output_token_ids) - first_placeholder
num_to_replace = min(num_sampled_ids, num_placeholders)
del new_ids[num_to_replace:]
end_index = first_placeholder + num_to_replace
req_output_token_ids[first_placeholder:end_index] = new_ids
def update_async_spec_token_ids(self, draft_token_ids: list[list[int]]) -> None:
"""
In async scheduling case, update spec_token_ids in sampling metadata with
real draft token ids from prior step. This is called right before they are
needed by the rejection sampler for penalty/bad_words computation.
"""
if not draft_token_ids or not self.prev_req_id_to_index:
return
if (spec_token_ids := self.sampling_metadata.spec_token_ids) is not None:
for req_id, spec_ids in zip(self.req_ids, spec_token_ids):
if spec_ids:
prev_index = self.prev_req_id_to_index.get(req_id)
if prev_index is not None:
draft_ids = draft_token_ids[prev_index]
if draft_ids:
del draft_ids[len(spec_ids) :]
spec_ids.clear()
spec_ids.extend(draft_ids)
@property
def num_reqs(self) -> int:

View File

@@ -2721,15 +2721,21 @@ class GPUModelRunner(
) -> SamplerOutput:
# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
# Update output token ids with tokens sampled in last step
# if async scheduling and required by current sampling params.
self.input_batch.update_async_output_token_ids()
if spec_decode_metadata is None:
# Update output token ids with tokens sampled in last step
# if async scheduling and required by current sampling params.
self.input_batch.update_async_output_token_ids()
return self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
)
# Update spec_token_ids with real draft tokens from pre step only when
# output_token_ids is needed (penalties or bad_words are in use).
if self.use_async_scheduling and self._draft_token_req_ids is not None:
draft_token_ids_cpu, _ = self._get_draft_token_ids_cpu()
self.input_batch.update_async_spec_token_ids(draft_token_ids_cpu)
sampler_output = self.rejection_sampler(
spec_decode_metadata,
None, # draft_probs
@@ -3352,8 +3358,6 @@ class GPUModelRunner(
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
kv_connector_output = self.kv_connector_output
self.kv_connector_output = None
self._draft_token_ids = None
self._draft_token_req_ids = None
if self.execute_model_state is None:
# Nothing to do (PP non-final rank case), output isn't used.
@@ -3393,6 +3397,8 @@ class GPUModelRunner(
with record_function_or_nullcontext("gpu_model_runner: sample"):
sampler_output = self._sample(logits, spec_decode_metadata)
self._draft_token_ids = None
self._draft_token_req_ids = None
self.input_batch.prev_sampled_token_ids = None
def propose_draft_token_ids(sampled_token_ids):
@@ -3517,17 +3523,18 @@ class GPUModelRunner(
def take_draft_token_ids(self) -> DraftTokenIds | None:
if not self.num_spec_tokens or not self._draft_token_req_ids:
return None
req_ids = self._draft_token_req_ids
draft_token_ids = self._get_draft_token_ids_cpu(len(req_ids))
draft_token_ids, req_ids = self._get_draft_token_ids_cpu()
return DraftTokenIds(req_ids, draft_token_ids)
def _copy_draft_token_ids_to_cpu(
self, scheduler_output: "SchedulerOutput", zeros_only: bool = False
) -> None:
struct_output = scheduler_output.has_structured_output_requests
if self.use_async_scheduling and not struct_output:
# Draft tokens don't need to be copied to the CPU if async
# scheduling is in use and there are no structured output reqs.
# Check if we need to copy draft tokens to CPU. In async scheduling,
# we only copy when needed for structured output, penalties or bad_words.
if self.use_async_scheduling and not (
scheduler_output.has_structured_output_requests
or self.input_batch.sampling_metadata.output_token_ids
):
return
# We must also set the corresponding request ids.
self._draft_token_req_ids = self.input_batch.req_ids.copy()
@@ -3552,13 +3559,16 @@ class GPUModelRunner(
self.draft_token_ids_cpu[:num_reqs] = 0
self.draft_token_ids_event.record()
def _get_draft_token_ids_cpu(self, num_reqs: int) -> list[list[int]]:
def _get_draft_token_ids_cpu(self) -> tuple[list[list[int]], list[str]]:
if isinstance(self._draft_token_ids, list):
return self._draft_token_ids
return self._draft_token_ids, self.input_batch.req_ids
req_ids = self._draft_token_req_ids
if req_ids is None:
return [], []
assert self.draft_token_ids_event is not None
assert self.draft_token_ids_cpu is not None
self.draft_token_ids_event.synchronize()
return self.draft_token_ids_cpu[:num_reqs].tolist()
return self.draft_token_ids_cpu[: len(req_ids)].tolist(), req_ids
def _copy_valid_sampled_token_count(
self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor