[GPTOSS][DP/EP][Marlin] Enable GPTOSS DP/EP using Marlin kernels (#25488)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-10-03 20:13:13 -04:00
committed by GitHub
parent 767cbb011d
commit 7ef40bb983
9 changed files with 264 additions and 101 deletions

View File

@@ -55,46 +55,6 @@ from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
#
def _moe_problem_size(
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
) -> tuple[int, int, int, int, int]:
"""
Extract the MoE problem size from the given tensor arguments:
- a: The hidden states, input to the MoE layer.
- w1: The first set of expert weights.
- w2: The second set of expert weights.
- topk_ids: The topk ids.
Note: extracting the problem shape from the weight and activation tensors is
not obvious. It needs to be done this way specifically due to subtle issues
with particular kernels, e.g. the int4 kernels divide the trailing dimension
by two, so it's not "correct" to extract N or K from the trailing dimension
of w1 or w2. Similarly, some kernels transpose the weights, so this needs
to be kept in mind.
"""
assert w1.dim() == 3 and w2.dim() == 3
E, N, _ = w1.size()
K = a1.size(-1)
if a1.dim() == 2:
# Make sure we are using the correct a1 (pre-permute).
assert topk_ids.size(0) == a1.size(0), \
f"{topk_ids.size(0)} != {a1.size(0)}"
M = a1.size(0)
else:
assert a1.dim() == 3
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
M = a1.size(1) # This is max_num_tokens
assert topk_ids.dim() == 2
topk = topk_ids.size(1)
return E, M, N, K, topk
class FusedMoEActivationFormat(Enum):
"""
The standard activation format (num_tokens, hidden dim).
@@ -391,6 +351,50 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
"""
raise NotImplementedError
def moe_problem_size(
self,
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
) -> tuple[int, int, int, int, int]:
"""
Extract the MoE problem size from the given tensor arguments:
- a: The hidden states, input to the MoE layer.
- w1: The first set of expert weights.
- w2: The second set of expert weights.
- topk_ids: The topk ids.
Note: extracting the problem shape from the weight and activation
tensors is not obvious. It needs to be done this way specifically
due to subtle issues with particular kernels, e.g. the int4 kernels
divide the trailing dimension by two, so it's not "correct" to
extract N or K from the trailing dimension of w1 or w2. Similarly,
some kernels transpose the weights, so this needs to be kept in mind.
Note: This implementation covers most cases. However, if experts
require a specialized implementation, like MarlinExperts, they are free
to override this function.
"""
assert w1.dim() == 3 and w2.dim() == 3
E, N, _ = w1.size()
K = a1.size(-1)
if a1.dim() == 2:
# Make sure we are using the correct a1 (pre-permute).
assert topk_ids.size(0) == a1.size(0), \
f"{topk_ids.size(0)} != {a1.size(0)}"
M = a1.size(0)
else:
assert a1.dim() == 3
assert a1.size(0) == E, f"{a1.size(0)} == {E}"
M = a1.size(1) # This is max_num_tokens
assert topk_ids.dim() == 2
topk = topk_ids.size(1)
return E, M, N, K, topk
#
# Various helpers for accessing quantization parameters from the
# quant_config.
@@ -674,7 +678,8 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input: bool,
) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
_, M, N, K, top_k = self.fused_experts.moe_problem_size(
a1q, w1, w2, topk_ids)
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.fused_experts.workspace_shapes(
@@ -737,7 +742,8 @@ class FusedMoEModularKernel(torch.nn.Module):
apply_router_weight_on_input: bool,
) -> torch.Tensor:
_, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids)
_, M, N, K, top_k = self.fused_experts.moe_problem_size(
a1q, w1, w2, topk_ids)
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_chunks = cdiv(M, CHUNK_SIZE)