[Model] Add Support for Multimodal Granite Models (#10291)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Alex Brooks
2024-11-21 03:46:20 -07:00
committed by GitHub
parent f0e0238016
commit 1cfde82ffd
6 changed files with 191 additions and 35 deletions

View File

@@ -25,7 +25,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges,
repeat_and_pad_placeholder_tokens)
repeat_and_pad_placeholder_tokens,
resolve_visual_encoder_outputs)
from vllm.sequence import SequenceData
from .utils import get_vit_attn_backend
@@ -450,11 +451,19 @@ class SiglipEncoder(nn.Module):
def forward(
self,
inputs_embeds: torch.Tensor,
) -> torch.Tensor:
return_all_hidden_states: bool,
) -> Union[torch.Tensor, list[torch.Tensor]]:
hidden_states_pool = []
hidden_states = inputs_embeds
for encoder_layer in self.layers:
hidden_states, _ = encoder_layer(hidden_states)
if return_all_hidden_states:
hidden_states_pool.append(hidden_states)
# If we have multiple feature sample layers, we return all hidden
# states in order and grab the ones we need by index.
if return_all_hidden_states:
return hidden_states_pool
return hidden_states
@@ -509,6 +518,7 @@ class SiglipVisionTransformer(nn.Module):
embed_dim = config.hidden_size
self.embeddings = SiglipVisionEmbeddings(config)
self.encoder = SiglipEncoder(
config,
quant_config=quant_config,
@@ -546,23 +556,33 @@ class SiglipVisionTransformer(nn.Module):
self,
pixel_values: torch.Tensor,
interpolate_pos_encoding: bool = True,
feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor:
hidden_states = self.embeddings(
pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
)
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
return_all_hidden_states = feature_sample_layers is not None
if self.post_layernorm is None:
return encoder_outputs
# Produces either the last layer output or all of the hidden states,
# depending on if we have feature_sample_layers or not
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
return_all_hidden_states=return_all_hidden_states,
)
last_hidden_state = self.post_layernorm(encoder_outputs)
# TODO: add this back when pooled_output is used in inference
# Handle post-norm (if applicable) and stacks feature layers if needed
encoder_outputs = resolve_visual_encoder_outputs(
encoder_outputs, feature_sample_layers, self.post_layernorm,
self.config.num_hidden_layers)
# TODO: add this back when pooled_output is used in inference.
# if self.use_head:
# pooled_output = self.head(last_hidden_state)
# pooled_output = self.head(encoder_outputs)
return last_hidden_state
return encoder_outputs
class SiglipVisionModel(nn.Module):
@@ -595,10 +615,12 @@ class SiglipVisionModel(nn.Module):
self,
pixel_values: torch.Tensor,
interpolate_pos_encoding: bool = False,
feature_sample_layers: Optional[list[int]] = None,
) -> torch.Tensor:
return self.vision_model(
pixel_values=pixel_values,
interpolate_pos_encoding=interpolate_pos_encoding,
feature_sample_layers=feature_sample_layers,
)
def load_weights(self, weights: Iterable[Tuple[str,