[Models] Replace all nn.Conv2d with vLLM's Conv2dLayer (#28842)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -22,6 +22,7 @@ from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
@@ -549,7 +550,7 @@ class ChameleonVQVAEVectorQuantizer(nn.Module):
|
||||
class ChameleonVQVAEEncoderConvDownsample(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
self.conv = Conv2dLayer(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0
|
||||
)
|
||||
|
||||
@@ -577,23 +578,23 @@ class ChameleonVQVAEEncoderResnetBlock(nn.Module):
|
||||
self.norm1 = torch.nn.GroupNorm(
|
||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
self.conv1 = torch.nn.Conv2d(
|
||||
self.conv1 = Conv2dLayer(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
self.norm2 = torch.nn.GroupNorm(
|
||||
num_groups=32, num_channels=out_channels, eps=1e-6, affine=True
|
||||
)
|
||||
self.dropout = torch.nn.Dropout(config.dropout)
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
self.conv2 = Conv2dLayer(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(
|
||||
self.conv_shortcut = Conv2dLayer(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(
|
||||
self.nin_shortcut = Conv2dLayer(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
@@ -626,16 +627,16 @@ class ChameleonVQVAEEncoderAttnBlock(nn.Module):
|
||||
self.norm = torch.nn.GroupNorm(
|
||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
|
||||
)
|
||||
self.q = torch.nn.Conv2d(
|
||||
self.q = Conv2dLayer(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
self.k = Conv2dLayer(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
self.v = Conv2dLayer(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
self.proj_out = Conv2dLayer(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
@@ -681,7 +682,7 @@ class ChameleonVQVAEEncoder(nn.Module):
|
||||
latent_channels = config.latent_channels
|
||||
channel_multiplier = config.channel_multiplier
|
||||
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
self.conv_in = Conv2dLayer(
|
||||
in_channels, base_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
@@ -738,7 +739,7 @@ class ChameleonVQVAEEncoder(nn.Module):
|
||||
self.norm_out = torch.nn.GroupNorm(
|
||||
num_groups=32, num_channels=block_in, eps=1e-6, affine=True
|
||||
)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
self.conv_out = Conv2dLayer(
|
||||
block_in,
|
||||
2 * latent_channels if double_latent else latent_channels,
|
||||
kernel_size=3,
|
||||
@@ -779,10 +780,8 @@ class ChameleonVQVAE(nn.Module):
|
||||
super().__init__()
|
||||
self.encoder = ChameleonVQVAEEncoder(config)
|
||||
self.quantize = ChameleonVQVAEVectorQuantizer(config)
|
||||
self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(
|
||||
config.embed_dim, config.latent_channels, 1
|
||||
)
|
||||
self.quant_conv = Conv2dLayer(config.latent_channels, config.embed_dim, 1)
|
||||
self.post_quant_conv = Conv2dLayer(config.embed_dim, config.latent_channels, 1)
|
||||
self.eval() # Chameleon's VQ model is frozen
|
||||
|
||||
def encode(
|
||||
|
||||
Reference in New Issue
Block a user