[Misc][Model][Refactor] Pass the prefix into Linear layers (#31669)
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
@@ -109,6 +109,7 @@ class VisualAttention(nn.Module):
|
||||
bias: bool = True,
|
||||
kdim: int | None = None,
|
||||
vdim: int | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
@@ -128,8 +129,12 @@ class VisualAttention(nn.Module):
|
||||
assert self._qkv_same_embed_dim, (
|
||||
"Visual Attention implementation only supports self-attention"
|
||||
)
|
||||
self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim)
|
||||
self.out_proj = ReplicatedLinear(embed_dim, embed_dim)
|
||||
self.in_proj = ReplicatedLinear(
|
||||
embed_dim, 3 * embed_dim, prefix=f"{prefix}.in_proj"
|
||||
)
|
||||
self.out_proj = ReplicatedLinear(
|
||||
embed_dim, embed_dim, prefix=f"{prefix}.out_proj"
|
||||
)
|
||||
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
||||
|
||||
def forward(
|
||||
@@ -214,10 +219,15 @@ class QwenVLMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.c_fc = ColumnParallelLinear(
|
||||
hidden_size, intermediate_size, bias=True, quant_config=quant_config
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_fc",
|
||||
)
|
||||
self.act_fn = get_act_fn("gelu")
|
||||
self.c_proj = RowParallelLinear(
|
||||
@@ -225,6 +235,7 @@ class QwenVLMLP(nn.Module):
|
||||
hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -242,17 +253,19 @@ class VisualAttentionBlock(nn.Module):
|
||||
mlp_ratio: float = 4.0,
|
||||
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.ln_1 = norm_layer(d_model)
|
||||
self.ln_2 = norm_layer(d_model)
|
||||
mlp_width = int(d_model * mlp_ratio)
|
||||
self.attn = VisualAttention(d_model, n_head)
|
||||
self.attn = VisualAttention(d_model, n_head, prefix=f"{prefix}.attn")
|
||||
self.mlp = QwenVLMLP(
|
||||
hidden_size=d_model,
|
||||
intermediate_size=mlp_width,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
|
||||
def attention(
|
||||
@@ -282,6 +295,7 @@ class TransformerBlock(nn.Module):
|
||||
mlp_ratio: float = 4.0,
|
||||
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
@@ -295,8 +309,9 @@ class TransformerBlock(nn.Module):
|
||||
mlp_ratio,
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.resblocks.{i}",
|
||||
)
|
||||
for _ in range(layers)
|
||||
for i in range(layers)
|
||||
]
|
||||
)
|
||||
|
||||
@@ -327,6 +342,7 @@ class VisionTransformer(nn.Module):
|
||||
output_dim: int = 512,
|
||||
image_start_id: int = 151857,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -356,6 +372,7 @@ class VisionTransformer(nn.Module):
|
||||
mlp_ratio,
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.transformer",
|
||||
)
|
||||
|
||||
self.attn_pool = Resampler2(
|
||||
@@ -366,6 +383,7 @@ class VisionTransformer(nn.Module):
|
||||
norm_layer=norm_layer,
|
||||
adaptive=False,
|
||||
do_post_projection=False,
|
||||
prefix=f"{prefix}.attn_pool",
|
||||
).to(
|
||||
device=self.positional_embedding.device,
|
||||
dtype=self.positional_embedding.dtype,
|
||||
@@ -413,7 +431,9 @@ class QwenVLModel(QWenModel):
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.visual = VisionTransformer(**config.visual, quant_config=quant_config)
|
||||
self.visual = VisionTransformer(
|
||||
**config.visual, quant_config=quant_config, prefix=f"{prefix}.visual"
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
|
||||
Reference in New Issue
Block a user