[Model][VLM] Initialize support for Mono-InternVL model (#9528)
This commit is contained in:
@@ -21,7 +21,8 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
|
||||
token_inputs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.models.intern_vit import InternVisionModel
|
||||
from vllm.model_executor.models.intern_vit import (InternVisionModel,
|
||||
InternVisionPatchModel)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
@@ -427,13 +428,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.downsample_ratio = config.downsample_ratio
|
||||
self.ps_version = config.ps_version
|
||||
|
||||
vision_feature_layer = self.select_layer
|
||||
if vision_feature_layer < 0:
|
||||
num_hidden_layers = config.vision_config.num_hidden_layers \
|
||||
+ vision_feature_layer + 1
|
||||
else:
|
||||
num_hidden_layers = vision_feature_layer + 1
|
||||
self.vision_model = self._init_vision_model(config, num_hidden_layers)
|
||||
self.llm_arch_name = config.text_config.architectures[0]
|
||||
self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
|
||||
self.vision_model = self._init_vision_model(config, self.is_mono)
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config, cache_config, quant_config)
|
||||
@@ -451,10 +448,19 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return Sampler()
|
||||
|
||||
def _init_vision_model(self, config: PretrainedConfig,
|
||||
num_hidden_layers: int):
|
||||
return InternVisionModel(config.vision_config,
|
||||
num_hidden_layers_override=num_hidden_layers)
|
||||
def _init_vision_model(self, config: PretrainedConfig, is_mono: bool):
|
||||
if not is_mono:
|
||||
vision_feature_layer = self.select_layer
|
||||
if vision_feature_layer < 0:
|
||||
num_hidden_layers = config.vision_config.num_hidden_layers \
|
||||
+ vision_feature_layer + 1
|
||||
else:
|
||||
num_hidden_layers = vision_feature_layer + 1
|
||||
return InternVisionModel(
|
||||
config.vision_config,
|
||||
num_hidden_layers_override=num_hidden_layers)
|
||||
else:
|
||||
return InternVisionPatchModel(config.vision_config)
|
||||
|
||||
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
|
||||
vit_hidden_size = config.vision_config.hidden_size
|
||||
@@ -562,6 +568,14 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return image_embeds
|
||||
|
||||
def _get_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
if self.is_mono:
|
||||
visual_token_mask = (
|
||||
input_ids == self.img_context_token_id).reshape(-1, 1)
|
||||
else:
|
||||
visual_token_mask = None
|
||||
return visual_token_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -574,6 +588,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
visual_token_mask = None
|
||||
else:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is not None:
|
||||
@@ -583,16 +598,24 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.img_context_token_id)
|
||||
visual_token_mask = self._get_visual_token_mask(input_ids)
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
visual_token_mask = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
forward_kwargs = {
|
||||
"input_ids": input_ids,
|
||||
"positions": positions,
|
||||
"kv_caches": kv_caches,
|
||||
"attn_metadata": attn_metadata,
|
||||
"intermediate_tensors": intermediate_tensors,
|
||||
"inputs_embeds": inputs_embeds,
|
||||
}
|
||||
if self.is_mono:
|
||||
forward_kwargs.update({"visual_token_mask": visual_token_mask})
|
||||
|
||||
hidden_states = self.language_model.model(**forward_kwargs)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
|
||||
Reference in New Issue
Block a user