[V1] EP/TP MoE + DP Attention (#13931)
This commit is contained in:
committed by
GitHub
parent
0a995d5434
commit
72c62eae5f
@@ -47,7 +47,8 @@ class JambaMoE(nn.Module):
|
||||
top_k: Optional[int] = None,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.num_total_experts = num_experts or config.num_experts
|
||||
self.top_k = top_k or config.num_experts_per_tok
|
||||
@@ -70,7 +71,8 @@ class JambaMoE(nn.Module):
|
||||
reduce_results=True,
|
||||
renormalize=False,
|
||||
use_grouped_topk=False,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
orig_shape = hidden_states.shape
|
||||
@@ -92,13 +94,15 @@ class JambaMLP(JambaMoE):
|
||||
config: JambaConfig,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__(config,
|
||||
num_experts=1,
|
||||
top_k=1,
|
||||
params_dtype=params_dtype,
|
||||
tp_size=tp_size,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
|
||||
class JambaMambaDecoderLayer(nn.Module):
|
||||
@@ -109,6 +113,7 @@ class JambaMambaDecoderLayer(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
is_lora_enabled: Optional[bool] = False,
|
||||
prefix: str = "",
|
||||
**kwargs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -129,7 +134,9 @@ class JambaMambaDecoderLayer(nn.Module):
|
||||
|
||||
num_experts = config.layers_num_experts[layer_idx]
|
||||
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
|
||||
self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
|
||||
self.feed_forward = ffn_layer_class(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward")
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
|
||||
@@ -211,7 +218,9 @@ class JambaAttentionDecoderLayer(nn.Module):
|
||||
|
||||
num_experts = config.layers_num_experts[layer_idx]
|
||||
ffn_layer_class = JambaMoE if num_experts > 1 else JambaMLP
|
||||
self.feed_forward = ffn_layer_class(config, quant_config=quant_config)
|
||||
self.feed_forward = ffn_layer_class(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward")
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.pre_ff_layernorm = RMSNorm(config.hidden_size,
|
||||
|
||||
Reference in New Issue
Block a user