[VLM] Enable overriding whether post layernorm is used in vision encoder + fix quant args (#9217)
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol,
|
||||
Tuple, TypedDict, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
|
||||
SiglipVisionConfig)
|
||||
PretrainedConfig, SiglipVisionConfig)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
@@ -200,7 +200,17 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def _init_vision_tower(hf_config: LlavaConfig):
|
||||
class LlavaLikeConfig(Protocol):
|
||||
vision_config: PretrainedConfig
|
||||
vision_feature_layer: int
|
||||
|
||||
|
||||
def init_vision_tower_for_llava(
|
||||
hf_config: LlavaLikeConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
*,
|
||||
require_post_norm: Optional[bool] = None,
|
||||
):
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
# Initialize the vision tower only up to the required feature layer
|
||||
@@ -214,16 +224,24 @@ def _init_vision_tower(hf_config: LlavaConfig):
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return CLIPVisionModel(
|
||||
vision_config,
|
||||
quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
require_post_norm=require_post_norm,
|
||||
)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
return SiglipVisionModel(
|
||||
vision_config,
|
||||
quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
require_post_norm=require_post_norm,
|
||||
)
|
||||
elif isinstance(vision_config, PixtralVisionConfig):
|
||||
# TODO: allow layer override?
|
||||
return PixtralHFVisionModel(vision_config)
|
||||
return PixtralHFVisionModel(
|
||||
vision_config,
|
||||
quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
require_post_norm=require_post_norm,
|
||||
)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
@@ -255,7 +273,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
config.projector_hidden_act = "gelu"
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_tower = _init_vision_tower(config)
|
||||
self.vision_tower = init_vision_tower_for_llava(config, quant_config)
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
|
||||
Reference in New Issue
Block a user