[Misc][Model][Refactor] Pass the prefix into Linear layers (#31669)
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
This commit is contained in:
@@ -82,6 +82,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
input_size=conv_kernel_size,
|
||||
output_size=intermediate_size,
|
||||
bias=use_conv_bias,
|
||||
prefix=f"{prefix}.conv1d",
|
||||
)
|
||||
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
||||
# Can't do this in `weight_loader` since it already exists in
|
||||
@@ -90,7 +91,10 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
||||
|
||||
self.in_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2, bias=use_bias
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=use_bias,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
)
|
||||
|
||||
# selective projection used to make dt, B and C input dependent
|
||||
@@ -98,12 +102,17 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
intermediate_size,
|
||||
time_step_rank + ssm_state_size * 2,
|
||||
bias=False,
|
||||
prefix=f"{prefix}.x_proj",
|
||||
)
|
||||
# time step projection (discretization) -
|
||||
# In the forward we need to apply dt_proj without the bias,
|
||||
# as the bias is added in the selective scan kernel.
|
||||
self.dt_proj = ColumnParallelLinear(
|
||||
time_step_rank, intermediate_size, bias=True, skip_bias_add=True
|
||||
time_step_rank,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
skip_bias_add=True,
|
||||
prefix=f"{prefix}.dt_proj",
|
||||
)
|
||||
|
||||
def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
|
||||
@@ -136,6 +145,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
hidden_size,
|
||||
bias=use_bias,
|
||||
input_is_parallel=True,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
self.dt_layernorm = (
|
||||
|
||||
@@ -127,11 +127,16 @@ class AriaProjectorMLP(nn.Module):
|
||||
in_features: int,
|
||||
hidden_features: int,
|
||||
output_dim: int,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.linear_in = ColumnParallelLinear(in_features, hidden_features, bias=False)
|
||||
self.linear_out = RowParallelLinear(hidden_features, output_dim, bias=False)
|
||||
self.linear_in = ColumnParallelLinear(
|
||||
in_features, hidden_features, bias=False, prefix=f"{prefix}.linear_in"
|
||||
)
|
||||
self.linear_out = RowParallelLinear(
|
||||
hidden_features, output_dim, bias=False, prefix=f"{prefix}.linear_out"
|
||||
)
|
||||
self.act = get_act_fn("gelu_new")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -154,7 +159,7 @@ class AriaProjector(nn.Module):
|
||||
A tensor with the shape of (batch_size, query_number, output_dim)
|
||||
"""
|
||||
|
||||
def __init__(self, config: AriaConfig) -> None:
|
||||
def __init__(self, config: AriaConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
self.patch_to_query_dict = config.projector_patch_to_query_dict
|
||||
@@ -174,7 +179,10 @@ class AriaProjector(nn.Module):
|
||||
|
||||
self.layer_norm = nn.LayerNorm(self.in_features)
|
||||
self.feed_forward = AriaProjectorMLP(
|
||||
self.in_features, self.hidden_features, self.output_dim
|
||||
self.in_features,
|
||||
self.hidden_features,
|
||||
self.output_dim,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -536,7 +544,9 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.vision_tower",
|
||||
)
|
||||
self.multi_modal_projector = AriaProjector(config)
|
||||
self.multi_modal_projector = AriaProjector(
|
||||
config, prefix=maybe_prefix(prefix, "multi_modal_projector")
|
||||
)
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.language_model = AriaTextModel(
|
||||
vllm_config=vllm_config.with_hf_config(config.text_config),
|
||||
|
||||
@@ -166,7 +166,7 @@ class GPTNeoXLayer(nn.Module):
|
||||
self.attention = GPTNeoXAttention(
|
||||
config, cache_config, quant_config, prefix=f"{prefix}.attention"
|
||||
)
|
||||
self.mlp = GPTNeoXMLP(config, quant_config)
|
||||
self.mlp = GPTNeoXMLP(config, quant_config, prefix=f"{prefix}.mlp")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -427,6 +427,7 @@ class HunYuanSparseMoeBlock(nn.Module):
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
prefix=f"{prefix}.shared_mlp",
|
||||
)
|
||||
else:
|
||||
self.shared_mlp = None
|
||||
|
||||
@@ -78,6 +78,7 @@ class JambaMoE(nn.Module):
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
params_dtype=params_dtype,
|
||||
prefix=f"{prefix}.router",
|
||||
)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
|
||||
@@ -27,15 +27,23 @@ logger = init_logger(__name__)
|
||||
|
||||
|
||||
class JinaVLScorer(nn.Module):
|
||||
def __init__(self, model_config: "ModelConfig"):
|
||||
def __init__(self, model_config: "ModelConfig", prefix: str = ""):
|
||||
super().__init__()
|
||||
config = model_config.hf_config.get_text_config()
|
||||
head_dtype = model_config.head_dtype
|
||||
self.dense = ColumnParallelLinear(
|
||||
config.hidden_size, config.hidden_size, params_dtype=head_dtype, bias=True
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
params_dtype=head_dtype,
|
||||
bias=True,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
config.hidden_size, config.num_labels, params_dtype=head_dtype, bias=True
|
||||
config.hidden_size,
|
||||
config.num_labels,
|
||||
params_dtype=head_dtype,
|
||||
bias=True,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
@@ -94,7 +102,9 @@ class JinaVLForSequenceClassification(
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
|
||||
self.score = JinaVLScorer(vllm_config.model_config)
|
||||
self.score = JinaVLScorer(
|
||||
vllm_config.model_config, prefix=maybe_prefix(prefix, "score")
|
||||
)
|
||||
self.pooler = DispatchPooler(
|
||||
{
|
||||
"token_classify": Pooler.for_token_classify(
|
||||
|
||||
@@ -90,6 +90,7 @@ class MiniCPMMoE(nn.Module):
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
tp_size: int | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = tp_size or get_tensor_model_parallel_world_size()
|
||||
@@ -108,6 +109,7 @@ class MiniCPMMoE(nn.Module):
|
||||
bias=False,
|
||||
params_dtype=self.params_dtype,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
self.ws = nn.Parameter(
|
||||
@@ -352,6 +354,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
||||
hidden_act=self.config.hidden_act,
|
||||
hidden_act_param=getattr(self.config, "hidden_act_param", 0.0),
|
||||
quant_config=self.quant_config,
|
||||
prefix=f"{self.prefix}.mlp",
|
||||
)
|
||||
else:
|
||||
self.mlp = MiniCPMMoE(
|
||||
@@ -359,6 +362,7 @@ class MiniCPMDecoderLayer(nn.Module):
|
||||
top_k=self.config.num_experts_per_tok,
|
||||
hidden_size=self.config.hidden_size,
|
||||
intermediate_size=self.config.intermediate_size,
|
||||
prefix=f"{self.prefix}.mlp",
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -108,6 +108,7 @@ class EagleMiniCPMDecoderLayer(nn.Module):
|
||||
top_k=self.config.num_experts_per_tok,
|
||||
hidden_size=self.config.hidden_size,
|
||||
intermediate_size=self.config.intermediate_size,
|
||||
prefix=f"{self.prefix}.mlp",
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -67,6 +67,7 @@ class EagleMistralLarge3Model(DeepseekV2Model):
|
||||
input_is_parallel=False,
|
||||
quant_config=quant_config,
|
||||
return_bias=False,
|
||||
prefix=maybe_prefix(prefix, "fc"),
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
|
||||
@@ -63,7 +63,9 @@ class ModernBertEmbeddings(nn.Module):
|
||||
|
||||
|
||||
class ModernBertAttention(nn.Module):
|
||||
def __init__(self, config: ModernBertConfig, layer_id: int | None = None):
|
||||
def __init__(
|
||||
self, config: ModernBertConfig, layer_id: int | None = None, prefix: str = ""
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -80,6 +82,7 @@ class ModernBertAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
bias=config.attention_bias,
|
||||
prefix=f"{prefix}.Wqkv",
|
||||
)
|
||||
|
||||
if layer_types := getattr(config, "layer_types", None):
|
||||
@@ -117,7 +120,10 @@ class ModernBertAttention(nn.Module):
|
||||
per_layer_sliding_window=sliding_window,
|
||||
)
|
||||
self.Wo = RowParallelLinear(
|
||||
config.hidden_size, config.hidden_size, bias=config.attention_bias
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=config.attention_bias,
|
||||
prefix=f"{prefix}.Wo",
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -135,7 +141,7 @@ class ModernBertAttention(nn.Module):
|
||||
|
||||
|
||||
class ModernBertMLP(nn.Module):
|
||||
def __init__(self, config: ModernBertConfig):
|
||||
def __init__(self, config: ModernBertConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.Wi = nn.Linear(
|
||||
@@ -143,7 +149,10 @@ class ModernBertMLP(nn.Module):
|
||||
)
|
||||
self.act = nn.GELU()
|
||||
self.Wo = RowParallelLinear(
|
||||
config.intermediate_size, config.hidden_size, bias=config.mlp_bias
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=config.mlp_bias,
|
||||
prefix=f"{prefix}.Wo",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -163,11 +172,13 @@ class ModernBertLayer(nn.Module):
|
||||
self.attn_norm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
|
||||
)
|
||||
self.attn = ModernBertAttention(config=config, layer_id=layer_id)
|
||||
self.attn = ModernBertAttention(
|
||||
config=config, layer_id=layer_id, prefix=f"{prefix}.attn"
|
||||
)
|
||||
self.mlp_norm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
|
||||
)
|
||||
self.mlp = ModernBertMLP(config)
|
||||
self.mlp = ModernBertMLP(config, prefix=f"{prefix}.mlp")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -189,7 +200,11 @@ class ModernBertEncoderLayer(nn.Module):
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ModernBertLayer(config=config, layer_id=layer_id)
|
||||
ModernBertLayer(
|
||||
config=config,
|
||||
layer_id=layer_id,
|
||||
prefix=f"{prefix}.layers.{layer_id}",
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
@@ -220,7 +235,9 @@ class ModernBertModel(nn.Module):
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.config = config
|
||||
self.embeddings = ModernBertEmbeddings(config)
|
||||
self.encoder_layer = ModernBertEncoderLayer(vllm_config)
|
||||
self.encoder_layer = ModernBertEncoderLayer(
|
||||
vllm_config, prefix=f"{prefix}.encoder_layer"
|
||||
)
|
||||
self.final_norm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.norm_eps, bias=config.norm_bias
|
||||
)
|
||||
|
||||
@@ -142,6 +142,7 @@ class ViTMLP(nn.Module):
|
||||
self,
|
||||
config: VisionBackboneConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.w1 = ColumnParallelLinear(
|
||||
@@ -149,6 +150,7 @@ class ViTMLP(nn.Module):
|
||||
config.image_mlp_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.w1",
|
||||
)
|
||||
# Activation function.
|
||||
assert config.image_mlp_activations == "quick_gelu"
|
||||
@@ -158,6 +160,7 @@ class ViTMLP(nn.Module):
|
||||
config.image_emb_dim,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.w2",
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -176,6 +179,7 @@ class MultiHeadDotProductAttention(nn.Module):
|
||||
use_bias: bool = True,
|
||||
nlayers: int = 1,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -202,24 +206,28 @@ class MultiHeadDotProductAttention(nn.Module):
|
||||
self.total_num_heads * self.head_dim,
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wq",
|
||||
)
|
||||
self.wk = ColumnParallelLinear(
|
||||
nlayers * self.hidden_size,
|
||||
self.total_num_kv_heads * self.head_dim,
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wk",
|
||||
)
|
||||
self.wv = ColumnParallelLinear(
|
||||
nlayers * self.hidden_size,
|
||||
self.total_num_kv_heads * self.head_dim,
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wv",
|
||||
)
|
||||
self.wo = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wo",
|
||||
)
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
@@ -254,10 +262,15 @@ class ResidualAttentionBlock(nn.Module):
|
||||
self,
|
||||
config: VisionBackboneConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.attention = MultiHeadDotProductAttention(config, quant_config=quant_config)
|
||||
self.feed_forward = ViTMLP(config, quant_config)
|
||||
self.attention = MultiHeadDotProductAttention(
|
||||
config, quant_config=quant_config, prefix=f"{prefix}.attention"
|
||||
)
|
||||
self.feed_forward = ViTMLP(
|
||||
config, quant_config, prefix=f"{prefix}.feed_forward"
|
||||
)
|
||||
self.attention_norm = nn.LayerNorm(
|
||||
config.image_emb_dim,
|
||||
eps=config.image_norm_eps,
|
||||
@@ -280,12 +293,15 @@ class BlockCollection(nn.Module):
|
||||
self,
|
||||
config: VisionBackboneConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.resblocks = nn.ModuleList(
|
||||
[
|
||||
ResidualAttentionBlock(config, quant_config)
|
||||
for _ in range(config.image_num_layers)
|
||||
ResidualAttentionBlock(
|
||||
config, quant_config, prefix=f"{prefix}.resblocks.{i}"
|
||||
)
|
||||
for i in range(config.image_num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
@@ -308,6 +324,7 @@ class VisionTransformer(nn.Module):
|
||||
self,
|
||||
config: VisionBackboneConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
scale = config.image_emb_dim**-0.5
|
||||
@@ -324,7 +341,9 @@ class VisionTransformer(nn.Module):
|
||||
bias=False,
|
||||
)
|
||||
self.pre_ln = nn.LayerNorm(config.image_emb_dim, eps=config.image_norm_eps)
|
||||
self.transformer = BlockCollection(config, quant_config)
|
||||
self.transformer = BlockCollection(
|
||||
config, quant_config, prefix=f"{prefix}.transformer"
|
||||
)
|
||||
|
||||
def add_pos_emb(self, x: torch.Tensor, patch_num: int) -> torch.Tensor:
|
||||
cls_emb = self.positional_embedding[0:1]
|
||||
@@ -419,6 +438,7 @@ class MolmoAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=config.qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.tp_rank: int | None = None
|
||||
@@ -454,6 +474,7 @@ class MolmoAttention(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
def _apply_qk_norm(
|
||||
@@ -493,6 +514,7 @@ class LanguageModelMLP(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
input_dim: int | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -503,6 +525,7 @@ class LanguageModelMLP(nn.Module):
|
||||
[self.intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
# Activation function.
|
||||
self.act_fn = MulAndSilu()
|
||||
@@ -512,6 +535,7 @@ class LanguageModelMLP(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -532,6 +556,7 @@ class ImageProjectorMLP(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
input_dim: int | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -542,6 +567,7 @@ class ImageProjectorMLP(nn.Module):
|
||||
[self.intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.merged_linear",
|
||||
)
|
||||
# Activation function.
|
||||
self.act_fn = SiluAndMul()
|
||||
@@ -552,6 +578,7 @@ class ImageProjectorMLP(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -579,7 +606,9 @@ class MolmoDecoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
# MLP block.
|
||||
self.mlp = LanguageModelMLP(config, quant_config=quant_config)
|
||||
self.mlp = LanguageModelMLP(
|
||||
config, quant_config=quant_config, prefix=f"{prefix}.mlp"
|
||||
)
|
||||
|
||||
# LayerNorm
|
||||
assert config.layer_norm_type == "rms"
|
||||
@@ -643,6 +672,7 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
|
||||
config: PretrainedConfig,
|
||||
vision_config: VisionBackboneConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.vit_layers = VIT_LAYERS
|
||||
@@ -651,18 +681,24 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
|
||||
(self.image_num_patch[0] + 1) // POOLING_SIZE,
|
||||
(self.image_num_patch[1] + 1) // POOLING_SIZE,
|
||||
)
|
||||
self.image_vit = VisionTransformer(vision_config, quant_config=quant_config)
|
||||
self.image_vit = VisionTransformer(
|
||||
vision_config, quant_config=quant_config, prefix=f"{prefix}.image_vit"
|
||||
)
|
||||
self.num_prefix_tokens = self.image_vit.num_prefix_tokens
|
||||
assert self.num_prefix_tokens in {0, 1}, (
|
||||
"Only 0 or 1 prefix tokens are supported"
|
||||
)
|
||||
self.image_pooling_2d = MultiHeadDotProductAttention(
|
||||
vision_config, nlayers=len(self.vit_layers), quant_config=quant_config
|
||||
vision_config,
|
||||
nlayers=len(self.vit_layers),
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.image_pooling_2d",
|
||||
)
|
||||
self.image_projector = ImageProjectorMLP(
|
||||
config,
|
||||
input_dim=vision_config.image_emb_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.image_projector",
|
||||
)
|
||||
|
||||
image_dim = vision_config.image_emb_dim * len(self.vit_layers)
|
||||
@@ -1405,7 +1441,12 @@ class MolmoForCausalLM(
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
vision_config = VisionBackboneConfig()
|
||||
self.vision_backbone = MolmoVisionBackbone(config, vision_config, quant_config)
|
||||
self.vision_backbone = MolmoVisionBackbone(
|
||||
config,
|
||||
vision_config,
|
||||
quant_config,
|
||||
prefix=maybe_prefix(prefix, "vision_backbone"),
|
||||
)
|
||||
self.model = MolmoModel(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
|
||||
@@ -86,7 +86,11 @@ class OlmoeMoE(nn.Module):
|
||||
|
||||
# Gate always runs at half / full precision for now.
|
||||
self.gate = ReplicatedLinear(
|
||||
hidden_size, num_experts, bias=False, quant_config=None
|
||||
hidden_size,
|
||||
num_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
|
||||
@@ -272,6 +272,7 @@ class PhiMoE(nn.Module):
|
||||
bias=False,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
|
||||
@@ -56,13 +56,22 @@ class QWenMLP(nn.Module):
|
||||
intermediate_size: int,
|
||||
hidden_act: str = "silu",
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size, hidden_size, bias=False, quant_config=quant_config
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
@@ -163,7 +172,10 @@ class QWenBlock(nn.Module):
|
||||
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.mlp = QWenMLP(
|
||||
config.hidden_size, config.intermediate_size // 2, quant_config=quant_config
|
||||
config.hidden_size,
|
||||
config.intermediate_size // 2,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -864,6 +864,7 @@ class Qwen3NextDecoderLayer(nn.Module):
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
|
||||
self.input_layernorm = Qwen3NextRMSNorm(
|
||||
|
||||
@@ -109,6 +109,7 @@ class VisualAttention(nn.Module):
|
||||
bias: bool = True,
|
||||
kdim: int | None = None,
|
||||
vdim: int | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
@@ -128,8 +129,12 @@ class VisualAttention(nn.Module):
|
||||
assert self._qkv_same_embed_dim, (
|
||||
"Visual Attention implementation only supports self-attention"
|
||||
)
|
||||
self.in_proj = ReplicatedLinear(embed_dim, 3 * embed_dim)
|
||||
self.out_proj = ReplicatedLinear(embed_dim, embed_dim)
|
||||
self.in_proj = ReplicatedLinear(
|
||||
embed_dim, 3 * embed_dim, prefix=f"{prefix}.in_proj"
|
||||
)
|
||||
self.out_proj = ReplicatedLinear(
|
||||
embed_dim, embed_dim, prefix=f"{prefix}.out_proj"
|
||||
)
|
||||
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
||||
|
||||
def forward(
|
||||
@@ -214,10 +219,15 @@ class QwenVLMLP(nn.Module):
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.c_fc = ColumnParallelLinear(
|
||||
hidden_size, intermediate_size, bias=True, quant_config=quant_config
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_fc",
|
||||
)
|
||||
self.act_fn = get_act_fn("gelu")
|
||||
self.c_proj = RowParallelLinear(
|
||||
@@ -225,6 +235,7 @@ class QwenVLMLP(nn.Module):
|
||||
hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.c_proj",
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -242,17 +253,19 @@ class VisualAttentionBlock(nn.Module):
|
||||
mlp_ratio: float = 4.0,
|
||||
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.ln_1 = norm_layer(d_model)
|
||||
self.ln_2 = norm_layer(d_model)
|
||||
mlp_width = int(d_model * mlp_ratio)
|
||||
self.attn = VisualAttention(d_model, n_head)
|
||||
self.attn = VisualAttention(d_model, n_head, prefix=f"{prefix}.attn")
|
||||
self.mlp = QwenVLMLP(
|
||||
hidden_size=d_model,
|
||||
intermediate_size=mlp_width,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
|
||||
def attention(
|
||||
@@ -282,6 +295,7 @@ class TransformerBlock(nn.Module):
|
||||
mlp_ratio: float = 4.0,
|
||||
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
@@ -295,8 +309,9 @@ class TransformerBlock(nn.Module):
|
||||
mlp_ratio,
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.resblocks.{i}",
|
||||
)
|
||||
for _ in range(layers)
|
||||
for i in range(layers)
|
||||
]
|
||||
)
|
||||
|
||||
@@ -327,6 +342,7 @@ class VisionTransformer(nn.Module):
|
||||
output_dim: int = 512,
|
||||
image_start_id: int = 151857,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -356,6 +372,7 @@ class VisionTransformer(nn.Module):
|
||||
mlp_ratio,
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.transformer",
|
||||
)
|
||||
|
||||
self.attn_pool = Resampler2(
|
||||
@@ -366,6 +383,7 @@ class VisionTransformer(nn.Module):
|
||||
norm_layer=norm_layer,
|
||||
adaptive=False,
|
||||
do_post_projection=False,
|
||||
prefix=f"{prefix}.attn_pool",
|
||||
).to(
|
||||
device=self.positional_embedding.device,
|
||||
dtype=self.positional_embedding.dtype,
|
||||
@@ -413,7 +431,9 @@ class QwenVLModel(QWenModel):
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.visual = VisionTransformer(**config.visual, quant_config=quant_config)
|
||||
self.visual = VisionTransformer(
|
||||
**config.visual, quant_config=quant_config, prefix=f"{prefix}.visual"
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
|
||||
@@ -86,7 +86,13 @@ class Zamba2LoRA(nn.Module):
|
||||
B_class = MergedColumnParallelLinear
|
||||
else:
|
||||
B_class = ColumnParallelLinear
|
||||
self.B = B_class(rank, output_dim, bias=False, quant_config=quant_config)
|
||||
self.B = B_class(
|
||||
rank,
|
||||
output_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.B",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -346,6 +352,7 @@ class Zamba2MLP(nn.Module):
|
||||
config.adapter_rank,
|
||||
2 * [self.intermediate_size],
|
||||
quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj_adapter_list.{block_idx}",
|
||||
)
|
||||
else:
|
||||
gate_up_proj_adapter = nn.Identity()
|
||||
|
||||
Reference in New Issue
Block a user