[MM][Bugfix] Replace PatchEmbed's conv3d to linear layer (#27418)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -26,6 +26,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
|
||||
|
||||
import math
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from functools import lru_cache, partial
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
@@ -56,6 +57,7 @@ from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
@@ -98,7 +100,11 @@ from .utils import (
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
||||
from .vision import (
|
||||
conv3d_to_linear_weight,
|
||||
get_vit_attn_backend,
|
||||
run_dp_sharded_mrope_vision_model,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -532,18 +538,15 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
kernel_size = (temporal_patch_size, patch_size, patch_size)
|
||||
self.proj = nn.Conv3d(
|
||||
in_channels,
|
||||
self.proj = ReplicatedLinear(
|
||||
in_channels * math.prod(kernel_size),
|
||||
hidden_size,
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size,
|
||||
bias=False,
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
L, C = x.shape
|
||||
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
|
||||
x = self.proj(x).view(L, self.hidden_size)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
@@ -950,6 +953,9 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if name.endswith("patch_embed.proj.weight"):
|
||||
loaded_weight = conv3d_to_linear_weight(loaded_weight)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user