[Async][Feat] support apply penalty or bad_words for async + spec (#30495)
Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com> Signed-off-by: izhuhaoran <izhuhaoran@qq.com> Signed-off-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -51,6 +51,14 @@ def test_without_spec_decoding(
|
||||
dict(logprobs=2),
|
||||
dict(logprobs=2, presence_penalty=-1.0),
|
||||
dict(structured_outputs=struct_outputs),
|
||||
dict(
|
||||
structured_outputs=struct_outputs,
|
||||
logprobs=2,
|
||||
),
|
||||
dict(
|
||||
structured_outputs=struct_outputs,
|
||||
presence_penalty=-1.0,
|
||||
),
|
||||
dict(
|
||||
structured_outputs=struct_outputs,
|
||||
logprobs=2,
|
||||
@@ -105,11 +113,15 @@ def test_with_spec_decoding(sample_json_schema, monkeypatch: pytest.MonkeyPatch)
|
||||
|
||||
test_sampling_params = [
|
||||
dict(),
|
||||
dict(presence_penalty=-1.0),
|
||||
dict(bad_words=["the", " the"]),
|
||||
dict(logprobs=2),
|
||||
dict(logprobs=2, presence_penalty=-1.0),
|
||||
dict(structured_outputs=struct_outputs),
|
||||
dict(
|
||||
structured_outputs=struct_outputs,
|
||||
logprobs=2,
|
||||
presence_penalty=-1.0,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -165,23 +165,6 @@ class InputProcessor:
|
||||
"are not yet supported with speculative decoding."
|
||||
)
|
||||
|
||||
# Async scheduling + spec decode currently incompatible with some
|
||||
# sampling parameters.
|
||||
if (
|
||||
self.vllm_config.speculative_config is not None
|
||||
and self.vllm_config.scheduler_config.async_scheduling
|
||||
and (
|
||||
params.frequency_penalty != 0.0
|
||||
or params.presence_penalty != 0.0
|
||||
or params.repetition_penalty != 1.0
|
||||
or params.bad_words_token_ids
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
"async scheduling with spec decoding doesn't yet support "
|
||||
"penalties or bad words in sampling parameters."
|
||||
)
|
||||
|
||||
def _validate_params(
|
||||
self,
|
||||
params: SamplingParams | PoolingParams,
|
||||
|
||||
@@ -965,9 +965,40 @@ class InputBatch:
|
||||
if sampled_token_ids is None:
|
||||
assert self.async_copy_ready_event is not None
|
||||
self.async_copy_ready_event.synchronize()
|
||||
sampled_token_ids = self.sampled_token_ids_cpu.squeeze(-1).tolist()
|
||||
# Replace placeholder token id with actual sampled id.
|
||||
req_output_token_ids[-1] = sampled_token_ids[prev_index]
|
||||
sampled_token_ids = self.sampled_token_ids_cpu.tolist()
|
||||
# Replace placeholder token id(s) with actual sampled id(s).
|
||||
new_ids: list[int] = sampled_token_ids[prev_index]
|
||||
if not new_ids:
|
||||
continue
|
||||
num_sampled_ids = len(new_ids) if new_ids[-1] != -1 else new_ids.index(-1)
|
||||
# Also account for case where there may be a smaller number of
|
||||
# output placeholders (tokens can be discarded after a kv-load failure).
|
||||
first_placeholder = req_output_token_ids.index(-1)
|
||||
num_placeholders = len(req_output_token_ids) - first_placeholder
|
||||
num_to_replace = min(num_sampled_ids, num_placeholders)
|
||||
del new_ids[num_to_replace:]
|
||||
end_index = first_placeholder + num_to_replace
|
||||
req_output_token_ids[first_placeholder:end_index] = new_ids
|
||||
|
||||
def update_async_spec_token_ids(self, draft_token_ids: list[list[int]]) -> None:
|
||||
"""
|
||||
In async scheduling case, update spec_token_ids in sampling metadata with
|
||||
real draft token ids from prior step. This is called right before they are
|
||||
needed by the rejection sampler for penalty/bad_words computation.
|
||||
"""
|
||||
if not draft_token_ids or not self.prev_req_id_to_index:
|
||||
return
|
||||
|
||||
if (spec_token_ids := self.sampling_metadata.spec_token_ids) is not None:
|
||||
for req_id, spec_ids in zip(self.req_ids, spec_token_ids):
|
||||
if spec_ids:
|
||||
prev_index = self.prev_req_id_to_index.get(req_id)
|
||||
if prev_index is not None:
|
||||
draft_ids = draft_token_ids[prev_index]
|
||||
if draft_ids:
|
||||
del draft_ids[len(spec_ids) :]
|
||||
spec_ids.clear()
|
||||
spec_ids.extend(draft_ids)
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
|
||||
@@ -2721,15 +2721,21 @@ class GPUModelRunner(
|
||||
) -> SamplerOutput:
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
# Update output token ids with tokens sampled in last step
|
||||
# if async scheduling and required by current sampling params.
|
||||
self.input_batch.update_async_output_token_ids()
|
||||
if spec_decode_metadata is None:
|
||||
# Update output token ids with tokens sampled in last step
|
||||
# if async scheduling and required by current sampling params.
|
||||
self.input_batch.update_async_output_token_ids()
|
||||
return self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
|
||||
# Update spec_token_ids with real draft tokens from pre step only when
|
||||
# output_token_ids is needed (penalties or bad_words are in use).
|
||||
if self.use_async_scheduling and self._draft_token_req_ids is not None:
|
||||
draft_token_ids_cpu, _ = self._get_draft_token_ids_cpu()
|
||||
self.input_batch.update_async_spec_token_ids(draft_token_ids_cpu)
|
||||
|
||||
sampler_output = self.rejection_sampler(
|
||||
spec_decode_metadata,
|
||||
None, # draft_probs
|
||||
@@ -3352,8 +3358,6 @@ class GPUModelRunner(
|
||||
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
|
||||
kv_connector_output = self.kv_connector_output
|
||||
self.kv_connector_output = None
|
||||
self._draft_token_ids = None
|
||||
self._draft_token_req_ids = None
|
||||
|
||||
if self.execute_model_state is None:
|
||||
# Nothing to do (PP non-final rank case), output isn't used.
|
||||
@@ -3393,6 +3397,8 @@ class GPUModelRunner(
|
||||
with record_function_or_nullcontext("gpu_model_runner: sample"):
|
||||
sampler_output = self._sample(logits, spec_decode_metadata)
|
||||
|
||||
self._draft_token_ids = None
|
||||
self._draft_token_req_ids = None
|
||||
self.input_batch.prev_sampled_token_ids = None
|
||||
|
||||
def propose_draft_token_ids(sampled_token_ids):
|
||||
@@ -3517,17 +3523,18 @@ class GPUModelRunner(
|
||||
def take_draft_token_ids(self) -> DraftTokenIds | None:
|
||||
if not self.num_spec_tokens or not self._draft_token_req_ids:
|
||||
return None
|
||||
req_ids = self._draft_token_req_ids
|
||||
draft_token_ids = self._get_draft_token_ids_cpu(len(req_ids))
|
||||
draft_token_ids, req_ids = self._get_draft_token_ids_cpu()
|
||||
return DraftTokenIds(req_ids, draft_token_ids)
|
||||
|
||||
def _copy_draft_token_ids_to_cpu(
|
||||
self, scheduler_output: "SchedulerOutput", zeros_only: bool = False
|
||||
) -> None:
|
||||
struct_output = scheduler_output.has_structured_output_requests
|
||||
if self.use_async_scheduling and not struct_output:
|
||||
# Draft tokens don't need to be copied to the CPU if async
|
||||
# scheduling is in use and there are no structured output reqs.
|
||||
# Check if we need to copy draft tokens to CPU. In async scheduling,
|
||||
# we only copy when needed for structured output, penalties or bad_words.
|
||||
if self.use_async_scheduling and not (
|
||||
scheduler_output.has_structured_output_requests
|
||||
or self.input_batch.sampling_metadata.output_token_ids
|
||||
):
|
||||
return
|
||||
# We must also set the corresponding request ids.
|
||||
self._draft_token_req_ids = self.input_batch.req_ids.copy()
|
||||
@@ -3552,13 +3559,16 @@ class GPUModelRunner(
|
||||
self.draft_token_ids_cpu[:num_reqs] = 0
|
||||
self.draft_token_ids_event.record()
|
||||
|
||||
def _get_draft_token_ids_cpu(self, num_reqs: int) -> list[list[int]]:
|
||||
def _get_draft_token_ids_cpu(self) -> tuple[list[list[int]], list[str]]:
|
||||
if isinstance(self._draft_token_ids, list):
|
||||
return self._draft_token_ids
|
||||
return self._draft_token_ids, self.input_batch.req_ids
|
||||
req_ids = self._draft_token_req_ids
|
||||
if req_ids is None:
|
||||
return [], []
|
||||
assert self.draft_token_ids_event is not None
|
||||
assert self.draft_token_ids_cpu is not None
|
||||
self.draft_token_ids_event.synchronize()
|
||||
return self.draft_token_ids_cpu[:num_reqs].tolist()
|
||||
return self.draft_token_ids_cpu[: len(req_ids)].tolist(), req_ids
|
||||
|
||||
def _copy_valid_sampled_token_count(
|
||||
self, next_token_ids: torch.Tensor, valid_sampled_tokens_count: torch.Tensor
|
||||
|
||||
Reference in New Issue
Block a user