[Core] Subclass ModelRunner to support cross-attention & encoder sequences (towards eventual encoder/decoder model support) (#4942)

Co-authored-by: Andrew Feldman <afeld2012@gmail.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
This commit is contained in:
afeldman-nm
2024-08-06 16:51:47 -04:00
committed by GitHub
parent 660470e5a3
commit fd95e026e0
33 changed files with 3957 additions and 333 deletions

View File

@@ -7,10 +7,11 @@ from array import array
from collections import defaultdict
from dataclasses import dataclass, field
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
Union)
Union, cast)
import torch
from vllm.inputs import is_valid_encoder_decoder_llm_inputs
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
@@ -244,24 +245,38 @@ class SequenceData:
class Sequence:
"""Stores the data, status, and block information of a sequence.
The sequence is constructed from the LLMInputs instance passed
in through the `inputs` constructor argument.
For encoder/decoder models, LLMInputs encapsulates both a
decoder and encoder prompt, creating an ambiguity about which
prompt to construct the sequence from. The `from_decoder_prompt`
constructor argument signals whether to construct the Sequence
from the LLMInputs decoder prompt, or encoder prompt.
Args:
seq_id: The ID of the sequence.
inputs: The inputs of the sequence.
block_size: The block size of the sequence. Should be the same as the
block size used by the block manager and cache engine.
eos_token_id: The end-of-sequence (EOS) token id recognized by this LLM.
lora_request: LoRA request.
prompt_adapter_request: Prompt Adapter request.
from_decoder_prompt: Construct Sequence from LLMInputs decoder prompt
(True) or encoder prompt (False.) Must be True
for decoder-only model.
"""
def __init__(
self,
seq_id: int,
inputs: "LLMInputs",
block_size: int,
eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
self,
seq_id: int,
inputs: "LLMInputs",
block_size: int,
eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
from_decoder_prompt: bool = True,
) -> None:
self.seq_id = seq_id
self.inputs = inputs
@@ -269,6 +284,36 @@ class Sequence:
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
self.from_decoder_prompt = from_decoder_prompt
self._prompt: Optional[str] = None
self._prompt_token_ids: Optional[List[int]] = None
# For decoder-only models, a Sequence is constructed
# from an LLMInputs instance (the `inputs` arg.)
#
# For encoder/decoder models the same `inputs`
# instance could be utilized to construct either an
# encoder sequence or a decoder sequence, because
# `LLMInputs` has both decoder- and encoder-oriented
# member variables (i.e. it encapsulates both an encoder
# and a decoder prompt.) The decision of which type of sequence
# to generate is determined by the `from_decoder_prompt` argument.
#
# When constructing a encoder sequence
# (`from_decoder_prompt` False) it matters that
# the `LLMInputs` instance stored in `inputs` is valid
# in the sense that its encoder-related member variables are
# populated; below, an exception is raised if this is
# not the case.
#
# When constructing a decoder sequence (`from_decoder_prompt` True)
# it does not matter whether `inputs` has its encoder-related
# member variables populated.
if not (from_decoder_prompt
or is_valid_encoder_decoder_llm_inputs(inputs)):
raise ValueError("Cannot extract encoder input prompt from "
f"invalid input {inputs}; did you forget the "
"encoder input prompt fields?")
self.data = SequenceData(self.prompt_token_ids)
self.output_logprobs: SampleLogprobs = []
@@ -289,11 +334,35 @@ class Sequence:
@property
def prompt(self) -> Optional[str]:
return self.inputs.get("prompt")
if self._prompt is not None:
# Reuse precomputed prompt string
return self._prompt
# Select decoder or encoder input prompt str,
# as appropriate
prompt_key: str = ("prompt"
if self.from_decoder_prompt else "encoder_prompt")
# Cache prompt
self._prompt = cast(Optional[str], self.inputs.get(prompt_key))
return self._prompt
@property
def prompt_token_ids(self) -> List[int]:
return self.inputs["prompt_token_ids"]
if self._prompt_token_ids is not None:
# Reuse precomputed prompt token ids
return self._prompt_token_ids
# Select decoder or encoder input prompt
# token ids, as appropriate
prompt_token_ids_key: str = ("prompt_token_ids"
if self.from_decoder_prompt else
"encoder_prompt_token_ids")
# Cache computed prompt token ids
self._prompt_token_ids = cast(List[int],
self.inputs.get(prompt_token_ids_key))
return self._prompt_token_ids
@property
def multi_modal_data(self) -> "MultiModalDataDict":
@@ -472,6 +541,22 @@ class SequenceGroup:
# We use the prompt of an arbitrary sequence.
return self.seqs[0].prompt_token_ids
@property
def encoder_prompt(self) -> Optional[str]:
# There are either 0 or 1 encoder sequences
# If one is present, its prompt is distinct
# from the decoder's.
return (self.encoder_seq.prompt
if self.encoder_seq is not None else None)
@property
def encoder_prompt_token_ids(self) -> Optional[List[int]]:
# There are either 0 or 1 encoder sequences
# If one is present, its prompt token ids are
# distinct from the decoder's.
return (self.encoder_seq.prompt_token_ids
if self.encoder_seq is not None else None)
@property
def multi_modal_data(self) -> "MultiModalDataDict":
# All sequences in the group should have the same multi-modal data.