[V1][Core] Support for Structured Outputs (#12388)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Signed-off-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -25,7 +25,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
LayerBlockType, cdiv, is_pin_memory_available)
|
||||
LayerBlockType, LazyLoader, cdiv,
|
||||
is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
from vllm.v1.engine.mm_input_cache import MMInputCacheClient
|
||||
@@ -40,7 +41,11 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr
|
||||
|
||||
from vllm.v1.core.scheduler_output import SchedulerOutput
|
||||
else:
|
||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -860,6 +865,53 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def apply_grammar_bitmask(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
logits: torch.Tensor,
|
||||
):
|
||||
# Serialization of np.ndarray is much more efficient than a tensor,
|
||||
# so we receive it in that format.
|
||||
grammar_bitmask = scheduler_output.grammar_bitmask
|
||||
if grammar_bitmask is None:
|
||||
return
|
||||
|
||||
# We receive the structured output bitmask from the scheduler, but the
|
||||
# indices of the requests in the batch may not match the indices of
|
||||
# the bitmask since the scheduler doesn't know how the gpu runner is
|
||||
# ordering the requests in the batch. We need to sort the bitmask to
|
||||
# match the order of the requests used here.
|
||||
struct_out_req_batch_indices: dict[str, int] = {}
|
||||
indices_match = True
|
||||
for req_id in self.input_batch.req_ids:
|
||||
mask_index = scheduler_output.structured_output_request_ids.get(
|
||||
req_id)
|
||||
if mask_index is None:
|
||||
# not a structured output request
|
||||
continue
|
||||
batch_index = self.input_batch.req_id_to_index[req_id]
|
||||
if batch_index != mask_index:
|
||||
indices_match = False
|
||||
struct_out_req_batch_indices[req_id] = batch_index
|
||||
|
||||
if not indices_match:
|
||||
# Sort the bitmask to match the order of the requests
|
||||
sorted_bitmask = np.zeros_like(grammar_bitmask)
|
||||
for req_id, batch_index in struct_out_req_batch_indices.items():
|
||||
orig_index = scheduler_output.structured_output_request_ids[
|
||||
req_id]
|
||||
sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
|
||||
grammar_bitmask = sorted_bitmask
|
||||
|
||||
grammar_bitmask = torch.from_numpy(grammar_bitmask)
|
||||
|
||||
# TODO: compatibility with spec decode
|
||||
xgr.apply_token_bitmask_inplace(
|
||||
logits,
|
||||
grammar_bitmask.to(self.device, non_blocking=True),
|
||||
indices=list(struct_out_req_batch_indices.values()),
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
@@ -945,6 +997,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
|
||||
# Apply structured output bitmasks if present
|
||||
if scheduler_output.grammar_bitmask is not None:
|
||||
self.apply_grammar_bitmask(scheduler_output, logits)
|
||||
|
||||
# Sample the next token and get logprobs if needed.
|
||||
sampling_metadata = self.input_batch.sampling_metadata
|
||||
if not self.use_spec_decode:
|
||||
|
||||
Reference in New Issue
Block a user