[Core] Factor out common code in SequenceData and Sequence (#8675)

This commit is contained in:
Cyrus Leung
2024-09-21 10:30:39 +08:00
committed by GitHub
parent d4bf085ad0
commit 0455c46ed4
8 changed files with 64 additions and 97 deletions

View File

@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
from array import array
from collections import defaultdict
from dataclasses import dataclass
from functools import cached_property, reduce
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union, cast
@@ -169,6 +170,35 @@ class SequenceData(msgspec.Struct,
# It is used to compute mrope_position_ids.
_mrope_position_delta: Optional[int] = None
@staticmethod
def from_counts(counts_by_token: Mapping[int, int]) -> "SequenceData":
if len(counts_by_token) == 0:
return SequenceData.from_seqs([])
arrs = [
array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
for token_id, count in counts_by_token.items()
]
return SequenceData(reduce(array.__add__, arrs))
@staticmethod
def from_seqs(
prompt_token_ids: GenericSequence[int],
output_token_ids: Optional[GenericSequence[int]] = None,
) -> "SequenceData":
prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
prompt_token_ids)
if output_token_ids is None:
return SequenceData(prompt_token_ids_arr)
output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
output_token_ids)
return SequenceData(prompt_token_ids_arr,
_output_token_ids=output_token_ids_arr)
def __post_init__(self) -> None:
assert self._prompt_token_ids.typecode == "l"
assert self._output_token_ids.typecode == "l"
@@ -370,8 +400,6 @@ class Sequence:
self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
self.from_decoder_prompt = from_decoder_prompt
self._prompt: Optional[str] = None
self._prompt_token_ids: Optional[List[int]] = None
# For decoder-only models, a Sequence is constructed
# from an LLMInputs instance (the `inputs` arg.)
@@ -400,8 +428,7 @@ class Sequence:
f"invalid input {inputs}; did you forget the "
"encoder input prompt fields?")
self.data = SequenceData(
array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids))
self.data = SequenceData.from_seqs(self.prompt_token_ids)
self.output_logprobs: SampleLogprobs = []
self.output_text = ""
@@ -422,37 +449,23 @@ class Sequence:
def n_blocks(self) -> int:
return (self.get_len() + self.block_size - 1) // self.block_size
@property
@cached_property
def prompt(self) -> Optional[str]:
if self._prompt is not None:
# Reuse precomputed prompt string
return self._prompt
# Select decoder or encoder input prompt str,
# as appropriate
# Select decoder or encoder input prompt str, as appropriate
prompt_key: str = ("prompt"
if self.from_decoder_prompt else "encoder_prompt")
# Cache prompt
self._prompt = cast(Optional[str], self.inputs.get(prompt_key))
return self._prompt
return cast(Optional[str], self.inputs.get(prompt_key))
@property
@cached_property
def prompt_token_ids(self) -> List[int]:
if self._prompt_token_ids is not None:
# Reuse precomputed prompt token ids
return self._prompt_token_ids
# Select decoder or encoder input prompt
# token ids, as appropriate
# Select decoder or encoder input prompt token ids, as appropriate
prompt_token_ids_key: str = ("prompt_token_ids"
if self.from_decoder_prompt else
"encoder_prompt_token_ids")
# Cache computed prompt token ids
self._prompt_token_ids = cast(List[int],
self.inputs.get(prompt_token_ids_key))
return self._prompt_token_ids
return cast(List[int], self.inputs.get(prompt_token_ids_key))
@property
def multi_modal_data(self) -> "MultiModalDataDict":