[Misc][Model][Refactor] Pass the prefix into Linear layers (#31669)

Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
Wang Kunpeng
2026-01-06 04:03:18 +08:00
committed by GitHub
parent 02dbb933cb
commit 5708297e4e
17 changed files with 181 additions and 40 deletions

View File

@@ -82,6 +82,7 @@ class MambaMixer(MambaBase, CustomOp):
input_size=conv_kernel_size,
output_size=intermediate_size,
bias=use_conv_bias,
prefix=f"{prefix}.conv1d",
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
@@ -90,7 +91,10 @@ class MambaMixer(MambaBase, CustomOp):
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
self.in_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=use_bias
hidden_size,
[intermediate_size] * 2,
bias=use_bias,
prefix=f"{prefix}.in_proj",
)
# selective projection used to make dt, B and C input dependent
@@ -98,12 +102,17 @@ class MambaMixer(MambaBase, CustomOp):
intermediate_size,
time_step_rank + ssm_state_size * 2,
bias=False,
prefix=f"{prefix}.x_proj",
)
# time step projection (discretization) -
# In the forward we need to apply dt_proj without the bias,
# as the bias is added in the selective scan kernel.
self.dt_proj = ColumnParallelLinear(
time_step_rank, intermediate_size, bias=True, skip_bias_add=True
time_step_rank,
intermediate_size,
bias=True,
skip_bias_add=True,
prefix=f"{prefix}.dt_proj",
)
def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
@@ -136,6 +145,7 @@ class MambaMixer(MambaBase, CustomOp):
hidden_size,
bias=use_bias,
input_is_parallel=True,
prefix=f"{prefix}.out_proj",
)
self.dt_layernorm = (

View File

@@ -127,11 +127,16 @@ class AriaProjectorMLP(nn.Module):
in_features: int,
hidden_features: int,
output_dim: int,
prefix: str = "",
) -> None:
super().__init__()
self.linear_in = ColumnParallelLinear(in_features, hidden_features, bias=False)
self.linear_out = RowParallelLinear(hidden_features, output_dim, bias=False)
self.linear_in = ColumnParallelLinear(
in_features, hidden_features, bias=False, prefix=f"{prefix}.linear_in"
)
self.linear_out = RowParallelLinear(
hidden_features, output_dim, bias=False, prefix=f"{prefix}.linear_out"
)
self.act = get_act_fn("gelu_new")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -154,7 +159,7 @@ class AriaProjector(nn.Module):
A tensor with the shape of (batch_size, query_number, output_dim)
"""
def __init__(self, config: AriaConfig) -> None:
def __init__(self, config: AriaConfig, prefix: str = "") -> None:
super().__init__()
self.patch_to_query_dict = config.projector_patch_to_query_dict
@@ -174,7 +179,10 @@ class AriaProjector(nn.Module):
self.layer_norm = nn.LayerNorm(self.in_features)
self.feed_forward = AriaProjectorMLP(
self.in_features, self.hidden_features, self.output_dim
self.in_features,
self.hidden_features,
self.output_dim,
prefix=f"{prefix}.feed_forward",
)
def forward(
@@ -536,7 +544,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
quant_config=quant_config,
prefix=f"{prefix}.vision_tower",
)
self.multi_modal_projector = AriaProjector(config)
self.multi_modal_projector = AriaProjector(
config, prefix=maybe_prefix(prefix, "multi_modal_projector")
)
self.vocab_size = config.text_config.vocab_size
self.language_model = AriaTextModel(
vllm_config=vllm_config.with_hf_config(config.text_config),

View File

@@ -166,7 +166,7 @@ class GPTNeoXLayer(nn.Module):
self.attention = GPTNeoXAttention(
config, cache_config, quant_config, prefix=f"{prefix}.attention"
)
self.mlp = GPTNeoXMLP(config, quant_config)
self.mlp = GPTNeoXMLP(config, quant_config, prefix=f"{prefix}.mlp")
def forward(
self,

View File

@@ -427,6 +427,7 @@ class HunYuanSparseMoeBlock(nn.Module):
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
prefix=f"{prefix}.shared_mlp",
)
else:
self.shared_mlp = None

View File

@@ -78,6 +78,7 @@ class JambaMoE(nn.Module):
bias=False,
quant_config=None,
params_dtype=params_dtype,
prefix=f"{prefix}.router",
)
self.experts = FusedMoE(

View File

@@ -27,15 +27,23 @@ logger = init_logger(__name__)
class JinaVLScorer(nn.Module):
def __init__(self, model_config: "ModelConfig"):
def __init__(self, model_config: "ModelConfig", prefix: str = ""):
super().__init__()
config = model_config.hf_config.get_text_config()
head_dtype = model_config.head_dtype
self.dense = ColumnParallelLinear(
config.hidden_size, config.hidden_size, params_dtype=head_dtype, bias=True
config.hidden_size,
config.hidden_size,
params_dtype=head_dtype,
bias=True,
prefix=f"{prefix}.dense",
)
self.out_proj = RowParallelLinear(
config.hidden_size, config.num_labels, params_dtype=head_dtype, bias=True
config.hidden_size,
config.num_labels,
params_dtype=head_dtype,
bias=True,
prefix=f"{prefix}.out_proj",
)
def forward(self, x, **kwargs):
@@ -94,7 +102,9 @@ class JinaVLForSequenceClassification(
pooler_config = vllm_config.model_config.pooler_config
assert pooler_config is not None
self.score = JinaVLScorer(vllm_config.model_config)
self.score = JinaVLScorer(
vllm_config.model_config, prefix=maybe_prefix(prefix, "score")
)
self.pooler = DispatchPooler(
{
"token_classify": Pooler.for_token_classify(

View File

@@ -90,6 +90,7 @@ class MiniCPMMoE(nn.Module):
intermediate_size: int,
params_dtype: torch.dtype | None = None,
tp_size: int | None = None,
prefix: str = "",
):
super().__init__()
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
@@ -108,6 +109,7 @@ class MiniCPMMoE(nn.Module):
bias=False,
params_dtype=self.params_dtype,
quant_config=None,
prefix=f"{prefix}.gate",
)
self.ws = nn.Parameter(
@@ -352,6 +354,7 @@ class MiniCPMDecoderLayer(nn.Module):
hidden_act=self.config.hidden_act,
hidden_act_param=getattr(self.config, "hidden_act_param", 0.0),
quant_config=self.quant_config,
prefix=f"{self.prefix}.mlp",
)
else:
self.mlp = MiniCPMMoE(
@@ -359,6 +362,7 @@ class MiniCPMDecoderLayer(nn.Module):
top_k=self.config.num_experts_per_tok,
hidden_size=self.config.hidden_size,
intermediate_size=self.config.intermediate_size,
prefix=f"{self.prefix}.mlp",
)
def forward(

View File

@@ -108,6 +108,7 @@ class EagleMiniCPMDecoderLayer(nn.Module):
top_k=self.config.num_experts_per_tok,
hidden_size=self.config.hidden_size,
intermediate_size=self.config.intermediate_size,
prefix=f"{self.prefix}.mlp",
)
def forward(

View File

@@ -67,6 +67,7 @@ class EagleMistralLarge3Model(DeepseekV2Model):
input_is_parallel=False,
quant_config=quant_config,
return_bias=False,
prefix=maybe_prefix(prefix, "fc"),
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(

View File

@@ -63,7 +63,9 @@ class ModernBertEmbeddings(nn.Module):
class ModernBertAttention(nn.Module):
def __init__(self, config: ModernBertConfig, layer_id: int | None = None):
def __init__(
self, config: ModernBertConfig, layer_id: int | None = None, prefix: str = ""
):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
@@ -80,6 +82,7 @@ class ModernBertAttention(nn.Module):
self.head_dim,
self.num_heads,
bias=config.attention_bias,
prefix=f"{prefix}.Wqkv",
)
if layer_types := getattr(config, "layer_types", None):
@@ -117,7 +120,10 @@ class ModernBertAttention(nn.Module):
per_layer_sliding_window=sliding_window,
)
self.Wo = RowParallelLinear(
config.hidden_size, config.hidden_size, bias=config.attention_bias
config.hidden_size,
config.hidden_size,
bias=config.attention_bias,
prefix=f"{prefix}.Wo",
)
def forward(
@@ -135,7 +141,7 @@ class ModernBertAttention(nn.Module):
class ModernBertMLP(nn.Module):
def __init__(self, config: ModernBertConfig):
def __init__(self, config: ModernBertConfig, prefix: str = ""):
super().__init__()
self.config = config
self.Wi = nn.Linear(
@@ -143,7 +149,10 @@ class ModernBertMLP(nn.Module):
)
self.act = nn.GELU()
self.Wo = RowParallelLinear(
config.intermediate_size, config.hidden_size, bias=config.mlp_bias
config.intermediate_size,
config.hidden_size,
bias=config.mlp_bias,
prefix=f"{prefix}.Wo",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -163,11 +172,13 @@ class ModernBertLayer(nn.Module):
self.attn_norm = nn.LayerNorm(
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
)
self.attn = ModernBertAttention(config=config, layer_id=layer_id)
self.attn = ModernBertAttention(
config=config, layer_id=layer_id, prefix=f"{prefix}.attn"
)
self.mlp_norm = nn.LayerNorm(
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
)
self.mlp = ModernBertMLP(config)
self.mlp = ModernBertMLP(config, prefix=f"{prefix}.mlp")
def forward(
self,
@@ -189,7 +200,11 @@ class ModernBertEncoderLayer(nn.Module):
config = vllm_config.model_config.hf_config
self.layers = nn.ModuleList(
[
ModernBertLayer(config=config, layer_id=layer_id)
ModernBertLayer(
config=config,
layer_id=layer_id,
prefix=f"{prefix}.layers.{layer_id}",
)
for layer_id in range(config.num_hidden_layers)
]
)
@@ -220,7 +235,9 @@ class ModernBertModel(nn.Module):
config = vllm_config.model_config.hf_config
self.config = config
self.embeddings = ModernBertEmbeddings(config)
self.encoder_layer = ModernBertEncoderLayer(vllm_config)
self.encoder_layer = ModernBertEncoderLayer(
vllm_config, prefix=f"{prefix}.encoder_layer"
)
self.final_norm = nn.LayerNorm(
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
)

View File

@@ -142,6 +142,7 @@ class ViTMLP(nn.Module):
self,
config: VisionBackboneConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.w1 = ColumnParallelLinear(
@@ -149,6 +150,7 @@ class ViTMLP(nn.Module):
config.image_mlp_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.w1",
)
# Activation function.
assert config.image_mlp_activations == "quick_gelu"
@@ -158,6 +160,7 @@ class ViTMLP(nn.Module):
config.image_emb_dim,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.w2",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -176,6 +179,7 @@ class MultiHeadDotProductAttention(nn.Module):
use_bias: bool = True,
nlayers: int = 1,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
@@ -202,24 +206,28 @@ class MultiHeadDotProductAttention(nn.Module):
self.total_num_heads * self.head_dim,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.wq",
)
self.wk = ColumnParallelLinear(
nlayers * self.hidden_size,
self.total_num_kv_heads * self.head_dim,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.wk",
)
self.wv = ColumnParallelLinear(
nlayers * self.hidden_size,
self.total_num_kv_heads * self.head_dim,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.wv",
)
self.wo = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.wo",
)
self.scale = self.head_dim**-0.5
@@ -254,10 +262,15 @@ class ResidualAttentionBlock(nn.Module):
self,
config: VisionBackboneConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.attention = MultiHeadDotProductAttention(config, quant_config=quant_config)
self.feed_forward = ViTMLP(config, quant_config)
self.attention = MultiHeadDotProductAttention(
config, quant_config=quant_config, prefix=f"{prefix}.attention"
)
self.feed_forward = ViTMLP(
config, quant_config, prefix=f"{prefix}.feed_forward"
)
self.attention_norm = nn.LayerNorm(
config.image_emb_dim,
eps=config.image_norm_eps,
@@ -280,12 +293,15 @@ class BlockCollection(nn.Module):
self,
config: VisionBackboneConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(config, quant_config)
for _ in range(config.image_num_layers)
ResidualAttentionBlock(
config, quant_config, prefix=f"{prefix}.resblocks.{i}"
)
for i in range(config.image_num_layers)
]
)
@@ -308,6 +324,7 @@ class VisionTransformer(nn.Module):
self,
config: VisionBackboneConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
scale = config.image_emb_dim**-0.5
@@ -324,7 +341,9 @@ class VisionTransformer(nn.Module):
bias=False,
)
self.pre_ln = nn.LayerNorm(config.image_emb_dim, eps=config.image_norm_eps)
self.transformer = BlockCollection(config, quant_config)
self.transformer = BlockCollection(
config, quant_config, prefix=f"{prefix}.transformer"
)
def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
cls_emb = self.positional_embedding[0:1]
@@ -419,6 +438,7 @@ class MolmoAttention(nn.Module):
self.total_num_kv_heads,
bias=config.qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.tp_rank: int | None = None
@@ -454,6 +474,7 @@ class MolmoAttention(nn.Module):
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
def _apply_qk_norm(
@@ -493,6 +514,7 @@ class LanguageModelMLP(nn.Module):
config: PretrainedConfig,
input_dim: int | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -503,6 +525,7 @@ class LanguageModelMLP(nn.Module):
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
# Activation function.
self.act_fn = MulAndSilu()
@@ -512,6 +535,7 @@ class LanguageModelMLP(nn.Module):
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
def forward(
@@ -532,6 +556,7 @@ class ImageProjectorMLP(nn.Module):
config: PretrainedConfig,
input_dim: int | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -542,6 +567,7 @@ class ImageProjectorMLP(nn.Module):
[self.intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.merged_linear",
)
# Activation function.
self.act_fn = SiluAndMul()
@@ -552,6 +578,7 @@ class ImageProjectorMLP(nn.Module):
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
def forward(
@@ -579,7 +606,9 @@ class MolmoDecoderLayer(nn.Module):
)
# MLP block.
self.mlp = LanguageModelMLP(config, quant_config=quant_config)
self.mlp = LanguageModelMLP(
config, quant_config=quant_config, prefix=f"{prefix}.mlp"
)
# LayerNorm
assert config.layer_norm_type == "rms"
@@ -643,6 +672,7 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
config: PretrainedConfig,
vision_config: VisionBackboneConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
self.vit_layers = VIT_LAYERS
@@ -651,18 +681,24 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
(self.image_num_patch[0] + 1) // POOLING_SIZE,
(self.image_num_patch[1] + 1) // POOLING_SIZE,
)
self.image_vit = VisionTransformer(vision_config, quant_config=quant_config)
self.image_vit = VisionTransformer(
vision_config, quant_config=quant_config, prefix=f"{prefix}.image_vit"
)
self.num_prefix_tokens = self.image_vit.num_prefix_tokens
assert self.num_prefix_tokens in {0, 1}, (
"Only 0 or 1 prefix tokens are supported"
)
self.image_pooling_2d = MultiHeadDotProductAttention(
vision_config, nlayers=len(self.vit_layers), quant_config=quant_config
vision_config,
nlayers=len(self.vit_layers),
quant_config=quant_config,
prefix=f"{prefix}.image_pooling_2d",
)
self.image_projector = ImageProjectorMLP(
config,
input_dim=vision_config.image_emb_dim,
quant_config=quant_config,
prefix=f"{prefix}.image_projector",
)
image_dim = vision_config.image_emb_dim * len(self.vit_layers)
@@ -1405,7 +1441,12 @@ class MolmoForCausalLM(
self.multimodal_config = multimodal_config
vision_config = VisionBackboneConfig()
self.vision_backbone = MolmoVisionBackbone(config, vision_config, quant_config)
self.vision_backbone = MolmoVisionBackbone(
config,
vision_config,
quant_config,
prefix=maybe_prefix(prefix, "vision_backbone"),
)
self.model = MolmoModel(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)

View File

@@ -86,7 +86,11 @@ class OlmoeMoE(nn.Module):
# Gate always runs at half / full precision for now.
self.gate = ReplicatedLinear(
hidden_size, num_experts, bias=False, quant_config=None
hidden_size,
num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)
self.experts = FusedMoE(

View File

@@ -272,6 +272,7 @@ class PhiMoE(nn.Module):
bias=False,
params_dtype=params_dtype,
quant_config=None,
prefix=f"{prefix}.gate",
)
self.experts = FusedMoE(

View File

@@ -56,13 +56,22 @@ class QWenMLP(nn.Module):
intermediate_size: int,
hidden_act: str = "silu",
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.c_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, quant_config=quant_config
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
if hidden_act != "silu":
raise ValueError(
@@ -163,7 +172,10 @@ class QWenBlock(nn.Module):
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mlp = QWenMLP(
config.hidden_size, config.intermediate_size // 2, quant_config=quant_config
config.hidden_size,
config.intermediate_size // 2,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
def forward(

View File

@@ -864,6 +864,7 @@ class Qwen3NextDecoderLayer(nn.Module):
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = Qwen3NextRMSNorm(

View File

@@ -109,6 +109,7 @@ class VisualAttention(nn.Module):
bias: bool = True,
kdim: int | None = None,
vdim: int | None = None,
prefix: str = "",
):
super().__init__()
self.embed_dim = embed_dim
@@ -128,8 +129,12 @@ class VisualAttention(nn.Module):
assert self._qkv_same_embed_dim, (
"Visual Attention implementation only supports self-attention"
)
self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim)
self.out_proj = ReplicatedLinear(embed_dim, embed_dim)
self.in_proj = ReplicatedLinear(
embed_dim, 3 * embed_dim, prefix=f"{prefix}.in_proj"
)
self.out_proj = ReplicatedLinear(
embed_dim, embed_dim, prefix=f"{prefix}.out_proj"
)
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
def forward(
@@ -214,10 +219,15 @@ class QwenVLMLP(nn.Module):
hidden_size: int,
intermediate_size: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.c_fc = ColumnParallelLinear(
hidden_size, intermediate_size, bias=True, quant_config=quant_config
hidden_size,
intermediate_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_fc",
)
self.act_fn = get_act_fn("gelu")
self.c_proj = RowParallelLinear(
@@ -225,6 +235,7 @@ class QwenVLMLP(nn.Module):
hidden_size,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.c_proj",
)
def forward(self, x):
@@ -242,17 +253,19 @@ class VisualAttentionBlock(nn.Module):
mlp_ratio: float = 4.0,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.ln_1 = norm_layer(d_model)
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.attn = VisualAttention(d_model, n_head)
self.attn = VisualAttention(d_model, n_head, prefix=f"{prefix}.attn")
self.mlp = QwenVLMLP(
hidden_size=d_model,
intermediate_size=mlp_width,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
def attention(
@@ -282,6 +295,7 @@ class TransformerBlock(nn.Module):
mlp_ratio: float = 4.0,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.width = width
@@ -295,8 +309,9 @@ class TransformerBlock(nn.Module):
mlp_ratio,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.resblocks.{i}",
)
for _ in range(layers)
for i in range(layers)
]
)
@@ -327,6 +342,7 @@ class VisionTransformer(nn.Module):
output_dim: int = 512,
image_start_id: int = 151857,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
**kwargs,
):
super().__init__()
@@ -356,6 +372,7 @@ class VisionTransformer(nn.Module):
mlp_ratio,
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.transformer",
)
self.attn_pool = Resampler2(
@@ -366,6 +383,7 @@ class VisionTransformer(nn.Module):
norm_layer=norm_layer,
adaptive=False,
do_post_projection=False,
prefix=f"{prefix}.attn_pool",
).to(
device=self.positional_embedding.device,
dtype=self.positional_embedding.dtype,
@@ -413,7 +431,9 @@ class QwenVLModel(QWenModel):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.visual = VisionTransformer(**config.visual, quant_config=quant_config)
self.visual = VisionTransformer(
**config.visual, quant_config=quant_config, prefix=f"{prefix}.visual"
)
@lru_cache(maxsize=1)

View File

@@ -86,7 +86,13 @@ class Zamba2LoRA(nn.Module):
B_class = MergedColumnParallelLinear
else:
B_class = ColumnParallelLinear
self.B = B_class(rank, output_dim, bias=False, quant_config=quant_config)
self.B = B_class(
rank,
output_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.B",
)
def forward(
self,
@@ -346,6 +352,7 @@ class Zamba2MLP(nn.Module):
config.adapter_rank,
2 * [self.intermediate_size],
quant_config,
prefix=f"{prefix}.gate_up_proj_adapter_list.{block_idx}",
)
else:
gate_up_proj_adapter = nn.Identity()