[Core] Subclass ModelRunner to support cross-attention & encoder sequences (towards eventual encoder/decoder model support) (#4942)
Co-authored-by: Andrew Feldman <afeld2012@gmail.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
105
vllm/sequence.py
105
vllm/sequence.py
@@ -7,10 +7,11 @@ from array import array
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
|
||||
Union)
|
||||
Union, cast)
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.inputs import is_valid_encoder_decoder_llm_inputs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
@@ -244,24 +245,38 @@ class SequenceData:
|
||||
class Sequence:
|
||||
"""Stores the data, status, and block information of a sequence.
|
||||
|
||||
The sequence is constructed from the LLMInputs instance passed
|
||||
in through the `inputs` constructor argument.
|
||||
|
||||
For encoder/decoder models, LLMInputs encapsulates both a
|
||||
decoder and encoder prompt, creating an ambiguity about which
|
||||
prompt to construct the sequence from. The `from_decoder_prompt`
|
||||
constructor argument signals whether to construct the Sequence
|
||||
from the LLMInputs decoder prompt, or encoder prompt.
|
||||
|
||||
Args:
|
||||
seq_id: The ID of the sequence.
|
||||
inputs: The inputs of the sequence.
|
||||
block_size: The block size of the sequence. Should be the same as the
|
||||
block size used by the block manager and cache engine.
|
||||
eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
|
||||
lora_request: LoRA request.
|
||||
prompt_adapter_request: Prompt Adapter request.
|
||||
from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt
|
||||
(True) or encoder prompt (False.) Must be True
|
||||
for decoder-only model.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seq_id: int,
|
||||
inputs: "LLMInputs",
|
||||
block_size: int,
|
||||
eos_token_id: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
self,
|
||||
seq_id: int,
|
||||
inputs: "LLMInputs",
|
||||
block_size: int,
|
||||
eos_token_id: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
from_decoder_prompt: bool = True,
|
||||
) -> None:
|
||||
self.seq_id = seq_id
|
||||
self.inputs = inputs
|
||||
@@ -269,6 +284,36 @@ class Sequence:
|
||||
self.eos_token_id = eos_token_id
|
||||
self.lora_request = lora_request
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
self.from_decoder_prompt = from_decoder_prompt
|
||||
self._prompt: Optional[str] = None
|
||||
self._prompt_token_ids: Optional[List[int]] = None
|
||||
|
||||
# For decoder-only models, a Sequence is constructed
|
||||
# from an LLMInputs instance (the `inputs` arg.)
|
||||
#
|
||||
# For encoder/decoder models the same `inputs`
|
||||
# instance could be utilized to construct either an
|
||||
# encoder sequence or a decoder sequence, because
|
||||
# `LLMInputs` has both decoder- and encoder-oriented
|
||||
# member variables (i.e. it encapsulates both an encoder
|
||||
# and a decoder prompt.) The decision of which type of sequence
|
||||
# to generate is determined by the `from_decoder_prompt` argument.
|
||||
#
|
||||
# When constructing a encoder sequence
|
||||
# (`from_decoder_prompt` False) it matters that
|
||||
# the `LLMInputs` instance stored in `inputs` is valid
|
||||
# in the sense that its encoder-related member variables are
|
||||
# populated; below, an exception is raised if this is
|
||||
# not the case.
|
||||
#
|
||||
# When constructing a decoder sequence (`from_decoder_prompt` True)
|
||||
# it does not matter whether `inputs` has its encoder-related
|
||||
# member variables populated.
|
||||
if not (from_decoder_prompt
|
||||
or is_valid_encoder_decoder_llm_inputs(inputs)):
|
||||
raise ValueError("Cannot extract encoder input prompt from "
|
||||
f"invalid input {inputs}; did you forget the "
|
||||
"encoder input prompt fields?")
|
||||
|
||||
self.data = SequenceData(self.prompt_token_ids)
|
||||
self.output_logprobs: SampleLogprobs = []
|
||||
@@ -289,11 +334,35 @@ class Sequence:
|
||||
|
||||
@property
|
||||
def prompt(self) -> Optional[str]:
|
||||
return self.inputs.get("prompt")
|
||||
if self._prompt is not None:
|
||||
# Reuse precomputed prompt string
|
||||
return self._prompt
|
||||
|
||||
# Select decoder or encoder input prompt str,
|
||||
# as appropriate
|
||||
prompt_key: str = ("prompt"
|
||||
if self.from_decoder_prompt else "encoder_prompt")
|
||||
|
||||
# Cache prompt
|
||||
self._prompt = cast(Optional[str], self.inputs.get(prompt_key))
|
||||
return self._prompt
|
||||
|
||||
@property
|
||||
def prompt_token_ids(self) -> List[int]:
|
||||
return self.inputs["prompt_token_ids"]
|
||||
if self._prompt_token_ids is not None:
|
||||
# Reuse precomputed prompt token ids
|
||||
return self._prompt_token_ids
|
||||
|
||||
# Select decoder or encoder input prompt
|
||||
# token ids, as appropriate
|
||||
prompt_token_ids_key: str = ("prompt_token_ids"
|
||||
if self.from_decoder_prompt else
|
||||
"encoder_prompt_token_ids")
|
||||
|
||||
# Cache computed prompt token ids
|
||||
self._prompt_token_ids = cast(List[int],
|
||||
self.inputs.get(prompt_token_ids_key))
|
||||
return self._prompt_token_ids
|
||||
|
||||
@property
|
||||
def multi_modal_data(self) -> "MultiModalDataDict":
|
||||
@@ -472,6 +541,22 @@ class SequenceGroup:
|
||||
# We use the prompt of an arbitrary sequence.
|
||||
return self.seqs[0].prompt_token_ids
|
||||
|
||||
@property
|
||||
def encoder_prompt(self) -> Optional[str]:
|
||||
# There are either 0 or 1 encoder sequences
|
||||
# If one is present, its prompt is distinct
|
||||
# from the decoder's.
|
||||
return (self.encoder_seq.prompt
|
||||
if self.encoder_seq is not None else None)
|
||||
|
||||
@property
|
||||
def encoder_prompt_token_ids(self) -> Optional[List[int]]:
|
||||
# There are either 0 or 1 encoder sequences
|
||||
# If one is present, its prompt token ids are
|
||||
# distinct from the decoder's.
|
||||
return (self.encoder_seq.prompt_token_ids
|
||||
if self.encoder_seq is not None else None)
|
||||
|
||||
@property
|
||||
def multi_modal_data(self) -> "MultiModalDataDict":
|
||||
# All sequences in the group should have the same multi-modal data.
|
||||
|
||||
Reference in New Issue
Block a user