[V0 Deprecation] Remove V0 logic from get_input_embeddings interface (#25242)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -46,7 +46,8 @@ from vllm.sequence import IntermediateTensors
|
||||
from .clip import CLIPVisionModel
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import get_vision_encoder_info
|
||||
|
||||
EOT = "<|endofturn|>"
|
||||
@@ -740,33 +741,20 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if (kwargs.get("pixel_values_images") is not None
|
||||
or kwargs.get("pixel_values_videos")
|
||||
is not None): # v0 compatibility
|
||||
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
if multimodal_embeddings is not None:
|
||||
multimodal_embeddings = torch.cat(multimodal_embeddings, dim=0)
|
||||
_mask_image = input_ids == self.config.image_token_id
|
||||
_mask_video = input_ids == self.config.video_token_id
|
||||
assert _mask_image.sum() + _mask_video.sum() == len(
|
||||
multimodal_embeddings)
|
||||
if multimodal_embeddings is not None \
|
||||
and len(multimodal_embeddings) != 0:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
placeholder_token_id=[
|
||||
self.config.image_token_id,
|
||||
self.config.video_token_id,
|
||||
],
|
||||
)
|
||||
|
||||
if multimodal_embeddings.dtype != inputs_embeds.dtype:
|
||||
multimodal_embeddings = multimodal_embeddings.to(
|
||||
dtype=inputs_embeds.dtype)
|
||||
if multimodal_embeddings.device != inputs_embeds.device:
|
||||
multimodal_embeddings = multimodal_embeddings.to(
|
||||
device=inputs_embeds.device)
|
||||
|
||||
if _mask_image.sum() > 0:
|
||||
inputs_embeds[
|
||||
_mask_image] = multimodal_embeddings[:sum(_mask_image)]
|
||||
if _mask_video.sum() > 0:
|
||||
inputs_embeds[_mask_video] = multimodal_embeddings[
|
||||
-sum(_mask_video):]
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
@@ -783,8 +771,9 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
inputs_embeds = self.get_input_embeddings(input_ids=input_ids,
|
||||
**kwargs)
|
||||
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
multimodal_embeddings)
|
||||
input_ids = None
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
|
||||
Reference in New Issue
Block a user