[V0 Deprecation] Remove V0 logic from get_input_embeddings interface (#25242)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-09-19 19:10:52 +08:00
committed by GitHub
parent a3d087adec
commit 5089fd749c
4 changed files with 21 additions and 83 deletions

View File

@@ -46,7 +46,8 @@ from vllm.sequence import IntermediateTensors
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vision_encoder_info
EOT = "<|endofturn|>"
@@ -740,33 +741,20 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
**kwargs,
) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if (kwargs.get("pixel_values_images") is not None
or kwargs.get("pixel_values_videos")
is not None): # v0 compatibility
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
if multimodal_embeddings is not None:
multimodal_embeddings = torch.cat(multimodal_embeddings, dim=0)
_mask_image = input_ids == self.config.image_token_id
_mask_video = input_ids == self.config.video_token_id
assert _mask_image.sum() + _mask_video.sum() == len(
multimodal_embeddings)
if multimodal_embeddings is not None \
and len(multimodal_embeddings) != 0:
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
placeholder_token_id=[
self.config.image_token_id,
self.config.video_token_id,
],
)
if multimodal_embeddings.dtype != inputs_embeds.dtype:
multimodal_embeddings = multimodal_embeddings.to(
dtype=inputs_embeds.dtype)
if multimodal_embeddings.device != inputs_embeds.device:
multimodal_embeddings = multimodal_embeddings.to(
device=inputs_embeds.device)
if _mask_image.sum() > 0:
inputs_embeds[
_mask_image] = multimodal_embeddings[:sum(_mask_image)]
if _mask_video.sum() > 0:
inputs_embeds[_mask_video] = multimodal_embeddings[
-sum(_mask_video):]
return inputs_embeds
def forward(
@@ -783,8 +771,9 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids=input_ids,
**kwargs)
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
multimodal_embeddings)
input_ids = None
hidden_states = self.language_model.model(input_ids,
positions,