[Model] Use merge_by_field_config for MM models (Qwen series) (#27546)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-27 13:38:05 +08:00
committed by GitHub
parent 63b22e0dbb
commit cbd5e07a51
7 changed files with 36 additions and 305 deletions

View File

@@ -1175,6 +1175,8 @@ class Qwen3LLMForCausalLM(Qwen3ForCausalLM):
class Qwen3VLForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -1298,24 +1300,6 @@ class Qwen3VLForConditionalGeneration(
for idx in range(self.deepstack_num_level):
self.deepstack_input_embeds[idx][:num_tokens].zero_()
def _validate_and_reshape_mm_tensor(
self, mm_input: object, name: str
) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
if mm_input.ndim == 2:
return mm_input
if mm_input.ndim != 3:
raise ValueError(
f"{name} should be 2D or batched 3D tensor. "
f"Got ndim: {mm_input.ndim} "
f"(shape={mm_input.shape})"
)
return mm_input.reshape(-1, mm_input.shape[-1])
else:
return torch.concat(mm_input)
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Qwen2_5_VLImageInputs | None:
@@ -1327,19 +1311,6 @@ class Qwen3VLForConditionalGeneration(
return None
if pixel_values is not None:
pixel_values = self._validate_and_reshape_mm_tensor(
pixel_values, "image pixel values"
)
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw"
)
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image pixel values. "
f"Got type: {type(pixel_values)}"
)
return Qwen2_5_VLImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values,
@@ -1347,18 +1318,6 @@ class Qwen3VLForConditionalGeneration(
)
if image_embeds is not None:
image_embeds = self._validate_and_reshape_mm_tensor(
image_embeds, "image embeds"
)
image_grid_thw = self._validate_and_reshape_mm_tensor(
image_grid_thw, "image grid_thw"
)
if not isinstance(image_embeds, torch.Tensor):
raise ValueError(
"Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}"
)
return Qwen2_5_VLImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds,
@@ -1377,13 +1336,6 @@ class Qwen3VLForConditionalGeneration(
return None
if pixel_values_videos is not None:
pixel_values_videos = self._validate_and_reshape_mm_tensor(
pixel_values_videos, "video pixel values"
)
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw"
)
return Qwen2_5_VLVideoPixelInputs(
type="pixel_values_videos",
pixel_values_videos=pixel_values_videos,
@@ -1392,18 +1344,6 @@ class Qwen3VLForConditionalGeneration(
)
if video_embeds is not None:
video_embeds = self._validate_and_reshape_mm_tensor(
video_embeds, "video embeds"
)
video_grid_thw = self._validate_and_reshape_mm_tensor(
video_grid_thw, "video grid_thw"
)
if not isinstance(video_embeds, torch.Tensor):
raise ValueError(
"Incorrect type of video embeddings. "
f"Got type: {type(video_embeds)}"
)
return Qwen2_5_VLVideoEmbeddingInputs(
type="video_embeds",
video_embeds=video_embeds,