[GPTOSS][DP/EP][Marlin] Enable GPTOSS DP/EP using Marlin kernels (#25488)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
committed by
GitHub
parent
767cbb011d
commit
7ef40bb983
@@ -187,6 +187,16 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \
|
||||
supports_router_weight and supports_activation
|
||||
|
||||
|
||||
def marlin_moe_intermediate_size(w1_packed: torch.Tensor,
|
||||
w2_packed: torch.Tensor):
|
||||
"""
|
||||
Given Marlin packed weight matrices w1_packed, and w2_packed,
|
||||
return the MoE intermediate size N
|
||||
"""
|
||||
marlin_tile_size = 16
|
||||
return w2_packed.size(1) * marlin_tile_size
|
||||
|
||||
|
||||
def marlin_make_workspace(output_size_per_partition: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
max_workspace_size = (output_size_per_partition //
|
||||
|
||||
Reference in New Issue
Block a user