[Models] Replace all nn.Conv2d with vLLM's Conv2dLayer (#28842)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
"""Conv Layer Class."""
|
||||
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -23,11 +24,11 @@ class ConvLayerBase(CustomOp):
|
||||
out_channels: int,
|
||||
kernel_size: int | tuple[int, ...],
|
||||
stride: int | tuple[int, ...] = 1,
|
||||
padding: int | tuple[int, ...] = 0,
|
||||
padding: int | tuple[int, ...] | Literal["same", "valid"] = 0,
|
||||
dilation: int | tuple[int, ...] = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = "zeros",
|
||||
padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
|
||||
*,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
@@ -36,6 +37,22 @@ class ConvLayerBase(CustomOp):
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
valid_padding_strings = {"same", "valid"}
|
||||
if isinstance(padding, str) and padding not in valid_padding_strings:
|
||||
raise ValueError(
|
||||
f"Invalid padding string '{padding}'. "
|
||||
f"Expected one of {valid_padding_strings}."
|
||||
)
|
||||
|
||||
if padding == "same":
|
||||
padding = (
|
||||
kernel_size // 2
|
||||
if isinstance(kernel_size, int)
|
||||
else tuple(k // 2 for k in kernel_size)
|
||||
)
|
||||
elif padding == "valid":
|
||||
padding = 0
|
||||
|
||||
kernel_size = (
|
||||
(kernel_size,) * self.num_dim
|
||||
if isinstance(kernel_size, int)
|
||||
@@ -45,6 +62,9 @@ class ConvLayerBase(CustomOp):
|
||||
padding = (padding,) * self.num_dim if isinstance(padding, int) else padding
|
||||
dilation = (dilation,) * self.num_dim if isinstance(dilation, int) else dilation
|
||||
|
||||
if padding == "same" and any(s != 1 for s in stride):
|
||||
raise ValueError("padding='same' is not supported for strided convolutions")
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
Reference in New Issue
Block a user