[GPTOSS][DP/EP][Marlin] Enable GPTOSS Batched DP/EP using Marlin kernels (#25997)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath
2025-10-16 15:53:11 -04:00
committed by GitHub
parent 2ed8b6b3d0
commit fb0571b077
12 changed files with 1174 additions and 335 deletions

View File

@@ -7,6 +7,8 @@ Run `pytest tests/kernels/test_moe.py`.
import functools
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
import pytest
import torch
@@ -26,7 +28,10 @@ from vllm.model_executor.layers.fused_moe.config import (
int4_w4a16_moe_quant_config,
int8_w8a16_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
batched_fused_marlin_moe,
fused_marlin_moe,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk,
modular_triton_fused_moe,
@@ -564,6 +569,105 @@ def marlin_moe_generate_valid_test_cases():
return cases
@dataclass
class MarlinMoEWeightData:
w_ref: torch.Tensor
qweight: torch.Tensor
scales: torch.Tensor
global_scale: torch.Tensor | None
g_idx: torch.Tensor | None
zeros: torch.Tensor | None
sort_indices: torch.Tensor | None
marlin_bias: torch.Tensor | None
@staticmethod
def make(
w: torch.Tensor,
quant_type: ScalarType,
group_size: int,
act_order: bool | None = None,
bias: torch.Tensor | None = None,
) -> "MarlinMoEWeightData":
assert w.ndim == 3
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
k = w.shape[-1]
w_ref_l: list[torch.Tensor] = []
qweight_l: list[torch.Tensor] = []
scales_l: list[torch.Tensor] = []
global_scale_l: list[torch.Tensor] = []
zeros_l: list[torch.Tensor] = []
g_idx_l: list[torch.Tensor] = []
sort_indices_l: list[torch.Tensor] = []
bias_l: list[torch.Tensor] = []
for i in range(w.shape[0]):
if quant_type == scalar_types.float4_e2m1f:
if group_size == 16:
w_ref, qweight, scales, global_scale = (
rand_marlin_weight_nvfp4_like(w[i], group_size)
)
else:
w_ref, qweight, scales = rand_marlin_weight_mxfp4_like(
w[i], group_size
)
global_scale = None
w_ref_l.append(w_ref.T)
qweight_l.append(qweight)
scales_l.append(scales)
if global_scale is not None:
global_scale_l.append(global_scale)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref, qweight, scales = marlin_quant_fp8_torch(w[i], group_size)
w_ref_l.append(w_ref.T)
qweight_l.append(qweight)
scales_l.append(scales)
elif has_zp:
w_ref, qweight, scales, zeros = awq_marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size
)
w_ref_l.append(w_ref.T)
qweight_l.append(qweight)
scales_l.append(scales)
zeros_l.append(zeros)
else:
test_perm = torch.randperm(k)
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref_l.append(w_ref.T)
qweight_l.append(qweight)
scales_l.append(scales)
g_idx_l.append(g_idx)
sort_indices_l.append(sort_indices)
if bias is not None:
bias_l.append(marlin_permute_bias(bias[i]))
w_ref = stack_and_dev(w_ref_l)
qweight = stack_and_dev(qweight_l).contiguous()
scales = stack_and_dev(scales_l)
global_scale = stack_and_dev(global_scale_l) if global_scale_l else None
g_idx = stack_and_dev(g_idx_l) if g_idx_l else None
zeros = stack_and_dev(zeros_l) if zeros_l else None
sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None
marlin_bias = stack_and_dev(bias_l) if bias_l else None
return MarlinMoEWeightData(
w_ref=w_ref,
qweight=qweight,
scales=scales,
global_scale=global_scale,
g_idx=g_idx,
zeros=zeros,
sort_indices=sort_indices,
marlin_bias=marlin_bias,
)
@pytest.mark.flaky(reruns=2)
@pytest.mark.parametrize(
("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"),
@@ -584,7 +688,6 @@ def test_fused_marlin_moe(
is_k_full: bool,
):
torch.cuda.manual_seed(0)
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
@@ -600,152 +703,44 @@ def test_fused_marlin_moe(
else:
e_map = None
w_ref1_l = []
qweight1_l = []
scales1_l = []
global_scale1_l = []
zeros1_l = []
g_idx1_l = []
sort_indices1_l = []
w1_data = MarlinMoEWeightData.make(
w=w1, quant_type=quant_type, group_size=group_size, act_order=act_order
)
for i in range(w1.shape[0]):
if quant_type == scalar_types.float4_e2m1f:
if group_size == 16:
w_ref1, qweight1, scales1, global_scale1 = (
rand_marlin_weight_nvfp4_like(w1[i], group_size)
)
else:
w_ref1, qweight1, scales1 = rand_marlin_weight_mxfp4_like(
w1[i], group_size
)
global_scale1 = None
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
if global_scale1 is not None:
global_scale1_l.append(global_scale1)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(w1[i], group_size)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
elif has_zp:
w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size
)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
zeros1_l.append(zeros1)
else:
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
global_scale1 = stack_and_dev(global_scale1_l) if global_scale1_l else None
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
w_ref2_l = []
qweight2_l = []
scales2_l = []
global_scale2_l = []
zeros2_l = []
g_idx2_l = []
sort_indices2_l = []
for i in range(w2.shape[0]):
if quant_type == scalar_types.float4_e2m1f:
if group_size == 16:
w_ref2, qweight2, scales2, global_scale2 = (
rand_marlin_weight_nvfp4_like(w2[i], group_size)
)
else:
w_ref2, qweight2, scales2 = rand_marlin_weight_mxfp4_like(
w2[i], group_size
)
global_scale2 = None
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
if global_scale2 is not None:
global_scale2_l.append(global_scale2)
elif quant_type == scalar_types.float8_e4m3fn:
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(w2[i], group_size)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
elif has_zp:
w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size
)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
zeros2_l.append(zeros2)
else:
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
global_scale2 = stack_and_dev(global_scale2_l) if global_scale2_l else None
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
w2_data = MarlinMoEWeightData.make(
w=w2, quant_type=quant_type, group_size=group_size, act_order=act_order
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map)
torch_output = torch_moe(
a, w1_data.w_ref, w2_data.w_ref, score, topk, expert_map=e_map
)
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,
w1_data.qweight,
w2_data.qweight,
None,
None,
scales1,
scales2,
w1_data.scales,
w2_data.scales,
score,
topk_weights,
topk_ids,
global_num_experts=e,
expert_map=e_map,
global_scale1=global_scale1,
global_scale2=global_scale2,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=zeros1,
w2_zeros=zeros2,
global_scale1=w1_data.global_scale,
global_scale2=w2_data.global_scale,
g_idx1=w1_data.g_idx,
g_idx2=w2_data.g_idx,
sort_indices1=w1_data.sort_indices,
sort_indices2=w2_data.sort_indices,
w1_zeros=w1_data.zeros,
w2_zeros=w2_data.zeros,
quant_type_id=quant_type.id,
is_k_full=is_k_full,
)
@@ -773,92 +768,52 @@ def test_fused_marlin_moe_with_bias(m):
b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10
b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10
b_bias1_l = []
w_ref1_l = []
qweight1_l = []
scales1_l = []
g_idx1_l = []
sort_indices1_l = []
w1_data = MarlinMoEWeightData.make(
w=w1,
quant_type=quant_type,
group_size=group_size,
act_order=act_order,
bias=b_bias1,
)
for i in range(w1.shape[0]):
test_perm = torch.randperm(k)
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref1_l.append(w_ref1.T)
qweight1_l.append(qweight1)
scales1_l.append(scales1)
g_idx1_l.append(g_idx1)
sort_indices1_l.append(sort_indices1)
b_bias1_l.append(marlin_permute_bias(b_bias1[i]))
w_ref1 = stack_and_dev(w_ref1_l)
qweight1 = stack_and_dev(qweight1_l).contiguous()
scales1 = stack_and_dev(scales1_l)
global_scale1 = None
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
zeros1 = None
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
marlin_bias1 = stack_and_dev(b_bias1_l) if b_bias1_l else None
b_bias2_l = []
w_ref2_l = []
qweight2_l = []
scales2_l = []
g_idx2_l = []
sort_indices2_l = []
for i in range(w2.shape[0]):
test_perm = torch.randperm(n)
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
)
w_ref2_l.append(w_ref2.T)
qweight2_l.append(qweight2)
scales2_l.append(scales2)
g_idx2_l.append(g_idx2)
sort_indices2_l.append(sort_indices2)
b_bias2_l.append(marlin_permute_bias(b_bias2[i]))
w_ref2 = stack_and_dev(w_ref2_l)
qweight2 = stack_and_dev(qweight2_l).contiguous()
scales2 = stack_and_dev(scales2_l)
global_scale2 = None
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
zeros2 = None
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
marlin_bias2 = stack_and_dev(b_bias2_l) if b_bias2_l else None
w2_data = MarlinMoEWeightData.make(
w=w2,
quant_type=quant_type,
group_size=group_size,
act_order=act_order,
bias=b_bias2,
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, b_bias2)
torch_output = torch_moe(
a, w1_data.w_ref, w2_data.w_ref, score, topk, b_bias1, b_bias2
)
marlin_output = fused_marlin_moe(
a,
qweight1,
qweight2,
marlin_bias1,
marlin_bias2,
scales1,
scales2,
w1_data.qweight,
w2_data.qweight,
w1_data.marlin_bias,
w2_data.marlin_bias,
w1_data.scales,
w2_data.scales,
score,
topk_weights,
topk_ids,
global_num_experts=e,
expert_map=None,
global_scale1=global_scale1,
global_scale2=global_scale2,
g_idx1=g_idx1,
g_idx2=g_idx2,
sort_indices1=sort_indices1,
sort_indices2=sort_indices2,
w1_zeros=zeros1,
w2_zeros=zeros2,
global_scale1=w1_data.global_scale,
global_scale2=w2_data.global_scale,
g_idx1=w1_data.g_idx,
g_idx2=w2_data.g_idx,
sort_indices1=w1_data.sort_indices,
sort_indices2=w2_data.sort_indices,
w1_zeros=w1_data.zeros,
w2_zeros=w2_data.zeros,
quant_type_id=quant_type.id,
is_k_full=is_k_full,
)
@@ -895,6 +850,41 @@ def test_moe_align_block_size_opcheck():
)
def test_batched_moe_align_block_size_opcheck():
max_tokens_per_batch = 512
num_experts = 4
block_size = 16
expert_num_tokens = torch.randint(
low=0,
high=max_tokens_per_batch,
size=(num_experts,),
dtype=torch.int32,
device="cuda",
)
max_num_tokens_padded = num_experts * max(max_tokens_per_batch, block_size)
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda")
assert max_num_tokens_padded % block_size == 0
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda")
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda")
opcheck(
torch.ops._moe_C.batched_moe_align_block_size,
(
max_tokens_per_batch,
block_size,
expert_num_tokens,
sorted_ids,
expert_ids,
num_tokens_post_pad,
),
)
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("k", [128, 511, 1024])
@@ -979,3 +969,171 @@ def test_cpu_fused_moe_basic(m, n, k, e, topk, dtype, with_bias, activation):
else:
atol = 5e-2
torch.testing.assert_close(out, ref, atol=atol, rtol=0)
@pytest.mark.parametrize("m", [16, 32, 64])
@pytest.mark.parametrize("n", [128])
@pytest.mark.parametrize("k", [128])
@pytest.mark.parametrize("e", [8, 12, 16, 32])
@pytest.mark.parametrize("topk", [2, 4])
@pytest.mark.parametrize("max_tokens_per_batch", [16, 32, 64])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_batched_fused_marlin_moe(
m: int, n: int, k: int, e: int, topk: int, max_tokens_per_batch: int
):
print(
f"testing m={m}, n={n}, k={k}, e={e}, "
f"topk={topk}, "
f"max_tokens_per_batch={max_tokens_per_batch}"
)
torch.cuda.manual_seed(0)
dtype = torch.bfloat16
quant_dtype = scalar_types.float4_e2m1f
group_size = 32
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
w1_data = MarlinMoEWeightData.make(
w=w1, quant_type=quant_dtype, group_size=group_size, act_order=None
)
w2_data = MarlinMoEWeightData.make(
w=w2, quant_type=quant_dtype, group_size=group_size, act_order=None
)
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
class BatchedRun:
@staticmethod
def _make_expert_num_tokens_cpu(
e: int, # num_experts
topk_ids_cpu: torch.Tensor,
) -> torch.Tensor:
expert_num_tokens_cpu = torch.zeros((e,), dtype=torch.int32, device="cpu")
for topk_id in torch.flatten(topk_ids_cpu):
expert_num_tokens_cpu[topk_id] += 1
return expert_num_tokens_cpu
def __init__(
self,
max_tokens_per_batch: int,
num_experts: int,
_topk_ids: torch.Tensor,
_topk_weights: torch.Tensor,
):
self.max_tokens_per_batch = max_tokens_per_batch
self.e = num_experts
self.topk_ids_cpu = _topk_ids.to("cpu")
self.topk_weights_cpu = _topk_weights.to("cpu")
self.expert_num_tokens_cpu = self._make_expert_num_tokens_cpu(
self.e, self.topk_ids_cpu
)
def is_valid(self):
"""
Return True only if the input can be represented in a Batched
format.
"""
return torch.all(self.expert_num_tokens_cpu <= self.max_tokens_per_batch)
def _scatter(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states_cpu = hidden_states.to("cpu")
K = hidden_states_cpu.size(1)
batched_hidden_states_cpu = torch.empty(
(e, max_tokens_per_batch, K),
dtype=hidden_states_cpu.dtype,
device="cpu",
)
counter_cpu = torch.zeros_like(self.expert_num_tokens_cpu)
for t_idx, token in enumerate(hidden_states_cpu):
for topk_id in self.topk_ids_cpu[t_idx]:
pos_in_batch = counter_cpu[topk_id]
batched_hidden_states_cpu[topk_id, pos_in_batch] = token
counter_cpu[topk_id] += 1
assert torch.allclose(counter_cpu, self.expert_num_tokens_cpu)
return batched_hidden_states_cpu.to("cuda")
def _gather(
self, batched_outputs: torch.Tensor, gather_outputs: torch.Tensor
) -> torch.Tensor:
batched_outputs_cpu = batched_outputs.to("cpu")
gather_outputs_cpu = torch.zeros_like(gather_outputs)
counter_cpu = torch.zeros((e,), device="cpu", dtype=torch.int32)
md = gather_outputs_cpu.size(0)
for t_idx in range(md):
token = None
for topk_id, topk_weight in zip(
self.topk_ids_cpu[t_idx], self.topk_weights_cpu[t_idx]
):
pos_in_batch = counter_cpu[topk_id]
t = batched_outputs_cpu[topk_id, pos_in_batch] * topk_weight
if token is None:
token = t
else:
token += t
counter_cpu[topk_id] += 1
assert token is not None
gather_outputs_cpu[t_idx] = token
gather_outputs.copy_(gather_outputs_cpu)
return gather_outputs
def run(
self, hidden_states: torch.Tensor, fused_marlin_moe_kwargs: dict[Any, Any]
) -> torch.Tensor:
assert hidden_states.ndim == 2
assert self.is_valid()
batched_hidden_states = self._scatter(hidden_states)
kwargs = fused_marlin_moe_kwargs | {
"hidden_states": batched_hidden_states,
"expert_num_tokens": self.expert_num_tokens_cpu.to("cuda"),
}
batched_outputs = batched_fused_marlin_moe(**kwargs)
output = torch.zeros_like(hidden_states)
output = self._gather(batched_outputs, output)
return output
kwargs = {
"w1": w1_data.qweight,
"w2": w2_data.qweight,
"bias1": None,
"bias2": None,
"w1_scale": w1_data.scales,
"w2_scale": w2_data.scales,
"gating_output": score,
"global_num_experts": e,
"expert_map": None,
"global_scale1": w1_data.global_scale,
"global_scale2": w2_data.global_scale,
"g_idx1": w1_data.g_idx,
"g_idx2": w2_data.g_idx,
"sort_indices1": w1_data.sort_indices,
"sort_indices2": w2_data.sort_indices,
"w1_zeros": w1_data.zeros,
"w2_zeros": w2_data.zeros,
"quant_type_id": quant_dtype.id,
"is_k_full": True,
}
# Reference
fused_marlin_moe_kwargs = kwargs | {
"hidden_states": a,
"topk_ids": topk_ids,
"topk_weights": topk_weights,
}
ref_marlin_output = fused_marlin_moe(**fused_marlin_moe_kwargs)
# Batched
br = BatchedRun(max_tokens_per_batch, e, topk_ids, topk_weights)
if not br.is_valid():
pytest.skip("Cannot represent data in Batched Format.")
marlin_output = br.run(a, kwargs)
torch.testing.assert_close(marlin_output, ref_marlin_output, atol=1e-3, rtol=0)