[Model] Add Support for Multimodal Granite Models (#10291)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Alex Brooks
2024-11-21 03:46:20 -07:00
committed by GitHub
parent f0e0238016
commit 1cfde82ffd
6 changed files with 191 additions and 35 deletions

View File

@@ -21,7 +21,8 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
repeat_and_pad_placeholder_tokens,
resolve_visual_encoder_outputs)
from vllm.sequence import SequenceData
from .utils import get_vit_attn_backend
@@ -389,12 +390,20 @@ class CLIPEncoder(nn.Module):
for layer_idx in range(num_hidden_layers)
])
def forward(self, inputs_embeds: torch.Tensor):
def forward(
self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool
) -> Union[torch.Tensor, list[torch.Tensor]]:
hidden_states_pool = []
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states)
if return_all_hidden_states:
hidden_states_pool.append(hidden_states)
# If we have multiple feature sample layers, we return all hidden
# states in order and grab the ones we need by index.
if return_all_hidden_states:
return hidden_states_pool
return hidden_states
@@ -419,6 +428,7 @@ class CLIPVisionTransformer(nn.Module):
# NOTE: This typo of "layrnorm" is not fixed on purpose to match
# the original transformers code and name of the model weights.
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.encoder = CLIPEncoder(
config=config,
quant_config=quant_config,
@@ -446,16 +456,26 @@ class CLIPVisionTransformer(nn.Module):
def forward(
self,
pixel_values: torch.Tensor,
feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor:
hidden_states = self.embeddings(pixel_values)
hidden_states = self.pre_layrnorm(hidden_states)
hidden_states = self.encoder(inputs_embeds=hidden_states)
if self.post_layernorm is None:
return hidden_states
return_all_hidden_states = feature_sample_layers is not None
return self.post_layernorm(hidden_states)
# Produces either the last layer output or all of the hidden states,
# depending on if we have feature_sample_layers or not
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
return_all_hidden_states=return_all_hidden_states)
# Handle post-norm (if applicable) and stacks feature layers if needed
encoder_outputs = resolve_visual_encoder_outputs(
encoder_outputs, feature_sample_layers, self.post_layernorm,
self.config.num_hidden_layers)
return encoder_outputs
class CLIPVisionModel(nn.Module):
@@ -478,11 +498,14 @@ class CLIPVisionModel(nn.Module):
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model",
)
prefix=f"{prefix}.vision_model")
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
return self.vision_model(pixel_values)
def forward(
self,
pixel_values: torch.Tensor,
feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor:
return self.vision_model(pixel_values, feature_sample_layers)
@property
def device(self):