[V1] EP/TP MoE + DP Attention (#13931)

This commit is contained in:
Tyler Michael Smith
2025-03-05 00:27:26 -05:00
committed by GitHub
parent 0a995d5434
commit 72c62eae5f
17 changed files with 250 additions and 75 deletions

View File

@@ -65,6 +65,7 @@ class DbrxExperts(FusedMoE):
config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
params_dtype: Optional[torch.dtype] = None,
prefix: str = "",
):
super().__init__(
num_experts=config.ffn_config.moe_num_experts,
@@ -76,6 +77,7 @@ class DbrxExperts(FusedMoE):
renormalize=True,
quant_config=quant_config,
tp_size=get_tensor_model_parallel_world_size(),
prefix=prefix,
)
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
@@ -139,6 +141,7 @@ class DbrxMoE(nn.Module):
config: DbrxConfig,
quant_config: Optional[QuantizationConfig] = None,
params_dtype: Optional[torch.dtype] = None,
prefix: str = "",
):
super().__init__()
self.d_model = config.d_model
@@ -150,7 +153,8 @@ class DbrxMoE(nn.Module):
self.experts = DbrxExperts(config=config,
quant_config=quant_config,
params_dtype=self.params_dtype)
params_dtype=self.params_dtype,
prefix=f"{prefix}.experts")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape
@@ -291,7 +295,7 @@ class DbrxBlock(nn.Module):
cache_config,
quant_config,
prefix=f"{prefix}.norm_attn_norm")
self.ffn = DbrxMoE(config, quant_config)
self.ffn = DbrxMoE(config, quant_config, prefix=f"{prefix}.ffn")
def forward(
self,