[VLM] Enable overriding whether post layernorm is used in vision encoder + fix quant args (#9217)

Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Cyrus Leung
2024-10-23 19:27:37 +08:00
committed by GitHub
parent 3ff57ebfca
commit c18e1a3418
18 changed files with 551 additions and 253 deletions

View File

@@ -31,6 +31,7 @@ from .clip import (CLIPVisionModel, dummy_seq_data_for_clip,
dummy_video_for_clip, get_clip_image_feature_size,
get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsMultiModal, SupportsPP
from .llava import init_vision_tower_for_llava
from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip,
dummy_video_for_siglip, get_siglip_image_feature_size,
get_siglip_patch_grid_length, input_processor_for_siglip)
@@ -357,32 +358,6 @@ def input_processor_for_llava_onevision(ctx: InputContext,
raise NotImplementedError(msg)
def _init_vision_tower(hf_config: LlavaOnevisionConfig):
vision_config = hf_config.vision_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = hf_config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = hf_config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
if isinstance(vision_config, CLIPVisionConfig):
return CLIPVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
elif isinstance(vision_config, SiglipVisionConfig):
return SiglipVisionModel(
vision_config,
num_hidden_layers_override=num_hidden_layers,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
class LlavaOnevisionMultiModalProjector(nn.Module):
def __init__(self, config: LlavaOnevisionConfig):
@@ -425,7 +400,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
self.multimodal_config = multimodal_config
# Initialize the vision tower only up to the required feature layer
self.vision_tower = _init_vision_tower(config)
self.vision_tower = init_vision_tower_for_llava(config, quant_config)
self.multi_modal_projector = LlavaOnevisionMultiModalProjector(config)
self.language_model = init_vllm_registered_model(
config.text_config, cache_config, quant_config)