[V1] Refactor model executable interface for multimodal models (#10570)
Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@@ -19,6 +19,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
@@ -565,6 +566,30 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
for i, patch_features_batch in enumerate(patch_embeddings)
|
||||
]
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if multimodal_embeddings is None:
|
||||
return self.language_model.get_input_embeddings(input_ids)
|
||||
|
||||
inputs_embeds = embed_multimodal(
|
||||
input_ids,
|
||||
self.config.image_token_index,
|
||||
self.language_model.model.get_input_embeddings,
|
||||
multimodal_embeddings,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -572,6 +597,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
"""Run forward pass for LlaVA-NeXT.
|
||||
@@ -620,24 +646,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"""
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
else:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
inputs_embeds = embed_multimodal(
|
||||
input_ids,
|
||||
self.config.image_token_index,
|
||||
self.language_model.model.get_input_embeddings,
|
||||
lambda _: self._process_image_input(image_input),
|
||||
)
|
||||
else:
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
# for `torch.compile` integration
|
||||
input_ids = None
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
@@ -645,7 +661,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
|
||||
Reference in New Issue
Block a user