[V1] EP/TP MoE + DP Attention (#13931)
This commit is contained in:
committed by
GitHub
parent
0a995d5434
commit
72c62eae5f
@@ -46,7 +46,7 @@ class AriaImagePixelInputs(TypedDict):
|
||||
pixel_values: torch.Tensor
|
||||
pixel_mask: Optional[torch.Tensor]
|
||||
"""
|
||||
Shape:
|
||||
Shape:
|
||||
pixel_values: `(batch_size * num_images, num_channels, height, width)`
|
||||
pixel_mask: `(batch_size * num_images, height, width)`
|
||||
"""
|
||||
@@ -135,11 +135,11 @@ class AriaProjector(nn.Module):
|
||||
query numbers,
|
||||
e.g., {1225: 128, 4900: 256}. This allows for different query sizes
|
||||
based on image resolution.
|
||||
embed_dim (int): Embedding dimension.
|
||||
num_heads (int): Number of attention heads.
|
||||
kv_dim (int): Dimension of key and value.
|
||||
ff_dim (int): Hidden dimension of the feed-forward network.
|
||||
output_dim (int): Output dimension.
|
||||
embed_dim (int): Embedding dimension.
|
||||
num_heads (int): Number of attention heads.
|
||||
kv_dim (int): Dimension of key and value.
|
||||
ff_dim (int): Hidden dimension of the feed-forward network.
|
||||
output_dim (int): Output dimension.
|
||||
norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm.
|
||||
|
||||
Outputs:
|
||||
@@ -239,6 +239,7 @@ class AriaTextMoELayer(nn.Module):
|
||||
self,
|
||||
config: AriaTextConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -254,6 +255,7 @@ class AriaTextMoELayer(nn.Module):
|
||||
intermediate_size=config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
reduce_results=True,
|
||||
prefix=f"{prefix}.experts",
|
||||
)
|
||||
self.shared_experts = LlamaMLP(
|
||||
config.hidden_size,
|
||||
@@ -301,7 +303,9 @@ class AriaTextDecoderLayer(LlamaDecoderLayer):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config, cache_config, quant_config, prefix)
|
||||
self.mlp = AriaTextMoELayer(config, quant_config=quant_config)
|
||||
self.mlp = AriaTextMoELayer(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
|
||||
|
||||
class AriaTextModel(LlamaModel, SupportsQuant):
|
||||
|
||||
Reference in New Issue
Block a user