[V1] EP/TP MoE + DP Attention (#13931)
This commit is contained in:
committed by
GitHub
parent
0a995d5434
commit
72c62eae5f
@@ -65,6 +65,7 @@ class DbrxExperts(FusedMoE):
|
||||
config: DbrxConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__(
|
||||
num_experts=config.ffn_config.moe_num_experts,
|
||||
@@ -76,6 +77,7 @@ class DbrxExperts(FusedMoE):
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
tp_size=get_tensor_model_parallel_world_size(),
|
||||
prefix=prefix,
|
||||
)
|
||||
self.config = config
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
@@ -139,6 +141,7 @@ class DbrxMoE(nn.Module):
|
||||
config: DbrxConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
@@ -150,7 +153,8 @@ class DbrxMoE(nn.Module):
|
||||
|
||||
self.experts = DbrxExperts(config=config,
|
||||
quant_config=quant_config,
|
||||
params_dtype=self.params_dtype)
|
||||
params_dtype=self.params_dtype,
|
||||
prefix=f"{prefix}.experts")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
orig_shape = hidden_states.shape
|
||||
@@ -291,7 +295,7 @@ class DbrxBlock(nn.Module):
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.norm_attn_norm")
|
||||
self.ffn = DbrxMoE(config, quant_config)
|
||||
self.ffn = DbrxMoE(config, quant_config, prefix=f"{prefix}.ffn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user