[Model][MM] Extract conv layer as CustomOp (#28455)

Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Shanshan Shen
2025-11-14 19:16:13 +08:00
committed by GitHub
parent 360bd8762f
commit 41b92f7d38
8 changed files with 277 additions and 66 deletions

View File

@@ -26,7 +26,6 @@
# limitations under the License.
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
import math
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import lru_cache, partial
from typing import Annotated, Any, Literal, TypeAlias
@@ -56,12 +55,12 @@ from vllm.distributed import utils as dist_utils
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_and_mul_fn
from vllm.model_executor.layers.conv import Conv3dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -110,7 +109,6 @@ from .utils import (
maybe_prefix,
)
from .vision import (
conv3d_to_linear_weight,
get_vit_attn_backend,
run_dp_sharded_mrope_vision_model,
)
@@ -525,15 +523,18 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
self.hidden_size = hidden_size
kernel_size = (temporal_patch_size, patch_size, patch_size)
self.proj = ReplicatedLinear(
in_channels * math.prod(kernel_size),
self.proj = Conv3dLayer(
in_channels,
hidden_size,
kernel_size=kernel_size,
stride=kernel_size,
bias=False,
return_bias=False,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
L, C = x.shape
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
x = self.proj(x).view(L, self.hidden_size)
return x
@@ -957,9 +958,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name.endswith("patch_embed.proj.weight"):
loaded_weight = conv3d_to_linear_weight(loaded_weight)
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue