[V1] EP/TP MoE + DP Attention (#13931)

This commit is contained in:
Tyler Michael Smith
2025-03-05 00:27:26 -05:00
committed by GitHub
parent 0a995d5434
commit 72c62eae5f
17 changed files with 250 additions and 75 deletions

View File

@@ -46,7 +46,7 @@ class AriaImagePixelInputs(TypedDict):
pixel_values: torch.Tensor
pixel_mask: Optional[torch.Tensor]
"""
Shape:
Shape:
pixel_values: `(batch_size * num_images, num_channels, height, width)`
pixel_mask: `(batch_size * num_images, height, width)`
"""
@@ -135,11 +135,11 @@ class AriaProjector(nn.Module):
query numbers,
e.g., {1225: 128, 4900: 256}. This allows for different query sizes
based on image resolution.
embed_dim (int): Embedding dimension.
num_heads (int): Number of attention heads.
kv_dim (int): Dimension of key and value.
ff_dim (int): Hidden dimension of the feed-forward network.
output_dim (int): Output dimension.
embed_dim (int): Embedding dimension.
num_heads (int): Number of attention heads.
kv_dim (int): Dimension of key and value.
ff_dim (int): Hidden dimension of the feed-forward network.
output_dim (int): Output dimension.
norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm.
Outputs:
@@ -239,6 +239,7 @@ class AriaTextMoELayer(nn.Module):
self,
config: AriaTextConfig,
quant_config: Optional[QuantizationConfig],
prefix: str = "",
) -> None:
super().__init__()
self.config = config
@@ -254,6 +255,7 @@ class AriaTextMoELayer(nn.Module):
intermediate_size=config.intermediate_size,
quant_config=quant_config,
reduce_results=True,
prefix=f"{prefix}.experts",
)
self.shared_experts = LlamaMLP(
config.hidden_size,
@@ -301,7 +303,9 @@ class AriaTextDecoderLayer(LlamaDecoderLayer):
prefix: str = "",
) -> None:
super().__init__(config, cache_config, quant_config, prefix)
self.mlp = AriaTextMoELayer(config, quant_config=quant_config)
self.mlp = AriaTextMoELayer(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
class AriaTextModel(LlamaModel, SupportsQuant):