[gpt-oss] fix model config with hf_config (#22401)

Signed-off-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
Yongye Zhu
2025-08-06 18:04:04 -07:00
committed by GitHub
parent 19c9365aa4
commit 5c7cc33f4d

View File

@@ -61,9 +61,9 @@ class OAIAttention(nn.Module):
"original_max_position_embeddings": "original_max_position_embeddings":
config.rope_scaling["original_max_position_embeddings"], config.rope_scaling["original_max_position_embeddings"],
"beta_fast": "beta_fast":
config.rope_ntk_beta, config.rope_scaling["beta_fast"],
"beta_slow": "beta_slow":
config.rope_ntk_alpha, config.rope_scaling["beta_slow"],
}, },
is_neox_style=True, is_neox_style=True,
) )
@@ -154,7 +154,7 @@ class MLPBlock(torch.nn.Module):
dtype=torch.bfloat16) dtype=torch.bfloat16)
assert config.intermediate_size % self.world_size == 0 assert config.intermediate_size % self.world_size == 0
self.experts = FusedMoE(num_experts=config.num_local_experts, self.experts = FusedMoE(num_experts=config.num_local_experts,
top_k=config.num_experts_per_token, top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
reduce_results=True, reduce_results=True,