Replace nn.ConvNd with vLLM's ConvNdLayer for Transformers modeling backend (#31498)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -47,6 +47,7 @@ from vllm.model_executor.models.transformers.utils import (
|
||||
get_feature_request_tip,
|
||||
init_on_device_without_buffers,
|
||||
log_replacement,
|
||||
replace_conv_class,
|
||||
replace_linear_class,
|
||||
replace_rms_norm_class,
|
||||
)
|
||||
@@ -314,6 +315,8 @@ class Base(
|
||||
new_module = replace_linear_class(
|
||||
child_module, style, self.quant_config, prefix=qual_name
|
||||
)
|
||||
elif isinstance(child_module, (nn.Conv2d, nn.Conv3d)):
|
||||
new_module = replace_conv_class(child_module)
|
||||
elif child_module.__class__.__name__.endswith("RMSNorm"):
|
||||
new_module = replace_rms_norm_class(
|
||||
child_module, self.text_config.hidden_size
|
||||
|
||||
@@ -25,6 +25,7 @@ from torch import nn
|
||||
|
||||
from vllm.config.utils import getattr_iter
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer, Conv3dLayer
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
@@ -136,6 +137,45 @@ def replace_linear_class(
|
||||
)
|
||||
|
||||
|
||||
TorchConv = nn.Conv2d | nn.Conv3d
|
||||
VllmConv = Conv2dLayer | Conv3dLayer
|
||||
|
||||
|
||||
def replace_conv_class(conv: TorchConv) -> VllmConv | TorchConv:
|
||||
"""Replace a Transformers Conv2d/Conv3d with vLLM's Conv2d/Conv3d.
|
||||
|
||||
Args:
|
||||
conv: `nn.Conv2d` or `nn.Conv3d` to be replaced.
|
||||
Returns:
|
||||
The new `Conv2dLayer` or `Conv3dLayer`. If the conv module is not supported,
|
||||
returns the original conv module.
|
||||
"""
|
||||
# vLLM does not handle non-zero padding modes
|
||||
if conv.padding_mode != "zeros":
|
||||
return conv
|
||||
|
||||
vllm_conv_cls = {
|
||||
nn.Conv2d: Conv2dLayer,
|
||||
nn.Conv3d: Conv3dLayer,
|
||||
}.get(type(conv))
|
||||
|
||||
if vllm_conv_cls is None:
|
||||
return conv
|
||||
|
||||
return vllm_conv_cls(
|
||||
in_channels=conv.in_channels,
|
||||
out_channels=conv.out_channels,
|
||||
kernel_size=conv.kernel_size,
|
||||
stride=conv.stride,
|
||||
padding=conv.padding,
|
||||
dilation=conv.dilation,
|
||||
groups=conv.groups,
|
||||
bias=conv.bias is not None,
|
||||
padding_mode=conv.padding_mode,
|
||||
params_dtype=conv.weight.dtype,
|
||||
)
|
||||
|
||||
|
||||
def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
|
||||
"""Replace a Transformers RMSNorm with vLLM's RMSNorm.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user