[V1] Support VLMs with fine-grained scheduling (#9871)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Woosuk Kwon
2024-11-12 20:53:13 -08:00
committed by GitHub
parent 0d4ea3fb5c
commit bbd3e86926
12 changed files with 542 additions and 96 deletions

View File

@@ -216,9 +216,11 @@ class GPT2Model(nn.Module):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor],
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
inputs_embeds = self.wte(input_ids)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
else:
@@ -263,6 +265,9 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.wte(input_ids)
def forward(
self,
input_ids: torch.Tensor,
@@ -270,9 +275,11 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states
def compute_logits(