[VLM] Enable overriding whether post layernorm is used in vision encoder + fix quant args (#9217)
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -70,7 +70,8 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
|
||||
projection_dim=768)
|
||||
|
||||
|
||||
def _init_img_processor(hf_config: PretrainedConfig):
|
||||
def _init_img_processor(hf_config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig]):
|
||||
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
|
||||
layer_idx = hf_config.img_processor.get('layer_idx', -2)
|
||||
|
||||
@@ -82,7 +83,10 @@ def _init_img_processor(hf_config: PretrainedConfig):
|
||||
num_hidden_layers = layer_idx + 1
|
||||
|
||||
img_processor = CLIPVisionModel(
|
||||
clip_config, num_hidden_layers_override=num_hidden_layers)
|
||||
clip_config,
|
||||
quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
)
|
||||
|
||||
return img_processor
|
||||
|
||||
@@ -148,14 +152,15 @@ class Phi3ImageEmbeddingBase(nn.Module):
|
||||
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
"""Phi3 Image embedding with HD transform."""
|
||||
|
||||
def __init__(self, config: PretrainedConfig) -> None:
|
||||
def __init__(self, config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig]) -> None:
|
||||
super().__init__()
|
||||
|
||||
# n_embed or hidden_size
|
||||
hidden_size = config.n_embd if hasattr(
|
||||
config, 'n_embd') else config.hidden_size
|
||||
|
||||
self.img_processor = _init_img_processor(config)
|
||||
self.img_processor = _init_img_processor(config, quant_config)
|
||||
|
||||
image_dim_out = config.img_processor['image_dim_out']
|
||||
self.num_img_tokens = config.img_processor['num_img_tokens']
|
||||
@@ -535,7 +540,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
)
|
||||
|
||||
# TODO: Optionally initializes this for supporting input embeddings.
|
||||
self.vision_embed_tokens = Phi3HDImageEmbedding(config)
|
||||
self.vision_embed_tokens = Phi3HDImageEmbedding(config, quant_config)
|
||||
|
||||
self.language_model = LlamaForCausalLM(config, cache_config,
|
||||
quant_config)
|
||||
|
||||
Reference in New Issue
Block a user