[VLM] Enable overriding whether post layernorm is used in vision encoder + fix quant args (#9217)

Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Cyrus Leung
2024-10-23 19:27:37 +08:00
committed by GitHub
parent 3ff57ebfca
commit c18e1a3418
18 changed files with 551 additions and 253 deletions

View File

@@ -1,12 +1,12 @@
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol,
Tuple, TypedDict, Union)
import torch
import torch.nn as nn
from PIL import Image
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
SiglipVisionConfig)
PretrainedConfig, SiglipVisionConfig)
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, MultiModalConfig
@@ -200,7 +200,17 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
raise NotImplementedError(msg)
def _init_vision_tower(hf_config: LlavaConfig):
class LlavaLikeConfig(Protocol):
vision_config: PretrainedConfig
vision_feature_layer: int
def init_vision_tower_for_llava(
hf_config: LlavaLikeConfig,
quant_config: Optional[QuantizationConfig],
*,
require_post_norm: Optional[bool] = None,
):
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer
@@ -214,16 +224,24 @@ def _init_vision_tower(hf_config: LlavaConfig):
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
quant_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
)
elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel(
vision_config,
quant_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
)
elif isinstance(vision_config, PixtralVisionConfig):
# TODO: allow layer override?
return PixtralHFVisionModel(vision_config)
return PixtralHFVisionModel(
vision_config,
quant_config,
num_hidden_layers_override=num_hidden_layers,
require_post_norm=require_post_norm,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@@ -255,7 +273,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
config.projector_hidden_act = "gelu"
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = _init_vision_tower(config)
self.vision_tower = init_vision_tower_for_llava(config, quant_config)
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,