[Bugfix][Model] Add base class for vision-language models (#4809)

This commit is contained in:
Cyrus Leung
2024-05-19 15:13:33 +08:00
committed by GitHub
parent 2e9a2227ec
commit f68470e803
4 changed files with 53 additions and 29 deletions

View File

@@ -19,6 +19,8 @@ from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from .vlm_base import VisionLanguageModelBase
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
@@ -40,7 +42,7 @@ class LlavaMultiModalProjector(nn.Module):
text_hidden_size,
bias=True)
def forward(self, image_features):
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.linear_1(image_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
@@ -50,30 +52,32 @@ class LlavaMultiModalProjector(nn.Module):
def _merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: torch.Tensor,
image_token_id: int):
image_token_id: int) -> torch.Tensor:
"""In place merges in vision_embeddings with inputs_embeds."""
mask = (input_ids == image_token_id)
inputs_embeds[mask] = vision_embeddings.view(-1,
image_feature_size = vision_embeddings.shape[0] * vision_embeddings.shape[1]
if mask.sum() != image_feature_size:
raise ValueError(f"image_feature_size should be {image_feature_size}, "
f"but found: {mask.sum()}")
inputs_embeds[mask] = vision_embeddings.view(image_feature_size,
vision_embeddings.shape[-1])
return inputs_embeds
class LlavaForConditionalGeneration(nn.Module):
class LlavaForConditionalGeneration(VisionLanguageModelBase):
def __init__(self,
config: "LlavaConfig",
config: LlavaConfig,
vision_language_config: VisionLanguageConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional["QuantizationConfig"] = None) -> None:
super().__init__()
quant_config: Optional[QuantizationConfig] = None) -> None:
super().__init__(vision_language_config)
self.config = config
self.vision_language_config = vision_language_config
assert self.vision_language_config, (
"Provide `image_input_type` and other vision "
"related configurations through LLM entrypoint "
"or engine arguments.")
if self.vision_language_config.image_input_type == (
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
self.vision_tower = CLIPVisionModel(config.vision_config)
@@ -98,14 +102,12 @@ class LlavaForConditionalGeneration(nn.Module):
config.vocab_size, logit_scale)
self.sampler = Sampler()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
image_input: Optional[torch.Tensor] = None
) -> SamplerOutput: # noqa: E501
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
image_input: Optional[torch.Tensor] = None) -> SamplerOutput:
"""Run forward pass for Llava 1.5.
One key thing to understand is the `input_ids` already accounts for the
@@ -172,7 +174,7 @@ class LlavaForConditionalGeneration(nn.Module):
image_features = image_input
vision_embeddings = self.multi_modal_projector(image_features)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
_merge_vision_embeddings(
inputs_embeds = _merge_vision_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.vision_language_config.image_token_id)
input_ids = None