[V1] EP/TP MoE + DP Attention (#13931)

This commit is contained in:
Tyler Michael Smith
2025-03-05 00:27:26 -05:00
committed by GitHub
parent 0a995d5434
commit 72c62eae5f
17 changed files with 250 additions and 75 deletions

View File

@@ -47,7 +47,8 @@ class JambaMoE(nn.Module):
top_k: Optional[int] = None,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.num_total_experts = num_experts or config.num_experts
self.top_k = top_k or config.num_experts_per_tok
@@ -70,7 +71,8 @@ class JambaMoE(nn.Module):
reduce_results=True,
renormalize=False,
use_grouped_topk=False,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.experts")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape
@@ -92,13 +94,15 @@ class JambaMLP(JambaMoE):
config: JambaConfig,
params_dtype: Optional[torch.dtype] = None,
tp_size: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__(config,
num_experts=1,
top_k=1,
params_dtype=params_dtype,
tp_size=tp_size,
quant_config=quant_config)
quant_config=quant_config,
prefix=prefix)
class JambaMambaDecoderLayer(nn.Module):
@@ -109,6 +113,7 @@ class JambaMambaDecoderLayer(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
is_lora_enabled: Optional[bool] = False,
prefix: str = "",
**kwargs) -> None:
super().__init__()
self.config = config
@@ -129,7 +134,9 @@ class JambaMambaDecoderLayer(nn.Module):
num_experts = config.layers_num_experts[layer_idx]
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
self.feed_forward = ffn_layer_class(config,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward")
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
@@ -211,7 +218,9 @@ class JambaAttentionDecoderLayer(nn.Module):
num_experts = config.layers_num_experts[layer_idx]
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
self.feed_forward = ffn_layer_class(config,
quant_config=quant_config,
prefix=f"{prefix}.feed_forward")
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_ff_layernorm = RMSNorm(config.hidden_size,