[Bugfix] Fix prefix strings for quantized VLMs (#9772)
This commit is contained in:
@@ -210,6 +210,7 @@ def init_vision_tower_for_llava(
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
*,
|
||||
require_post_norm: Optional[bool] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
@@ -224,23 +225,26 @@ def init_vision_tower_for_llava(
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return CLIPVisionModel(
|
||||
vision_config,
|
||||
quant_config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=prefix,
|
||||
)
|
||||
elif isinstance(vision_config, SiglipVisionConfig):
|
||||
return SiglipVisionModel(
|
||||
vision_config,
|
||||
quant_config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=prefix,
|
||||
)
|
||||
elif isinstance(vision_config, PixtralVisionConfig):
|
||||
return PixtralHFVisionModel(
|
||||
vision_config,
|
||||
quant_config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
@@ -274,14 +278,20 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_tower = init_vision_tower_for_llava(
|
||||
config, quant_config, require_post_norm=False)
|
||||
config,
|
||||
quant_config,
|
||||
require_post_norm=False,
|
||||
prefix="vision_tower")
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
projector_hidden_act=config.projector_hidden_act)
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
config.text_config, cache_config, quant_config)
|
||||
config.text_config,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix="language_model")
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
Reference in New Issue
Block a user