[ModelRunner V2] Support spec decode with structured outputs (#33374)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -75,6 +75,9 @@ class InputBatch:
|
||||
cu_num_logits: torch.Tensor
|
||||
cu_num_logits_np: np.ndarray
|
||||
|
||||
# Whether any requests in batch use structured output.
|
||||
has_structured_output_reqs: bool
|
||||
|
||||
@classmethod
|
||||
def make_dummy(
|
||||
cls,
|
||||
@@ -139,6 +142,7 @@ class InputBatch:
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
cu_num_logits_np=cu_num_logits_np,
|
||||
has_structured_output_reqs=False,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.worker.gpu.async_utils import AsyncOutput
|
||||
from vllm.v1.worker.gpu.attn_utils import (
|
||||
build_attn_metadata,
|
||||
@@ -59,6 +59,7 @@ from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
|
||||
from vllm.v1.worker.gpu.sample.sampler import Sampler
|
||||
from vllm.v1.worker.gpu.spec_decode import init_speculator
|
||||
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
|
||||
from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler
|
||||
from vllm.v1.worker.gpu.states import RequestState
|
||||
from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
@@ -167,6 +168,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# LoRA-related workers.
|
||||
self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)
|
||||
|
||||
# Draft tokens propagation - for spec-dec + struct outputs.
|
||||
self.draft_tokens_handler = DraftTokensHandler(self.device)
|
||||
|
||||
# KV Connector if configured.
|
||||
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
|
||||
|
||||
def update_max_model_len(self, max_model_len: int) -> None:
|
||||
@@ -638,6 +643,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
logits_indices=logits_indices,
|
||||
cu_num_logits=cu_num_logits,
|
||||
cu_num_logits_np=cu_num_logits_np,
|
||||
has_structured_output_reqs=scheduler_output.has_structured_output_requests,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -938,7 +944,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_rejected,
|
||||
)
|
||||
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
|
||||
self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)
|
||||
|
||||
if self.use_async_scheduling:
|
||||
return async_output
|
||||
return async_output.get_output()
|
||||
|
||||
def take_draft_token_ids(self) -> DraftTokenIds | None:
|
||||
return self.draft_tokens_handler.get_draft_tokens()
|
||||
|
||||
45
vllm/v1/worker/gpu/spec_decode/utils.py
Normal file
45
vllm/v1/worker/gpu/spec_decode/utils.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.v1.outputs import DraftTokenIds
|
||||
from vllm.v1.worker.gpu.async_utils import async_copy_to_np
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
|
||||
|
||||
class DraftTokensHandler:
|
||||
def __init__(self, device: torch.device | None = None):
|
||||
self.device = device
|
||||
self.copy_stream = torch.cuda.Stream(device)
|
||||
self.copy_event = torch.cuda.Event()
|
||||
|
||||
self.req_ids: list[str] = []
|
||||
self.draft_tokens_np: np.ndarray | None = None
|
||||
|
||||
def set_draft_tokens(
|
||||
self, input_batch: InputBatch, draft_tokens: torch.Tensor
|
||||
) -> None:
|
||||
if not input_batch.has_structured_output_reqs:
|
||||
# No draft token validation needs to be performed by
|
||||
# the scheduler for this batch.
|
||||
if self.req_ids:
|
||||
self.req_ids = []
|
||||
self.draft_tokens_np = None
|
||||
return
|
||||
|
||||
# For spec decoding + structured outputs, we must transfer the
|
||||
# draft tokens back to the scheduler for grammar validation.
|
||||
self.req_ids = input_batch.req_ids
|
||||
current_stream = torch.cuda.current_stream(self.device)
|
||||
self.copy_stream.wait_stream(current_stream)
|
||||
with torch.cuda.stream(self.copy_stream):
|
||||
self.draft_tokens_np = async_copy_to_np(draft_tokens)
|
||||
self.copy_event.record()
|
||||
|
||||
def get_draft_tokens(self) -> DraftTokenIds | None:
|
||||
if self.draft_tokens_np is None:
|
||||
return None
|
||||
|
||||
self.copy_event.synchronize()
|
||||
return DraftTokenIds(self.req_ids, self.draft_tokens_np.tolist())
|
||||
Reference in New Issue
Block a user