[Core] [Bugfix] Add Input Embeddings (#15428)

Signed-off-by: Andrew Sansom <andrew@protopia.ai>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: 临景 <linjing.yx@alibaba-inc.com>
Co-authored-by: Bryce1010 <bryceyx@gmail.com>
Co-authored-by: Nan2018 <nan@protopia.ai>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Andrew Sansom
2025-05-02 03:06:39 -05:00
committed by GitHub
parent 9e2de9b9e9
commit cc2a77d7f1
22 changed files with 691 additions and 113 deletions

View File

@@ -166,6 +166,9 @@ class SequenceData(msgspec.Struct,
_output_token_ids: array = msgspec.field(
default_factory=lambda: array(VLLM_TOKEN_ID_ARRAY_TYPE, []))
_prompt_embeds: Optional[torch.Tensor] = None
_output_embeds: Optional[torch.Tensor] = None
### The below fields should not be passed as an argument ###
_cumulative_logprob: float = 0.0
_prompt_token_ids_tuple: tuple[int,
@@ -176,6 +179,7 @@ class SequenceData(msgspec.Struct,
_num_cached_tokens: int = 0
_stage: SequenceStage = SequenceStage.PREFILL
_cached_all_token_ids: list[int] = msgspec.field(default_factory=list)
_cached_all_token_embeds: Optional[torch.Tensor] = None
# It is used to get delta input. It is reset when `get_delta_and_reset`
# is called.
@@ -208,6 +212,8 @@ class SequenceData(msgspec.Struct,
def from_seqs(
prompt_token_ids: GenericSequence[int],
output_token_ids: Optional[GenericSequence[int]] = None,
*,
prompt_embeds: Optional[torch.Tensor] = None,
) -> "SequenceData":
"""
Construct a :class:`SequenceData` instance from prompt and output
@@ -217,13 +223,15 @@ class SequenceData(msgspec.Struct,
prompt_token_ids)
if output_token_ids is None:
return SequenceData(prompt_token_ids_arr)
return SequenceData(prompt_token_ids_arr,
_prompt_embeds=prompt_embeds)
output_token_ids_arr = array(VLLM_TOKEN_ID_ARRAY_TYPE,
output_token_ids)
return SequenceData(prompt_token_ids_arr,
_output_token_ids=output_token_ids_arr)
_output_token_ids=output_token_ids_arr,
_prompt_embeds=prompt_embeds)
def __post_init__(self) -> None:
assert self._prompt_token_ids.typecode == "l"
@@ -231,6 +239,8 @@ class SequenceData(msgspec.Struct,
self._prompt_token_ids_tuple: tuple[int, ...] = tuple(
self._prompt_token_ids)
self._update_cached_all_tokens()
if self._prompt_embeds is not None:
self._update_cached_all_token_embeds()
def _update_cached_all_tokens(self):
assert isinstance(self._prompt_token_ids, array)
@@ -238,6 +248,13 @@ class SequenceData(msgspec.Struct,
self._cached_all_token_ids: list[int] = list(self._prompt_token_ids +
self._output_token_ids)
def _update_cached_all_token_embeds(self):
assert isinstance(self._prompt_embeds, torch.Tensor)
self._cached_all_token_embeds: torch.Tensor = self._prompt_embeds
if self._output_embeds is not None:
self._cached_all_token_embeds = torch.cat(
(self._cached_all_token_embeds, self._output_embeds), dim=0)
@property
def cumulative_logprob(self) -> float:
return self._cumulative_logprob
@@ -270,6 +287,15 @@ class SequenceData(msgspec.Struct,
new_output_token_ids)
self._update_cached_all_tokens()
@property
def output_embeds(self) -> Optional[torch.Tensor]:
return self._output_embeds
@output_embeds.setter
def output_embeds(self, new_output_token_embeds: torch.Tensor) -> None:
self._output_token_embeds = new_output_token_embeds
self._update_cached_all_token_embeds()
@property
def output_token_ids_array(self) -> array:
"""Return the prompt token ids in array type.
@@ -280,6 +306,15 @@ class SequenceData(msgspec.Struct,
assert isinstance(self._output_token_ids, array)
return self._output_token_ids
@property
def prompt_embeds(self) -> Optional[torch.Tensor]:
return self._prompt_embeds
@prompt_embeds.setter
def prompt_embeds(self, prompt_embeds: torch.Tensor) -> None:
self._prompt_embeds = prompt_embeds
self._update_cached_all_token_embeds()
@property
def mrope_position_delta(self) -> Optional[int]:
return self._mrope_position_delta
@@ -288,11 +323,28 @@ class SequenceData(msgspec.Struct,
def mrope_position_delta(self, new_mrope_position_delta):
self._mrope_position_delta = new_mrope_position_delta
def append_token_id(self, token_id: int, logprob: float) -> None:
def append_token_id(self,
token_id: int,
logprob: float,
token_embed: Optional[torch.Tensor] = None) -> None:
self._output_token_ids.append(token_id)
self._new_appended_tokens.append(token_id)
self._cached_all_token_ids.append(token_id)
self._cumulative_logprob += logprob
if token_embed is not None:
# Do not pass in with batch or sequence dimensions
assert token_embed.ndim == 1
token_embed = token_embed.detach().cpu().unsqueeze(0)
if self._output_embeds is None:
self._output_embeds = token_embed
else:
self._output_embeds = torch.cat(
(self._output_embeds, token_embed), dim=0)
assert self._cached_all_token_embeds is not None
self._cached_all_token_embeds = torch.cat(
(self._cached_all_token_embeds,
token_embed.to(device=self._cached_all_token_embeds.device)),
dim=0)
def get_len(self) -> int:
return len(self._output_token_ids) + len(self._prompt_token_ids)
@@ -306,6 +358,9 @@ class SequenceData(msgspec.Struct,
def get_token_ids(self) -> list[int]:
return self._cached_all_token_ids
def get_token_embeddings(self) -> Optional[torch.Tensor]:
return self._cached_all_token_embeds
def get_prefix_token_ids(
self, num_tokens: int
) -> tuple[tuple[int, ...], Optional[tuple[int, ...]]]:
@@ -387,6 +442,8 @@ class SequenceData(msgspec.Struct,
def __repr__(self) -> str:
return (f"SequenceData("
f"prompt_token_ids={self._prompt_token_ids}, "
f"prompt_embeds.shape="
f"{getattr(self._prompt_embeds, 'shape', None)}, "
f"output_token_ids={self.output_token_ids}, "
f"cumulative_logprob={self.cumulative_logprob}, "
f"get_num_computed_tokens={self.get_num_computed_tokens()})")
@@ -425,7 +482,10 @@ class Sequence:
self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
self.data = SequenceData.from_seqs(self.prompt_token_ids)
self.data = SequenceData.from_seqs(
self.prompt_token_ids,
prompt_embeds=self.inputs["prompt_embeds"]
if self.inputs["type"] == "embeds" else None)
self.output_logprobs: SampleLogprobs = []
self.output_text = ""
@@ -448,14 +508,20 @@ class Sequence:
@property
def prompt(self) -> Optional[str]:
if self.inputs["type"] == "embeds":
return None
return self.inputs.get("prompt")
@property
def prompt_token_ids(self) -> list[int]:
if self.inputs["type"] == "embeds":
return [0] * len(self.inputs["prompt_embeds"])
return self.inputs["prompt_token_ids"]
@property
def token_type_ids(self) -> list[int]:
if self.inputs["type"] == "embeds":
return []
return self.inputs.get("token_type_ids", [])
@property
@@ -554,11 +620,14 @@ class Sequence:
"""Reset the sequence states for recomputation."""
self.data.reset_state_for_recompute()
def append_token_id(self, token_id: int, logprobs: dict[int,
Logprob]) -> None:
def append_token_id(self,
token_id: int,
logprobs: dict[int, Logprob],
token_embed: Optional[torch.Tensor] = None) -> None:
assert token_id in logprobs
self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob)
self.data.append_token_id(token_id, logprobs[token_id].logprob,
token_embed)
def get_len(self) -> int:
return self.data.get_len()
@@ -889,6 +958,10 @@ class SequenceGroup:
f"sampling_params={self.sampling_params}, "
f"num_seqs={len(self.seqs)})")
def uses_prompt_embeds(self) -> bool:
"""Returns True if the sequence group uses input embeds."""
return any(seq.data.prompt_embeds is not None for seq in self.seqs)
class SequenceGroupMetadataDelta(
msgspec.Struct,
@@ -1043,10 +1116,14 @@ class SequenceOutput(
parent_seq_id: int
output_token: int
logprobs: dict[int, Logprob]
output_embed: Optional[torch.Tensor] = None
def __repr__(self) -> str:
output_embed_shape = \
self.output_embed.shape if self.output_embed is not None else None
return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, "
f"output_token={self.output_token}, "
f"output_embed.shape={output_embed_shape}"
f"logprobs={self.logprobs})")
def __eq__(self, other: object) -> bool: