Fix PixtralHF missing spatial_merge_size (#17571)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -19,10 +19,11 @@ _C = TypeVar("_C", bound=PretrainedConfig)
|
||||
|
||||
class VisionEncoderInfo(ABC, Generic[_C]):
|
||||
|
||||
def __init__(self, vision_config: _C) -> None:
|
||||
def __init__(self, hf_config: _C) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.vision_config = vision_config
|
||||
self.hf_config = hf_config
|
||||
self.vision_config = hf_config.vision_config
|
||||
|
||||
@abstractmethod
|
||||
def get_num_image_tokens(
|
||||
@@ -57,18 +58,14 @@ def get_vision_encoder_info(
|
||||
from .pixtral import PixtralHFEncoderInfo, PixtralVisionConfig
|
||||
from .siglip import SiglipEncoderInfo, SiglipVisionConfig
|
||||
|
||||
vision_config = hf_config.vision_config
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return CLIPEncoderInfo(vision_config)
|
||||
if isinstance(vision_config, PixtralVisionConfig):
|
||||
# Need to sneak in spatial_merge_size for Mistral3
|
||||
vision_config.spatial_merge_size = getattr(hf_config,
|
||||
"spatial_merge_size", 1)
|
||||
return PixtralHFEncoderInfo(vision_config)
|
||||
if isinstance(vision_config, SiglipVisionConfig):
|
||||
return SiglipEncoderInfo(vision_config)
|
||||
if isinstance(hf_config.vision_config, CLIPVisionConfig):
|
||||
return CLIPEncoderInfo(hf_config)
|
||||
if isinstance(hf_config.vision_config, PixtralVisionConfig):
|
||||
return PixtralHFEncoderInfo(hf_config)
|
||||
if isinstance(hf_config.vision_config, SiglipVisionConfig):
|
||||
return SiglipEncoderInfo(hf_config)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
msg = f"Unsupported vision config: {type(hf_config.vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user