[ModelRunner V2] Support spec decode with structured outputs (#33374)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-02-01 16:19:59 -08:00
committed by GitHub
parent e535d90deb
commit cf0a99f84d
3 changed files with 60 additions and 1 deletions

View File

@@ -75,6 +75,9 @@ class InputBatch:
cu_num_logits: torch.Tensor
cu_num_logits_np: np.ndarray
# Whether any requests in batch use structured output.
has_structured_output_reqs: bool
@classmethod
def make_dummy(
cls,
@@ -139,6 +142,7 @@ class InputBatch:
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np,
has_structured_output_reqs=False,
)

View File

@@ -20,7 +20,7 @@ from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.gpu.async_utils import AsyncOutput
from vllm.v1.worker.gpu.attn_utils import (
build_attn_metadata,
@@ -59,6 +59,7 @@ from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
from vllm.v1.worker.gpu.sample.sampler import Sampler
from vllm.v1.worker.gpu.spec_decode import init_speculator
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler
from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
@@ -167,6 +168,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# LoRA-related workers.
self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)
# Draft tokens propagation - for spec-dec + struct outputs.
self.draft_tokens_handler = DraftTokensHandler(self.device)
# KV Connector if configured.
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
def update_max_model_len(self, max_model_len: int) -> None:
@@ -638,6 +643,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np,
has_structured_output_reqs=scheduler_output.has_structured_output_requests,
)
@torch.inference_mode()
@@ -938,7 +944,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_rejected,
)
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)
if self.use_async_scheduling:
return async_output
return async_output.get_output()
def take_draft_token_ids(self) -> DraftTokenIds | None:
return self.draft_tokens_handler.get_draft_tokens()

View File

@@ -0,0 +1,45 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.v1.outputs import DraftTokenIds
from vllm.v1.worker.gpu.async_utils import async_copy_to_np
from vllm.v1.worker.gpu.input_batch import InputBatch
class DraftTokensHandler:
def __init__(self, device: torch.device | None = None):
self.device = device
self.copy_stream = torch.cuda.Stream(device)
self.copy_event = torch.cuda.Event()
self.req_ids: list[str] = []
self.draft_tokens_np: np.ndarray | None = None
def set_draft_tokens(
self, input_batch: InputBatch, draft_tokens: torch.Tensor
) -> None:
if not input_batch.has_structured_output_reqs:
# No draft token validation needs to be performed by
# the scheduler for this batch.
if self.req_ids:
self.req_ids = []
self.draft_tokens_np = None
return
# For spec decoding + structured outputs, we must transfer the
# draft tokens back to the scheduler for grammar validation.
self.req_ids = input_batch.req_ids
current_stream = torch.cuda.current_stream(self.device)
self.copy_stream.wait_stream(current_stream)
with torch.cuda.stream(self.copy_stream):
self.draft_tokens_np = async_copy_to_np(draft_tokens)
self.copy_event.record()
def get_draft_tokens(self) -> DraftTokenIds | None:
if self.draft_tokens_np is None:
return None
self.copy_event.synchronize()
return DraftTokenIds(self.req_ids, self.draft_tokens_np.tolist())