[MM][Bugfix] Replace PatchEmbed's conv3d to linear layer (#27418)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Isotr0py
2025-10-24 15:32:47 +08:00
committed by GitHub
parent 88d3141ec6
commit 42efe609ba
6 changed files with 97 additions and 42 deletions

View File

@@ -22,6 +22,7 @@
# limitations under the License.
"""Inference-only Qwen3-Omni-Moe model (thinker part)."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial
from typing import Any
@@ -53,7 +54,11 @@ from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
@@ -98,7 +103,11 @@ from .utils import (
_merge_multimodal_embeddings,
maybe_prefix,
)
from .vision import get_llm_pos_ids_for_vision, get_vit_attn_backend
from .vision import (
conv3d_to_linear_weight,
get_llm_pos_ids_for_vision,
get_vit_attn_backend,
)
try:
import flash_attn
@@ -131,18 +140,16 @@ class Qwen3_VisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = nn.Conv3d(
in_channels,
self.proj = ReplicatedLinear(
in_channels * math.prod(kernel_size),
hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=True,
return_bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
x = self.proj(x)
return x
@@ -559,6 +566,9 @@ class Qwen3Omni_VisionTransformer(nn.Module):
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue