diff --git a/vllm/model_executor/models/transformers/base.py b/vllm/model_executor/models/transformers/base.py index 45e746ac2..2e79ace46 100644 --- a/vllm/model_executor/models/transformers/base.py +++ b/vllm/model_executor/models/transformers/base.py @@ -47,6 +47,7 @@ from vllm.model_executor.models.transformers.utils import ( get_feature_request_tip, init_on_device_without_buffers, log_replacement, + replace_conv_class, replace_linear_class, replace_rms_norm_class, ) @@ -314,6 +315,8 @@ class Base( new_module = replace_linear_class( child_module, style, self.quant_config, prefix=qual_name ) + elif isinstance(child_module, (nn.Conv2d, nn.Conv3d)): + new_module = replace_conv_class(child_module) elif child_module.__class__.__name__.endswith("RMSNorm"): new_module = replace_rms_norm_class( child_module, self.text_config.hidden_size diff --git a/vllm/model_executor/models/transformers/utils.py b/vllm/model_executor/models/transformers/utils.py index c7844381e..e47f3bba5 100644 --- a/vllm/model_executor/models/transformers/utils.py +++ b/vllm/model_executor/models/transformers/utils.py @@ -25,6 +25,7 @@ from torch import nn from vllm.config.utils import getattr_iter from vllm.logger import init_logger +from vllm.model_executor.layers.conv import Conv2dLayer, Conv3dLayer from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, @@ -136,6 +137,45 @@ def replace_linear_class( ) +TorchConv = nn.Conv2d | nn.Conv3d +VllmConv = Conv2dLayer | Conv3dLayer + + +def replace_conv_class(conv: TorchConv) -> VllmConv | TorchConv: + """Replace a Transformers Conv2d/Conv3d with vLLM's Conv2d/Conv3d. + + Args: + conv: `nn.Conv2d` or `nn.Conv3d` to be replaced. + Returns: + The new `Conv2dLayer` or `Conv3dLayer`. If the conv module is not supported, + returns the original conv module. + """ + # vLLM does not handle non-zero padding modes + if conv.padding_mode != "zeros": + return conv + + vllm_conv_cls = { + nn.Conv2d: Conv2dLayer, + nn.Conv3d: Conv3dLayer, + }.get(type(conv)) + + if vllm_conv_cls is None: + return conv + + return vllm_conv_cls( + in_channels=conv.in_channels, + out_channels=conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + dilation=conv.dilation, + groups=conv.groups, + bias=conv.bias is not None, + padding_mode=conv.padding_mode, + params_dtype=conv.weight.dtype, + ) + + def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm: """Replace a Transformers RMSNorm with vLLM's RMSNorm.