[MM][Bugfix] Replace PatchEmbed's conv3d to linear layer (#27418)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -25,6 +25,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
||||
|
||||
import math
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from functools import partial
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
@@ -53,7 +54,11 @@ from vllm.distributed import parallel_state, tensor_model_parallel_all_gather
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import QuickGELU
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
dispatch_rotary_emb_function,
|
||||
@@ -100,7 +105,11 @@ from .utils import (
|
||||
init_vllm_registered_model,
|
||||
maybe_prefix,
|
||||
)
|
||||
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
|
||||
from .vision import (
|
||||
conv3d_to_linear_weight,
|
||||
get_vit_attn_backend,
|
||||
run_dp_sharded_mrope_vision_model,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -561,18 +570,15 @@ class Qwen2VisionPatchEmbed(nn.Module):
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
kernel_size = (temporal_patch_size, patch_size, patch_size)
|
||||
self.proj = nn.Conv3d(
|
||||
in_channels,
|
||||
self.proj = ReplicatedLinear(
|
||||
in_channels * math.prod(kernel_size),
|
||||
embed_dim,
|
||||
kernel_size=kernel_size,
|
||||
stride=kernel_size,
|
||||
bias=False,
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
L, C = x.shape
|
||||
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
|
||||
x = self.proj(x).view(L, self.embed_dim)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
@@ -835,6 +841,9 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if name.endswith("patch_embed.proj.weight"):
|
||||
loaded_weight = conv3d_to_linear_weight(loaded_weight)
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user