[Model][VLM] Initialize support for Mono-InternVL model (#9528)

This commit is contained in:
Isotr0py
2024-10-23 00:01:46 +08:00
committed by GitHub
parent 9dbcce84a7
commit bb392ea2d2
6 changed files with 253 additions and 27 deletions

View File

@@ -21,7 +21,8 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.models.intern_vit import InternVisionModel
from vllm.model_executor.models.intern_vit import (InternVisionModel,
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
@@ -427,13 +428,9 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.downsample_ratio = config.downsample_ratio
self.ps_version = config.ps_version
vision_feature_layer = self.select_layer
if vision_feature_layer < 0:
num_hidden_layers = config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
self.vision_model = self._init_vision_model(config, num_hidden_layers)
self.llm_arch_name = config.text_config.architectures[0]
self.is_mono = self.llm_arch_name == 'InternLM2VEForCausalLM'
self.vision_model = self._init_vision_model(config, self.is_mono)
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)
@@ -451,10 +448,19 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return Sampler()
def _init_vision_model(self, config: PretrainedConfig,
num_hidden_layers: int):
return InternVisionModel(config.vision_config,
num_hidden_layers_override=num_hidden_layers)
def _init_vision_model(self, config: PretrainedConfig, is_mono: bool):
if not is_mono:
vision_feature_layer = self.select_layer
if vision_feature_layer < 0:
num_hidden_layers = config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
return InternVisionModel(
config.vision_config,
num_hidden_layers_override=num_hidden_layers)
else:
return InternVisionPatchModel(config.vision_config)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
vit_hidden_size = config.vision_config.hidden_size
@@ -562,6 +568,14 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return image_embeds
def _get_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
if self.is_mono:
visual_token_mask = (
input_ids == self.img_context_token_id).reshape(-1, 1)
else:
visual_token_mask = None
return visual_token_mask
def forward(
self,
input_ids: torch.Tensor,
@@ -574,6 +588,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
visual_token_mask = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
@@ -583,16 +598,24 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.img_context_token_id)
visual_token_mask = self._get_visual_token_mask(input_ids)
input_ids = None
else:
inputs_embeds = None
visual_token_mask = None
hidden_states = self.language_model.model(input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors,
inputs_embeds=inputs_embeds)
forward_kwargs = {
"input_ids": input_ids,
"positions": positions,
"kv_caches": kv_caches,
"attn_metadata": attn_metadata,
"intermediate_tensors": intermediate_tensors,
"inputs_embeds": inputs_embeds,
}
if self.is_mono:
forward_kwargs.update({"visual_token_mask": visual_token_mask})
hidden_states = self.language_model.model(**forward_kwargs)
return hidden_states
def compute_logits(