[Misc][Model][Refactor] Pass the prefix into Linear layers (#31669)

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
Wang Kunpeng
2026-01-06 04:03:18 +08:00
committed by GitHub
parent 02dbb933cb
commit 5708297e4e
17 changed files with 181 additions and 40 deletions

View File

@@ -109,6 +109,7 @@ class VisualAttention(nn.Module):
bias: bool = True,
kdim: int | None = None,
vdim: int | None = None,
prefix: str = "",
):
super().__init__()
self.embed_dim = embed_dim
@@ -128,8 +129,12 @@ class VisualAttention(nn.Module):
assert self._qkv_same_embed_dim, (
"Visual Attention implementation only supports self-attention"
)
self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim)
self.out_proj = ReplicatedLinear(embed_dim, embed_dim)
self.in_proj = ReplicatedLinear(
embed_dim, 3 * embed_dim, prefix=f"{prefix}.in_proj"
)
self.out_proj = ReplicatedLinear(
embed_dim, embed_dim, prefix=f"{prefix}.out_proj"
)
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
def forward(
@@ -214,10 +219,15 @@ class QwenVLMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.c_fc = ColumnParallelLinear(
hidden_size, intermediate_size, bias=True, quant_config=quant_config
hidden_size,
intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_fc",
)
self.act_fn = get_act_fn("gelu")
self.c_proj = RowParallelLinear(
@@ -225,6 +235,7 @@ class QwenVLMLP(nn.Module):
hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
def forward(self, x):
@@ -242,17 +253,19 @@ class VisualAttentionBlock(nn.Module):
mlp_ratio: float = 4.0,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.ln_1 = norm_layer(d_model)
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.attn = VisualAttention(d_model, n_head)
self.attn = VisualAttention(d_model, n_head, prefix=f"{prefix}.attn")
self.mlp = QwenVLMLP(
hidden_size=d_model,
intermediate_size=mlp_width,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
def attention(
@@ -282,6 +295,7 @@ class TransformerBlock(nn.Module):
mlp_ratio: float = 4.0,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.width = width
@@ -295,8 +309,9 @@ class TransformerBlock(nn.Module):
mlp_ratio,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.resblocks.{i}",
)
for _ in range(layers)
for i in range(layers)
]
)
@@ -327,6 +342,7 @@ class VisionTransformer(nn.Module):
output_dim: int = 512,
image_start_id: int = 151857,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
**kwargs,
):
super().__init__()
@@ -356,6 +372,7 @@ class VisionTransformer(nn.Module):
mlp_ratio,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.transformer",
)
self.attn_pool = Resampler2(
@@ -366,6 +383,7 @@ class VisionTransformer(nn.Module):
norm_layer=norm_layer,
adaptive=False,
do_post_projection=False,
prefix=f"{prefix}.attn_pool",
).to(
device=self.positional_embedding.device,
dtype=self.positional_embedding.dtype,
@@ -413,7 +431,9 @@ class QwenVLModel(QWenModel):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.visual = VisionTransformer(**config.visual, quant_config=quant_config)
self.visual = VisionTransformer(
**config.visual, quant_config=quant_config, prefix=f"{prefix}.visual"
)
@lru_cache(maxsize=1)