[V1] Refactor model executable interface for multimodal models (#10570)

Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Roger Wang
2024-11-26 12:46:11 -08:00
committed by GitHub
parent 7576cd38df
commit 2f0a0a17a4
18 changed files with 568 additions and 293 deletions

View File

@@ -29,6 +29,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
@@ -38,7 +39,7 @@ from vllm.utils import print_warning_once
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
maybe_prefix, merge_multimodal_embeddings)
# These configs are not part of the model config but the preprocessor
# and processor files, so we hardcode them in the model file for now.
@@ -987,6 +988,29 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
data=self._validate_pixel_values(pixel_values),
)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
assert self.model.vqmodel is not None
image_tokens = self.model.get_image_tokens(image_input["data"].to(
self.config.torch_dtype))
vision_embeddings = self.model.get_input_embeddings(image_tokens)
return vision_embeddings
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
) -> torch.Tensor:
inputs_embeds = self.model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.model.vocabulary_mapping.image_token_id)
return inputs_embeds
def forward(
self,
input_ids: torch.Tensor,
@@ -994,27 +1018,27 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, IntermediateTensors]:
if intermediate_tensors is not None:
inputs_embeds = None
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
vision_embeddings)
input_ids = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
assert self.model.vqmodel is not None
image_tokens = self.model.get_image_tokens(
image_input["data"].to(self.config.torch_dtype))
image_token_id = self.model.vocabulary_mapping.image_token_id
special_image_mask = input_ids == image_token_id
image_tokens = image_tokens.to(input_ids.device,
input_ids.dtype)
input_ids = input_ids.masked_scatter(special_image_mask,
image_tokens)
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors)
hidden_states = self.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
return hidden_states
def compute_logits(