[Model] NemotronH Support (#22349)

Signed-off-by: Daniel Afrimi <danielafrimi8@gmail.com>
This commit is contained in:
danielafrimi
2025-08-11 14:09:24 +03:00
committed by GitHub
parent 951b038298
commit 14a5d903ab
2 changed files with 23 additions and 7 deletions

View File

@@ -64,20 +64,32 @@ class NemotronHMLP(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
hybrid_override_pattern = config.hybrid_override_pattern
mlp_index = hybrid_override_pattern[:layer_idx + 1].count("-") - 1
if isinstance(config.intermediate_size, list):
if len(config.intermediate_size) == 1:
intermediate_size = config.intermediate_size[0]
else:
intermediate_size = config.intermediate_size[mlp_index]
else:
intermediate_size = config.intermediate_size
self.up_proj = ColumnParallelLinear(
input_size=config.hidden_size,
output_size=config.intermediate_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.up_proj",
)
self.down_proj = RowParallelLinear(
input_size=config.intermediate_size,
input_size=intermediate_size,
output_size=config.hidden_size,
bias=bias,
quant_config=quant_config,
@@ -110,6 +122,7 @@ class NemotronHMLPDecoderLayer(nn.Module):
quant_config=quant_config,
bias=config.mlp_bias,
prefix=f"{prefix}.mixer",
layer_idx=layer_idx,
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -146,7 +159,7 @@ class NemotronHMambaDecoderLayer(nn.Module):
hidden_size=config.hidden_size,
ssm_state_size=config.ssm_state_size,
conv_kernel_size=config.conv_kernel,
intermediate_size=config.expand * config.hidden_size,
intermediate_size=config.mamba_num_heads * config.mamba_head_dim,
use_conv_bias=config.use_conv_bias,
use_bias=config.use_bias,
n_groups=config.n_groups,
@@ -205,7 +218,10 @@ class NemotronHAttention(nn.Module):
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.hidden_size // self.total_num_heads
if hasattr(config, "head_dim") and config.head_dim is not None:
self.head_dim = config.head_dim
else:
self.head_dim = config.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
@@ -481,7 +497,7 @@ class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
"""
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
intermediate_size = hf_config.expand * hf_config.hidden_size
intermediate_size = hf_config.mamba_num_heads * hf_config.mamba_head_dim
return MambaStateShapeCalculator.mamba2_state_shape(
intermediate_size=intermediate_size,