diff --git a/vllm/v1/worker/gpu/input_batch.py b/vllm/v1/worker/gpu/input_batch.py index b3ab15178..1b18768b3 100644 --- a/vllm/v1/worker/gpu/input_batch.py +++ b/vllm/v1/worker/gpu/input_batch.py @@ -75,6 +75,9 @@ class InputBatch: cu_num_logits: torch.Tensor cu_num_logits_np: np.ndarray + # Whether any requests in batch use structured output. + has_structured_output_reqs: bool + @classmethod def make_dummy( cls, @@ -139,6 +142,7 @@ class InputBatch: logits_indices=logits_indices, cu_num_logits=cu_num_logits, cu_num_logits_np=cu_num_logits_np, + has_structured_output_reqs=False, ) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 5889abded..43e26f3c9 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -20,7 +20,7 @@ from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.worker.gpu.async_utils import AsyncOutput from vllm.v1.worker.gpu.attn_utils import ( build_attn_metadata, @@ -59,6 +59,7 @@ from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker from vllm.v1.worker.gpu.sample.sampler import Sampler from vllm.v1.worker.gpu.spec_decode import init_speculator from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample +from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler from vllm.v1.worker.gpu.states import RequestState from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin @@ -167,6 +168,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): # LoRA-related workers. self.lora_state = LoraState(max_num_reqs=self.max_num_reqs) + # Draft tokens propagation - for spec-dec + struct outputs. + self.draft_tokens_handler = DraftTokensHandler(self.device) + + # KV Connector if configured. self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR def update_max_model_len(self, max_model_len: int) -> None: @@ -638,6 +643,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): logits_indices=logits_indices, cu_num_logits=cu_num_logits, cu_num_logits_np=cu_num_logits_np, + has_structured_output_reqs=scheduler_output.has_structured_output_requests, ) @torch.inference_mode() @@ -938,7 +944,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): num_rejected, ) self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens + self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens) if self.use_async_scheduling: return async_output return async_output.get_output() + + def take_draft_token_ids(self) -> DraftTokenIds | None: + return self.draft_tokens_handler.get_draft_tokens() diff --git a/vllm/v1/worker/gpu/spec_decode/utils.py b/vllm/v1/worker/gpu/spec_decode/utils.py new file mode 100644 index 000000000..ddeb99a71 --- /dev/null +++ b/vllm/v1/worker/gpu/spec_decode/utils.py @@ -0,0 +1,45 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import numpy as np +import torch + +from vllm.v1.outputs import DraftTokenIds +from vllm.v1.worker.gpu.async_utils import async_copy_to_np +from vllm.v1.worker.gpu.input_batch import InputBatch + + +class DraftTokensHandler: + def __init__(self, device: torch.device | None = None): + self.device = device + self.copy_stream = torch.cuda.Stream(device) + self.copy_event = torch.cuda.Event() + + self.req_ids: list[str] = [] + self.draft_tokens_np: np.ndarray | None = None + + def set_draft_tokens( + self, input_batch: InputBatch, draft_tokens: torch.Tensor + ) -> None: + if not input_batch.has_structured_output_reqs: + # No draft token validation needs to be performed by + # the scheduler for this batch. + if self.req_ids: + self.req_ids = [] + self.draft_tokens_np = None + return + + # For spec decoding + structured outputs, we must transfer the + # draft tokens back to the scheduler for grammar validation. + self.req_ids = input_batch.req_ids + current_stream = torch.cuda.current_stream(self.device) + self.copy_stream.wait_stream(current_stream) + with torch.cuda.stream(self.copy_stream): + self.draft_tokens_np = async_copy_to_np(draft_tokens) + self.copy_event.record() + + def get_draft_tokens(self) -> DraftTokenIds | None: + if self.draft_tokens_np is None: + return None + + self.copy_event.synchronize() + return DraftTokenIds(self.req_ids, self.draft_tokens_np.tolist())