[V1] Refactor model executable interface for multimodal models (#10570)
Signed-off-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@@ -33,7 +33,8 @@ from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalData, MultiModalKwargs
|
||||
from vllm.multimodal.inputs import (MultiModalData, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
@@ -545,6 +546,30 @@ class ChatGLMModel(nn.Module):
|
||||
""")
|
||||
return GLMImagePixelInputs(pixel_values=pixel_values)
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input["pixel_values"] is None:
|
||||
return None
|
||||
pixel_values = image_input["pixel_values"].to(
|
||||
dtype=self.config.torch_dtype)
|
||||
vision_embeddings = self.vision(pixel_values)
|
||||
return vision_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_glm_vision_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
vision_embeddings=multimodal_embeddings,
|
||||
boi_token_id=self.config.boi_token_id,
|
||||
eoi_token_id=self.config.eoi_token_id)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -552,26 +577,17 @@ class ChatGLMModel(nn.Module):
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor:
|
||||
if intermediate_tensors is None:
|
||||
inputs_embeds = self.embedding(input_ids)
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input["pixel_values"] is not None:
|
||||
pixel_values = image_input["pixel_values"].to(
|
||||
dtype=inputs_embeds.dtype)
|
||||
image_embeds = self.vision(pixel_values)
|
||||
|
||||
boi_token_id = self.config.boi_token_id
|
||||
eoi_token_id = self.config.eoi_token_id
|
||||
|
||||
inputs_embeds = merge_glm_vision_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
vision_embeddings=image_embeds,
|
||||
boi_token_id=boi_token_id,
|
||||
eoi_token_id=eoi_token_id)
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
if intermediate_tensors is None and inputs_embeds is None:
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
vision_embeddings)
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = intermediate_tensors["hidden_states"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user