[Bugfix] Fix prefix strings for quantized VLMs (#9772)

This commit is contained in:
Michael Goin
2024-10-29 19:02:59 -04:00
committed by GitHub
parent 8d7724104a
commit bc73e9821c
20 changed files with 288 additions and 97 deletions

View File

@@ -71,7 +71,8 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
def _init_img_processor(hf_config: PretrainedConfig,
quant_config: Optional[QuantizationConfig]):
quant_config: Optional[QuantizationConfig],
prefix: str = "") -> CLIPVisionModel:
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
layer_idx = hf_config.img_processor.get('layer_idx', -2)
@@ -86,6 +87,7 @@ def _init_img_processor(hf_config: PretrainedConfig,
clip_config,
quant_config,
num_hidden_layers_override=num_hidden_layers,
prefix=prefix,
)
return img_processor
@@ -152,15 +154,18 @@ class Phi3ImageEmbeddingBase(nn.Module):
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
"""Phi3 Image embedding with HD transform."""
def __init__(self, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig]) -> None:
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "") -> None:
super().__init__()
# n_embed or hidden_size
hidden_size = config.n_embd if hasattr(
config, 'n_embd') else config.hidden_size
self.img_processor = _init_img_processor(config, quant_config)
self.img_processor = _init_img_processor(
config, quant_config, prefix=f"{prefix}.img_processor")
image_dim_out = config.img_processor['image_dim_out']
self.num_img_tokens = config.img_processor['num_img_tokens']
@@ -537,11 +542,15 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
prefix="model.embed_tokens",
)
# TODO: Optionally initializes this for supporting input embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding(config, quant_config)
self.vision_embed_tokens = Phi3HDImageEmbedding(
config, quant_config, prefix="model.vision_embed_tokens")
# The prefix is empty intentionally because default prefix of
# LlamaForCausalLM is "model"
self.language_model = LlamaForCausalLM(config, cache_config,
quant_config)