[GPTOSS][DP/EP][Marlin] Enable GPTOSS DP/EP using Marlin kernels (#25488)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-10-03 20:13:13 -04:00
committed by GitHub
parent 767cbb011d
commit 7ef40bb983
9 changed files with 264 additions and 101 deletions

View File

@@ -187,6 +187,16 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \
supports_router_weight and supports_activation
def marlin_moe_intermediate_size(w1_packed: torch.Tensor,
w2_packed: torch.Tensor):
"""
Given Marlin packed weight matrices w1_packed, and w2_packed,
return the MoE intermediate size N
"""
marlin_tile_size = 16
return w2_packed.size(1) * marlin_tile_size
def marlin_make_workspace(output_size_per_partition: int,
device: torch.device) -> torch.Tensor:
max_workspace_size = (output_size_per_partition //