[Bugfix] Merge MM embeddings by index instead of token IDs (#16229)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: NickLucche <nlucches@redhat.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -25,7 +25,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems, NestedTensors)
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
@@ -37,8 +37,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
merge_multimodal_embeddings)
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
from .vision import run_dp_sharded_vision_model
|
||||
|
||||
|
||||
@@ -996,10 +995,13 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
1 else cur_feature[0])
|
||||
return merged_image_features
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
return []
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
@@ -1007,24 +1009,21 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
# Multi-modal token ID may exceed vocab size
|
||||
handle_oov_mm_token: bool = True,
|
||||
) -> torch.Tensor:
|
||||
if multimodal_embeddings is None:
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
else:
|
||||
is_text = input_ids != self.config.image_token_id
|
||||
text_ids = input_ids[is_text]
|
||||
text_embeds = self.language_model.model.get_input_embeddings(
|
||||
text_ids)
|
||||
inputs_embeds = torch.empty(input_ids.shape[0],
|
||||
text_embeds.shape[-1],
|
||||
dtype=text_embeds.dtype,
|
||||
device=text_embeds.device)
|
||||
inputs_embeds[is_text] = text_embeds
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.config.image_token_id)
|
||||
return inputs_embeds
|
||||
# This is to satisfy the type checker for each overload
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().get_input_embeddings(input_ids)
|
||||
|
||||
return super().get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -1038,10 +1037,11 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
inputs_embeds = None
|
||||
elif inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
vision_embeddings,
|
||||
is_multimodal=input_ids == self.config.image_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model(input_ids,
|
||||
|
||||
Reference in New Issue
Block a user