[Model] Use merge_by_field_config for MM models (H-L) (#26230)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -283,6 +283,7 @@ class KimiVLMultiModalProcessor(BaseMultiModalProcessor[KimiVLProcessingInfo]):
|
||||
dummy_inputs=KimiVLDummyInputsBuilder)
|
||||
class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
supports_encoder_tp_data = True
|
||||
|
||||
@@ -342,23 +343,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
config.vocab_size, logit_scale)
|
||||
self.media_placeholder: int = self.config.media_placeholder_token_id
|
||||
|
||||
# ref: qwen2_vl.py
|
||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||
name: str) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. "
|
||||
f"Got type: {type(mm_input)}")
|
||||
if isinstance(mm_input, torch.Tensor):
|
||||
if mm_input.ndim == 2:
|
||||
return mm_input
|
||||
if mm_input.ndim != 3:
|
||||
raise ValueError(f"{name} should be 2D or batched 3D tensor. "
|
||||
f"Got ndim: {mm_input.ndim} "
|
||||
f"(shape={mm_input.shape})")
|
||||
return mm_input.reshape(-1, mm_input.shape[-1])
|
||||
else:
|
||||
return torch.concat(mm_input)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[KimiVLImageInputs]:
|
||||
# image input type must be pixel values now
|
||||
@@ -368,21 +352,6 @@ class KimiVLForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
image_grid_hws = self._validate_and_reshape_mm_tensor(
|
||||
image_grid_hws, "image grid hws")
|
||||
# pixel_values may have complex shapes
|
||||
num_channels = 3
|
||||
patch_size = self.config.vision_config.patch_size
|
||||
if isinstance(pixel_values, list):
|
||||
pixel_values = torch.cat([
|
||||
x.reshape(-1, num_channels, patch_size, patch_size)
|
||||
for x in pixel_values
|
||||
])
|
||||
else:
|
||||
pixel_values = pixel_values.reshape(-1, num_channels, patch_size,
|
||||
patch_size)
|
||||
pixel_values = pixel_values.to(self.vision_tower.dtype)
|
||||
|
||||
return KimiVLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=pixel_values,
|
||||
|
||||
Reference in New Issue
Block a user