[Core] [Bugfix] Add Input Embeddings (#15428)
Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: 临景 <linjing.yx@alibaba-inc.com> Co-authored-by: Bryce1010 <bryceyx@gmail.com> Co-authored-by: Nan2018 <nan@protopia.ai> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -166,6 +166,9 @@ class SequenceData(msgspec.Struct,
|
||||
_output_token_ids: array = msgspec.field(
|
||||
default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))
|
||||
|
||||
_prompt_embeds: Optional[torch.Tensor] = None
|
||||
_output_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
### The below fields should not be passed as an argument ###
|
||||
_cumulative_logprob: float = 0.0
|
||||
_prompt_token_ids_tuple: tuple[int,
|
||||
@@ -176,6 +179,7 @@ class SequenceData(msgspec.Struct,
|
||||
_num_cached_tokens: int = 0
|
||||
_stage: SequenceStage = SequenceStage.PREFILL
|
||||
_cached_all_token_ids: list[int] = msgspec.field(default_factory=list)
|
||||
_cached_all_token_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
# It is used to get delta input. It is reset when `get_delta_and_reset`
|
||||
# is called.
|
||||
@@ -208,6 +212,8 @@ class SequenceData(msgspec.Struct,
|
||||
def from_seqs(
|
||||
prompt_token_ids: GenericSequence[int],
|
||||
output_token_ids: Optional[GenericSequence[int]] = None,
|
||||
*,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
) -> "SequenceData":
|
||||
"""
|
||||
Construct a :class:`SequenceData` instance from prompt and output
|
||||
@@ -217,13 +223,15 @@ class SequenceData(msgspec.Struct,
|
||||
prompt_token_ids)
|
||||
|
||||
if output_token_ids is None:
|
||||
return SequenceData(prompt_token_ids_arr)
|
||||
return SequenceData(prompt_token_ids_arr,
|
||||
_prompt_embeds=prompt_embeds)
|
||||
|
||||
output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
output_token_ids)
|
||||
|
||||
return SequenceData(prompt_token_ids_arr,
|
||||
_output_token_ids=output_token_ids_arr)
|
||||
_output_token_ids=output_token_ids_arr,
|
||||
_prompt_embeds=prompt_embeds)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self._prompt_token_ids.typecode == "l"
|
||||
@@ -231,6 +239,8 @@ class SequenceData(msgspec.Struct,
|
||||
self._prompt_token_ids_tuple: tuple[int, ...] = tuple(
|
||||
self._prompt_token_ids)
|
||||
self._update_cached_all_tokens()
|
||||
if self._prompt_embeds is not None:
|
||||
self._update_cached_all_token_embeds()
|
||||
|
||||
def _update_cached_all_tokens(self):
|
||||
assert isinstance(self._prompt_token_ids, array)
|
||||
@@ -238,6 +248,13 @@ class SequenceData(msgspec.Struct,
|
||||
self._cached_all_token_ids: list[int] = list(self._prompt_token_ids +
|
||||
self._output_token_ids)
|
||||
|
||||
def _update_cached_all_token_embeds(self):
|
||||
assert isinstance(self._prompt_embeds, torch.Tensor)
|
||||
self._cached_all_token_embeds: torch.Tensor = self._prompt_embeds
|
||||
if self._output_embeds is not None:
|
||||
self._cached_all_token_embeds = torch.cat(
|
||||
(self._cached_all_token_embeds, self._output_embeds), dim=0)
|
||||
|
||||
@property
|
||||
def cumulative_logprob(self) -> float:
|
||||
return self._cumulative_logprob
|
||||
@@ -270,6 +287,15 @@ class SequenceData(msgspec.Struct,
|
||||
new_output_token_ids)
|
||||
self._update_cached_all_tokens()
|
||||
|
||||
@property
|
||||
def output_embeds(self) -> Optional[torch.Tensor]:
|
||||
return self._output_embeds
|
||||
|
||||
@output_embeds.setter
|
||||
def output_embeds(self, new_output_token_embeds: torch.Tensor) -> None:
|
||||
self._output_token_embeds = new_output_token_embeds
|
||||
self._update_cached_all_token_embeds()
|
||||
|
||||
@property
|
||||
def output_token_ids_array(self) -> array:
|
||||
"""Return the prompt token ids in array type.
|
||||
@@ -280,6 +306,15 @@ class SequenceData(msgspec.Struct,
|
||||
assert isinstance(self._output_token_ids, array)
|
||||
return self._output_token_ids
|
||||
|
||||
@property
|
||||
def prompt_embeds(self) -> Optional[torch.Tensor]:
|
||||
return self._prompt_embeds
|
||||
|
||||
@prompt_embeds.setter
|
||||
def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None:
|
||||
self._prompt_embeds = prompt_embeds
|
||||
self._update_cached_all_token_embeds()
|
||||
|
||||
@property
|
||||
def mrope_position_delta(self) -> Optional[int]:
|
||||
return self._mrope_position_delta
|
||||
@@ -288,11 +323,28 @@ class SequenceData(msgspec.Struct,
|
||||
def mrope_position_delta(self, new_mrope_position_delta):
|
||||
self._mrope_position_delta = new_mrope_position_delta
|
||||
|
||||
def append_token_id(self, token_id: int, logprob: float) -> None:
|
||||
def append_token_id(self,
|
||||
token_id: int,
|
||||
logprob: float,
|
||||
token_embed: Optional[torch.Tensor] = None) -> None:
|
||||
self._output_token_ids.append(token_id)
|
||||
self._new_appended_tokens.append(token_id)
|
||||
self._cached_all_token_ids.append(token_id)
|
||||
self._cumulative_logprob += logprob
|
||||
if token_embed is not None:
|
||||
# Do not pass in with batch or sequence dimensions
|
||||
assert token_embed.ndim == 1
|
||||
token_embed = token_embed.detach().cpu().unsqueeze(0)
|
||||
if self._output_embeds is None:
|
||||
self._output_embeds = token_embed
|
||||
else:
|
||||
self._output_embeds = torch.cat(
|
||||
(self._output_embeds, token_embed), dim=0)
|
||||
assert self._cached_all_token_embeds is not None
|
||||
self._cached_all_token_embeds = torch.cat(
|
||||
(self._cached_all_token_embeds,
|
||||
token_embed.to(device=self._cached_all_token_embeds.device)),
|
||||
dim=0)
|
||||
|
||||
def get_len(self) -> int:
|
||||
return len(self._output_token_ids) + len(self._prompt_token_ids)
|
||||
@@ -306,6 +358,9 @@ class SequenceData(msgspec.Struct,
|
||||
def get_token_ids(self) -> list[int]:
|
||||
return self._cached_all_token_ids
|
||||
|
||||
def get_token_embeddings(self) -> Optional[torch.Tensor]:
|
||||
return self._cached_all_token_embeds
|
||||
|
||||
def get_prefix_token_ids(
|
||||
self, num_tokens: int
|
||||
) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]:
|
||||
@@ -387,6 +442,8 @@ class SequenceData(msgspec.Struct,
|
||||
def __repr__(self) -> str:
|
||||
return (f"SequenceData("
|
||||
f"prompt_token_ids={self._prompt_token_ids}, "
|
||||
f"prompt_embeds.shape="
|
||||
f"{getattr(self._prompt_embeds, 'shape', None)}, "
|
||||
f"output_token_ids={self.output_token_ids}, "
|
||||
f"cumulative_logprob={self.cumulative_logprob}, "
|
||||
f"get_num_computed_tokens={self.get_num_computed_tokens()})")
|
||||
@@ -425,7 +482,10 @@ class Sequence:
|
||||
self.lora_request = lora_request
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
|
||||
self.data = SequenceData.from_seqs(self.prompt_token_ids)
|
||||
self.data = SequenceData.from_seqs(
|
||||
self.prompt_token_ids,
|
||||
prompt_embeds=self.inputs["prompt_embeds"]
|
||||
if self.inputs["type"] == "embeds" else None)
|
||||
self.output_logprobs: SampleLogprobs = []
|
||||
self.output_text = ""
|
||||
|
||||
@@ -448,14 +508,20 @@ class Sequence:
|
||||
|
||||
@property
|
||||
def prompt(self) -> Optional[str]:
|
||||
if self.inputs["type"] == "embeds":
|
||||
return None
|
||||
return self.inputs.get("prompt")
|
||||
|
||||
@property
|
||||
def prompt_token_ids(self) -> list[int]:
|
||||
if self.inputs["type"] == "embeds":
|
||||
return [0] * len(self.inputs["prompt_embeds"])
|
||||
return self.inputs["prompt_token_ids"]
|
||||
|
||||
@property
|
||||
def token_type_ids(self) -> list[int]:
|
||||
if self.inputs["type"] == "embeds":
|
||||
return []
|
||||
return self.inputs.get("token_type_ids", [])
|
||||
|
||||
@property
|
||||
@@ -554,11 +620,14 @@ class Sequence:
|
||||
"""Reset the sequence states for recomputation."""
|
||||
self.data.reset_state_for_recompute()
|
||||
|
||||
def append_token_id(self, token_id: int, logprobs: dict[int,
|
||||
Logprob]) -> None:
|
||||
def append_token_id(self,
|
||||
token_id: int,
|
||||
logprobs: dict[int, Logprob],
|
||||
token_embed: Optional[torch.Tensor] = None) -> None:
|
||||
assert token_id in logprobs
|
||||
self.output_logprobs.append(logprobs)
|
||||
self.data.append_token_id(token_id, logprobs[token_id].logprob)
|
||||
self.data.append_token_id(token_id, logprobs[token_id].logprob,
|
||||
token_embed)
|
||||
|
||||
def get_len(self) -> int:
|
||||
return self.data.get_len()
|
||||
@@ -889,6 +958,10 @@ class SequenceGroup:
|
||||
f"sampling_params={self.sampling_params}, "
|
||||
f"num_seqs={len(self.seqs)})")
|
||||
|
||||
def uses_prompt_embeds(self) -> bool:
|
||||
"""Returns True if the sequence group uses input embeds."""
|
||||
return any(seq.data.prompt_embeds is not None for seq in self.seqs)
|
||||
|
||||
|
||||
class SequenceGroupMetadataDelta(
|
||||
msgspec.Struct,
|
||||
@@ -1043,10 +1116,14 @@ class SequenceOutput(
|
||||
parent_seq_id: int
|
||||
output_token: int
|
||||
logprobs: dict[int, Logprob]
|
||||
output_embed: Optional[torch.Tensor] = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
output_embed_shape = \
|
||||
self.output_embed.shape if self.output_embed is not None else None
|
||||
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
|
||||
f"output_token={self.output_token}, "
|
||||
f"output_embed.shape={output_embed_shape}"
|
||||
f"logprobs={self.logprobs})")
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user