[Model] Optimize nemotron_h implementation (#19249)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
# Adapted from https://github.com/vllm-project/vllm/blob/94d8ec8d2bcb4ec55e33022b313c7e978edf05e1/vllm/model_executor/models/bamba.py
|
# Adapted from https://github.com/vllm-project/vllm/blob/94d8ec8d2bcb4ec55e33022b313c7e978edf05e1/vllm/model_executor/models/bamba.py
|
||||||
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
# Copyright 2024 HuggingFace Inc. team. All rights reserved.
|
||||||
@@ -29,7 +30,7 @@ from vllm.distributed.parallel_state import get_pp_group
|
|||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
@@ -63,19 +64,22 @@ class NemotronHMLP(nn.Module):
|
|||||||
config: NemotronHConfig,
|
config: NemotronHConfig,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.up_proj = MergedColumnParallelLinear(
|
self.up_proj = ColumnParallelLinear(
|
||||||
input_size=config.hidden_size,
|
input_size=config.hidden_size,
|
||||||
output_sizes=[config.intermediate_size],
|
output_size=config.intermediate_size,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.up_proj",
|
||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
input_size=config.intermediate_size,
|
input_size=config.intermediate_size,
|
||||||
output_size=config.hidden_size,
|
output_size=config.hidden_size,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
)
|
)
|
||||||
self.act_fn = ReLUSquaredActivation()
|
self.act_fn = ReLUSquaredActivation()
|
||||||
|
|
||||||
@@ -99,9 +103,12 @@ class NemotronHMLPDecoderLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.mixer = NemotronHMLP(config,
|
self.mixer = NemotronHMLP(
|
||||||
quant_config=quant_config,
|
config,
|
||||||
bias=config.mlp_bias)
|
quant_config=quant_config,
|
||||||
|
bias=config.mlp_bias,
|
||||||
|
prefix=f"{prefix}.mixer",
|
||||||
|
)
|
||||||
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
@@ -207,12 +214,14 @@ class NemotronHAttention(nn.Module):
|
|||||||
self.total_num_kv_heads,
|
self.total_num_kv_heads,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
)
|
)
|
||||||
self.o_proj = RowParallelLinear(
|
self.o_proj = RowParallelLinear(
|
||||||
self.total_num_heads * self.head_dim,
|
self.total_num_heads * self.head_dim,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.attn = Attention(
|
self.attn = Attention(
|
||||||
@@ -253,7 +262,7 @@ class NemotronHAttentionDecoderLayer(nn.Module):
|
|||||||
layer_idx,
|
layer_idx,
|
||||||
cache_config,
|
cache_config,
|
||||||
quant_config,
|
quant_config,
|
||||||
prefix,
|
prefix=f"{prefix}.mixer",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@@ -435,7 +444,6 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
|||||||
"k_proj",
|
"k_proj",
|
||||||
"v_proj",
|
"v_proj",
|
||||||
],
|
],
|
||||||
"gate_up_proj": ["up_proj", "down_proj"]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# LoRA specific attributes
|
# LoRA specific attributes
|
||||||
|
|||||||
Reference in New Issue
Block a user