[Core] Factor out common code in SequenceData and Sequence (#8675)
This commit is contained in:
@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
|
||||
from array import array
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property, reduce
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Tuple, Union, cast
|
||||
@@ -169,6 +170,35 @@ class SequenceData(msgspec.Struct,
|
||||
# It is used to compute mrope_position_ids.
|
||||
_mrope_position_delta: Optional[int] = None
|
||||
|
||||
@staticmethod
|
||||
def from_counts(counts_by_token: Mapping[int, int]) -> "SequenceData":
|
||||
if len(counts_by_token) == 0:
|
||||
return SequenceData.from_seqs([])
|
||||
|
||||
arrs = [
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, [token_id]) * count
|
||||
for token_id, count in counts_by_token.items()
|
||||
]
|
||||
|
||||
return SequenceData(reduce(array.__add__, arrs))
|
||||
|
||||
@staticmethod
|
||||
def from_seqs(
|
||||
prompt_token_ids: GenericSequence[int],
|
||||
output_token_ids: Optional[GenericSequence[int]] = None,
|
||||
) -> "SequenceData":
|
||||
prompt_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
prompt_token_ids)
|
||||
|
||||
if output_token_ids is None:
|
||||
return SequenceData(prompt_token_ids_arr)
|
||||
|
||||
output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
output_token_ids)
|
||||
|
||||
return SequenceData(prompt_token_ids_arr,
|
||||
_output_token_ids=output_token_ids_arr)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self._prompt_token_ids.typecode == "l"
|
||||
assert self._output_token_ids.typecode == "l"
|
||||
@@ -370,8 +400,6 @@ class Sequence:
|
||||
self.lora_request = lora_request
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
self.from_decoder_prompt = from_decoder_prompt
|
||||
self._prompt: Optional[str] = None
|
||||
self._prompt_token_ids: Optional[List[int]] = None
|
||||
|
||||
# For decoder-only models, a Sequence is constructed
|
||||
# from an LLMInputs instance (the `inputs` arg.)
|
||||
@@ -400,8 +428,7 @@ class Sequence:
|
||||
f"invalid input {inputs}; did you forget the "
|
||||
"encoder input prompt fields?")
|
||||
|
||||
self.data = SequenceData(
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, self.prompt_token_ids))
|
||||
self.data = SequenceData.from_seqs(self.prompt_token_ids)
|
||||
self.output_logprobs: SampleLogprobs = []
|
||||
self.output_text = ""
|
||||
|
||||
@@ -422,37 +449,23 @@ class Sequence:
|
||||
def n_blocks(self) -> int:
|
||||
return (self.get_len() + self.block_size - 1) // self.block_size
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def prompt(self) -> Optional[str]:
|
||||
if self._prompt is not None:
|
||||
# Reuse precomputed prompt string
|
||||
return self._prompt
|
||||
|
||||
# Select decoder or encoder input prompt str,
|
||||
# as appropriate
|
||||
# Select decoder or encoder input prompt str, as appropriate
|
||||
prompt_key: str = ("prompt"
|
||||
if self.from_decoder_prompt else "encoder_prompt")
|
||||
|
||||
# Cache prompt
|
||||
self._prompt = cast(Optional[str], self.inputs.get(prompt_key))
|
||||
return self._prompt
|
||||
return cast(Optional[str], self.inputs.get(prompt_key))
|
||||
|
||||
@property
|
||||
@cached_property
|
||||
def prompt_token_ids(self) -> List[int]:
|
||||
if self._prompt_token_ids is not None:
|
||||
# Reuse precomputed prompt token ids
|
||||
return self._prompt_token_ids
|
||||
|
||||
# Select decoder or encoder input prompt
|
||||
# token ids, as appropriate
|
||||
# Select decoder or encoder input prompt token ids, as appropriate
|
||||
prompt_token_ids_key: str = ("prompt_token_ids"
|
||||
if self.from_decoder_prompt else
|
||||
"encoder_prompt_token_ids")
|
||||
|
||||
# Cache computed prompt token ids
|
||||
self._prompt_token_ids = cast(List[int],
|
||||
self.inputs.get(prompt_token_ids_key))
|
||||
return self._prompt_token_ids
|
||||
return cast(List[int], self.inputs.get(prompt_token_ids_key))
|
||||
|
||||
@property
|
||||
def multi_modal_data(self) -> "MultiModalDataDict":
|
||||
|
||||
Reference in New Issue
Block a user