[GPTOSS][DP/EP][Marlin] Enable GPTOSS Batched DP/EP using Marlin kernels (#25997)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
2ed8b6b3d0
commit
fb0571b077
@@ -7,6 +7,8 @@ Run `pytest tests/kernels/test_moe.py`.
|
||||
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -26,7 +28,10 @@ from vllm.model_executor.layers.fused_moe.config import (
|
||||
int4_w4a16_moe_quant_config,
|
||||
int8_w8a16_moe_quant_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
batched_fused_marlin_moe,
|
||||
fused_marlin_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk,
|
||||
modular_triton_fused_moe,
|
||||
@@ -564,6 +569,105 @@ def marlin_moe_generate_valid_test_cases():
|
||||
return cases
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarlinMoEWeightData:
|
||||
w_ref: torch.Tensor
|
||||
qweight: torch.Tensor
|
||||
scales: torch.Tensor
|
||||
global_scale: torch.Tensor | None
|
||||
g_idx: torch.Tensor | None
|
||||
zeros: torch.Tensor | None
|
||||
sort_indices: torch.Tensor | None
|
||||
marlin_bias: torch.Tensor | None
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
act_order: bool | None = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> "MarlinMoEWeightData":
|
||||
assert w.ndim == 3
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
k = w.shape[-1]
|
||||
|
||||
w_ref_l: list[torch.Tensor] = []
|
||||
qweight_l: list[torch.Tensor] = []
|
||||
scales_l: list[torch.Tensor] = []
|
||||
global_scale_l: list[torch.Tensor] = []
|
||||
zeros_l: list[torch.Tensor] = []
|
||||
g_idx_l: list[torch.Tensor] = []
|
||||
sort_indices_l: list[torch.Tensor] = []
|
||||
bias_l: list[torch.Tensor] = []
|
||||
|
||||
for i in range(w.shape[0]):
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size == 16:
|
||||
w_ref, qweight, scales, global_scale = (
|
||||
rand_marlin_weight_nvfp4_like(w[i], group_size)
|
||||
)
|
||||
else:
|
||||
w_ref, qweight, scales = rand_marlin_weight_mxfp4_like(
|
||||
w[i], group_size
|
||||
)
|
||||
global_scale = None
|
||||
|
||||
w_ref_l.append(w_ref.T)
|
||||
qweight_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
if global_scale is not None:
|
||||
global_scale_l.append(global_scale)
|
||||
elif quant_type == scalar_types.float8_e4m3fn:
|
||||
w_ref, qweight, scales = marlin_quant_fp8_torch(w[i], group_size)
|
||||
w_ref_l.append(w_ref.T)
|
||||
qweight_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
elif has_zp:
|
||||
w_ref, qweight, scales, zeros = awq_marlin_quantize(
|
||||
w[i].transpose(1, 0), quant_type, group_size
|
||||
)
|
||||
|
||||
w_ref_l.append(w_ref.T)
|
||||
qweight_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
zeros_l.append(zeros)
|
||||
else:
|
||||
test_perm = torch.randperm(k)
|
||||
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
|
||||
w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
||||
)
|
||||
|
||||
w_ref_l.append(w_ref.T)
|
||||
qweight_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
g_idx_l.append(g_idx)
|
||||
sort_indices_l.append(sort_indices)
|
||||
|
||||
if bias is not None:
|
||||
bias_l.append(marlin_permute_bias(bias[i]))
|
||||
|
||||
w_ref = stack_and_dev(w_ref_l)
|
||||
qweight = stack_and_dev(qweight_l).contiguous()
|
||||
scales = stack_and_dev(scales_l)
|
||||
global_scale = stack_and_dev(global_scale_l) if global_scale_l else None
|
||||
g_idx = stack_and_dev(g_idx_l) if g_idx_l else None
|
||||
zeros = stack_and_dev(zeros_l) if zeros_l else None
|
||||
sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None
|
||||
marlin_bias = stack_and_dev(bias_l) if bias_l else None
|
||||
|
||||
return MarlinMoEWeightData(
|
||||
w_ref=w_ref,
|
||||
qweight=qweight,
|
||||
scales=scales,
|
||||
global_scale=global_scale,
|
||||
g_idx=g_idx,
|
||||
zeros=zeros,
|
||||
sort_indices=sort_indices,
|
||||
marlin_bias=marlin_bias,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.flaky(reruns=2)
|
||||
@pytest.mark.parametrize(
|
||||
("m, n, k, e, topk, ep_size, dtype, group_size,act_order, quant_type, is_k_full"),
|
||||
@@ -584,7 +688,6 @@ def test_fused_marlin_moe(
|
||||
is_k_full: bool,
|
||||
):
|
||||
torch.cuda.manual_seed(0)
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
|
||||
@@ -600,152 +703,44 @@ def test_fused_marlin_moe(
|
||||
else:
|
||||
e_map = None
|
||||
|
||||
w_ref1_l = []
|
||||
qweight1_l = []
|
||||
scales1_l = []
|
||||
global_scale1_l = []
|
||||
zeros1_l = []
|
||||
g_idx1_l = []
|
||||
sort_indices1_l = []
|
||||
w1_data = MarlinMoEWeightData.make(
|
||||
w=w1, quant_type=quant_type, group_size=group_size, act_order=act_order
|
||||
)
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size == 16:
|
||||
w_ref1, qweight1, scales1, global_scale1 = (
|
||||
rand_marlin_weight_nvfp4_like(w1[i], group_size)
|
||||
)
|
||||
else:
|
||||
w_ref1, qweight1, scales1 = rand_marlin_weight_mxfp4_like(
|
||||
w1[i], group_size
|
||||
)
|
||||
global_scale1 = None
|
||||
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
if global_scale1 is not None:
|
||||
global_scale1_l.append(global_scale1)
|
||||
elif quant_type == scalar_types.float8_e4m3fn:
|
||||
w_ref1, qweight1, scales1 = marlin_quant_fp8_torch(w1[i], group_size)
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
elif has_zp:
|
||||
w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
|
||||
w1[i].transpose(1, 0), quant_type, group_size
|
||||
)
|
||||
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
zeros1_l.append(zeros1)
|
||||
else:
|
||||
test_perm = torch.randperm(k)
|
||||
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
|
||||
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
||||
)
|
||||
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
g_idx1_l.append(g_idx1)
|
||||
sort_indices1_l.append(sort_indices1)
|
||||
|
||||
w_ref1 = stack_and_dev(w_ref1_l)
|
||||
qweight1 = stack_and_dev(qweight1_l).contiguous()
|
||||
scales1 = stack_and_dev(scales1_l)
|
||||
global_scale1 = stack_and_dev(global_scale1_l) if global_scale1_l else None
|
||||
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
|
||||
zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
|
||||
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
|
||||
|
||||
w_ref2_l = []
|
||||
qweight2_l = []
|
||||
scales2_l = []
|
||||
global_scale2_l = []
|
||||
zeros2_l = []
|
||||
g_idx2_l = []
|
||||
sort_indices2_l = []
|
||||
|
||||
for i in range(w2.shape[0]):
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size == 16:
|
||||
w_ref2, qweight2, scales2, global_scale2 = (
|
||||
rand_marlin_weight_nvfp4_like(w2[i], group_size)
|
||||
)
|
||||
else:
|
||||
w_ref2, qweight2, scales2 = rand_marlin_weight_mxfp4_like(
|
||||
w2[i], group_size
|
||||
)
|
||||
global_scale2 = None
|
||||
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
if global_scale2 is not None:
|
||||
global_scale2_l.append(global_scale2)
|
||||
elif quant_type == scalar_types.float8_e4m3fn:
|
||||
w_ref2, qweight2, scales2 = marlin_quant_fp8_torch(w2[i], group_size)
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
elif has_zp:
|
||||
w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
|
||||
w2[i].transpose(1, 0), quant_type, group_size
|
||||
)
|
||||
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
zeros2_l.append(zeros2)
|
||||
else:
|
||||
test_perm = torch.randperm(n)
|
||||
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
|
||||
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
||||
)
|
||||
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
g_idx2_l.append(g_idx2)
|
||||
sort_indices2_l.append(sort_indices2)
|
||||
|
||||
w_ref2 = stack_and_dev(w_ref2_l)
|
||||
qweight2 = stack_and_dev(qweight2_l).contiguous()
|
||||
scales2 = stack_and_dev(scales2_l)
|
||||
global_scale2 = stack_and_dev(global_scale2_l) if global_scale2_l else None
|
||||
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
|
||||
zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
|
||||
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
|
||||
w2_data = MarlinMoEWeightData.make(
|
||||
w=w2, quant_type=quant_type, group_size=group_size, act_order=act_order
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, expert_map=e_map)
|
||||
torch_output = torch_moe(
|
||||
a, w1_data.w_ref, w2_data.w_ref, score, topk, expert_map=e_map
|
||||
)
|
||||
|
||||
marlin_output = fused_marlin_moe(
|
||||
a,
|
||||
qweight1,
|
||||
qweight2,
|
||||
w1_data.qweight,
|
||||
w2_data.qweight,
|
||||
None,
|
||||
None,
|
||||
scales1,
|
||||
scales2,
|
||||
w1_data.scales,
|
||||
w2_data.scales,
|
||||
score,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
global_scale1=global_scale1,
|
||||
global_scale2=global_scale2,
|
||||
g_idx1=g_idx1,
|
||||
g_idx2=g_idx2,
|
||||
sort_indices1=sort_indices1,
|
||||
sort_indices2=sort_indices2,
|
||||
w1_zeros=zeros1,
|
||||
w2_zeros=zeros2,
|
||||
global_scale1=w1_data.global_scale,
|
||||
global_scale2=w2_data.global_scale,
|
||||
g_idx1=w1_data.g_idx,
|
||||
g_idx2=w2_data.g_idx,
|
||||
sort_indices1=w1_data.sort_indices,
|
||||
sort_indices2=w2_data.sort_indices,
|
||||
w1_zeros=w1_data.zeros,
|
||||
w2_zeros=w2_data.zeros,
|
||||
quant_type_id=quant_type.id,
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
@@ -773,92 +768,52 @@ def test_fused_marlin_moe_with_bias(m):
|
||||
b_bias1 = torch.randn((e, 2 * n), device="cuda", dtype=dtype) / 10
|
||||
b_bias2 = torch.randn((e, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
b_bias1_l = []
|
||||
w_ref1_l = []
|
||||
qweight1_l = []
|
||||
scales1_l = []
|
||||
g_idx1_l = []
|
||||
sort_indices1_l = []
|
||||
w1_data = MarlinMoEWeightData.make(
|
||||
w=w1,
|
||||
quant_type=quant_type,
|
||||
group_size=group_size,
|
||||
act_order=act_order,
|
||||
bias=b_bias1,
|
||||
)
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
test_perm = torch.randperm(k)
|
||||
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize(
|
||||
w1[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
||||
)
|
||||
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
g_idx1_l.append(g_idx1)
|
||||
sort_indices1_l.append(sort_indices1)
|
||||
b_bias1_l.append(marlin_permute_bias(b_bias1[i]))
|
||||
|
||||
w_ref1 = stack_and_dev(w_ref1_l)
|
||||
qweight1 = stack_and_dev(qweight1_l).contiguous()
|
||||
scales1 = stack_and_dev(scales1_l)
|
||||
global_scale1 = None
|
||||
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
|
||||
zeros1 = None
|
||||
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
|
||||
marlin_bias1 = stack_and_dev(b_bias1_l) if b_bias1_l else None
|
||||
|
||||
b_bias2_l = []
|
||||
w_ref2_l = []
|
||||
qweight2_l = []
|
||||
scales2_l = []
|
||||
g_idx2_l = []
|
||||
sort_indices2_l = []
|
||||
|
||||
for i in range(w2.shape[0]):
|
||||
test_perm = torch.randperm(n)
|
||||
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize(
|
||||
w2[i].transpose(1, 0), quant_type, group_size, act_order, test_perm
|
||||
)
|
||||
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
g_idx2_l.append(g_idx2)
|
||||
sort_indices2_l.append(sort_indices2)
|
||||
b_bias2_l.append(marlin_permute_bias(b_bias2[i]))
|
||||
|
||||
w_ref2 = stack_and_dev(w_ref2_l)
|
||||
qweight2 = stack_and_dev(qweight2_l).contiguous()
|
||||
scales2 = stack_and_dev(scales2_l)
|
||||
global_scale2 = None
|
||||
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
|
||||
zeros2 = None
|
||||
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
|
||||
marlin_bias2 = stack_and_dev(b_bias2_l) if b_bias2_l else None
|
||||
w2_data = MarlinMoEWeightData.make(
|
||||
w=w2,
|
||||
quant_type=quant_type,
|
||||
group_size=group_size,
|
||||
act_order=act_order,
|
||||
bias=b_bias2,
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, b_bias1, b_bias2)
|
||||
torch_output = torch_moe(
|
||||
a, w1_data.w_ref, w2_data.w_ref, score, topk, b_bias1, b_bias2
|
||||
)
|
||||
|
||||
marlin_output = fused_marlin_moe(
|
||||
a,
|
||||
qweight1,
|
||||
qweight2,
|
||||
marlin_bias1,
|
||||
marlin_bias2,
|
||||
scales1,
|
||||
scales2,
|
||||
w1_data.qweight,
|
||||
w2_data.qweight,
|
||||
w1_data.marlin_bias,
|
||||
w2_data.marlin_bias,
|
||||
w1_data.scales,
|
||||
w2_data.scales,
|
||||
score,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
global_scale1=global_scale1,
|
||||
global_scale2=global_scale2,
|
||||
g_idx1=g_idx1,
|
||||
g_idx2=g_idx2,
|
||||
sort_indices1=sort_indices1,
|
||||
sort_indices2=sort_indices2,
|
||||
w1_zeros=zeros1,
|
||||
w2_zeros=zeros2,
|
||||
global_scale1=w1_data.global_scale,
|
||||
global_scale2=w2_data.global_scale,
|
||||
g_idx1=w1_data.g_idx,
|
||||
g_idx2=w2_data.g_idx,
|
||||
sort_indices1=w1_data.sort_indices,
|
||||
sort_indices2=w2_data.sort_indices,
|
||||
w1_zeros=w1_data.zeros,
|
||||
w2_zeros=w2_data.zeros,
|
||||
quant_type_id=quant_type.id,
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
@@ -895,6 +850,41 @@ def test_moe_align_block_size_opcheck():
|
||||
)
|
||||
|
||||
|
||||
def test_batched_moe_align_block_size_opcheck():
|
||||
max_tokens_per_batch = 512
|
||||
num_experts = 4
|
||||
block_size = 16
|
||||
|
||||
expert_num_tokens = torch.randint(
|
||||
low=0,
|
||||
high=max_tokens_per_batch,
|
||||
size=(num_experts,),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
max_num_tokens_padded = num_experts * max(max_tokens_per_batch, block_size)
|
||||
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda")
|
||||
|
||||
assert max_num_tokens_padded % block_size == 0
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda")
|
||||
|
||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device="cuda")
|
||||
|
||||
opcheck(
|
||||
torch.ops._moe_C.batched_moe_align_block_size,
|
||||
(
|
||||
max_tokens_per_batch,
|
||||
block_size,
|
||||
expert_num_tokens,
|
||||
sorted_ids,
|
||||
expert_ids,
|
||||
num_tokens_post_pad,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222])
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@@ -979,3 +969,171 @@ def test_cpu_fused_moe_basic(m, n, k, e, topk, dtype, with_bias, activation):
|
||||
else:
|
||||
atol = 5e-2
|
||||
torch.testing.assert_close(out, ref, atol=atol, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [16, 32, 64])
|
||||
@pytest.mark.parametrize("n", [128])
|
||||
@pytest.mark.parametrize("k", [128])
|
||||
@pytest.mark.parametrize("e", [8, 12, 16, 32])
|
||||
@pytest.mark.parametrize("topk", [2, 4])
|
||||
@pytest.mark.parametrize("max_tokens_per_batch", [16, 32, 64])
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_batched_fused_marlin_moe(
|
||||
m: int, n: int, k: int, e: int, topk: int, max_tokens_per_batch: int
|
||||
):
|
||||
print(
|
||||
f"testing m={m}, n={n}, k={k}, e={e}, "
|
||||
f"topk={topk}, "
|
||||
f"max_tokens_per_batch={max_tokens_per_batch}"
|
||||
)
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
dtype = torch.bfloat16
|
||||
quant_dtype = scalar_types.float4_e2m1f
|
||||
group_size = 32
|
||||
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20
|
||||
|
||||
w1_data = MarlinMoEWeightData.make(
|
||||
w=w1, quant_type=quant_dtype, group_size=group_size, act_order=None
|
||||
)
|
||||
w2_data = MarlinMoEWeightData.make(
|
||||
w=w2, quant_type=quant_dtype, group_size=group_size, act_order=None
|
||||
)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(a, score, topk, False)
|
||||
|
||||
class BatchedRun:
|
||||
@staticmethod
|
||||
def _make_expert_num_tokens_cpu(
|
||||
e: int, # num_experts
|
||||
topk_ids_cpu: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
expert_num_tokens_cpu = torch.zeros((e,), dtype=torch.int32, device="cpu")
|
||||
for topk_id in torch.flatten(topk_ids_cpu):
|
||||
expert_num_tokens_cpu[topk_id] += 1
|
||||
return expert_num_tokens_cpu
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens_per_batch: int,
|
||||
num_experts: int,
|
||||
_topk_ids: torch.Tensor,
|
||||
_topk_weights: torch.Tensor,
|
||||
):
|
||||
self.max_tokens_per_batch = max_tokens_per_batch
|
||||
self.e = num_experts
|
||||
self.topk_ids_cpu = _topk_ids.to("cpu")
|
||||
self.topk_weights_cpu = _topk_weights.to("cpu")
|
||||
self.expert_num_tokens_cpu = self._make_expert_num_tokens_cpu(
|
||||
self.e, self.topk_ids_cpu
|
||||
)
|
||||
|
||||
def is_valid(self):
|
||||
"""
|
||||
Return True only if the input can be represented in a Batched
|
||||
format.
|
||||
"""
|
||||
return torch.all(self.expert_num_tokens_cpu <= self.max_tokens_per_batch)
|
||||
|
||||
def _scatter(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states_cpu = hidden_states.to("cpu")
|
||||
K = hidden_states_cpu.size(1)
|
||||
batched_hidden_states_cpu = torch.empty(
|
||||
(e, max_tokens_per_batch, K),
|
||||
dtype=hidden_states_cpu.dtype,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
counter_cpu = torch.zeros_like(self.expert_num_tokens_cpu)
|
||||
for t_idx, token in enumerate(hidden_states_cpu):
|
||||
for topk_id in self.topk_ids_cpu[t_idx]:
|
||||
pos_in_batch = counter_cpu[topk_id]
|
||||
batched_hidden_states_cpu[topk_id, pos_in_batch] = token
|
||||
counter_cpu[topk_id] += 1
|
||||
assert torch.allclose(counter_cpu, self.expert_num_tokens_cpu)
|
||||
return batched_hidden_states_cpu.to("cuda")
|
||||
|
||||
def _gather(
|
||||
self, batched_outputs: torch.Tensor, gather_outputs: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
batched_outputs_cpu = batched_outputs.to("cpu")
|
||||
gather_outputs_cpu = torch.zeros_like(gather_outputs)
|
||||
|
||||
counter_cpu = torch.zeros((e,), device="cpu", dtype=torch.int32)
|
||||
md = gather_outputs_cpu.size(0)
|
||||
for t_idx in range(md):
|
||||
token = None
|
||||
for topk_id, topk_weight in zip(
|
||||
self.topk_ids_cpu[t_idx], self.topk_weights_cpu[t_idx]
|
||||
):
|
||||
pos_in_batch = counter_cpu[topk_id]
|
||||
t = batched_outputs_cpu[topk_id, pos_in_batch] * topk_weight
|
||||
if token is None:
|
||||
token = t
|
||||
else:
|
||||
token += t
|
||||
counter_cpu[topk_id] += 1
|
||||
assert token is not None
|
||||
gather_outputs_cpu[t_idx] = token
|
||||
gather_outputs.copy_(gather_outputs_cpu)
|
||||
return gather_outputs
|
||||
|
||||
def run(
|
||||
self, hidden_states: torch.Tensor, fused_marlin_moe_kwargs: dict[Any, Any]
|
||||
) -> torch.Tensor:
|
||||
assert hidden_states.ndim == 2
|
||||
assert self.is_valid()
|
||||
|
||||
batched_hidden_states = self._scatter(hidden_states)
|
||||
|
||||
kwargs = fused_marlin_moe_kwargs | {
|
||||
"hidden_states": batched_hidden_states,
|
||||
"expert_num_tokens": self.expert_num_tokens_cpu.to("cuda"),
|
||||
}
|
||||
batched_outputs = batched_fused_marlin_moe(**kwargs)
|
||||
|
||||
output = torch.zeros_like(hidden_states)
|
||||
output = self._gather(batched_outputs, output)
|
||||
return output
|
||||
|
||||
kwargs = {
|
||||
"w1": w1_data.qweight,
|
||||
"w2": w2_data.qweight,
|
||||
"bias1": None,
|
||||
"bias2": None,
|
||||
"w1_scale": w1_data.scales,
|
||||
"w2_scale": w2_data.scales,
|
||||
"gating_output": score,
|
||||
"global_num_experts": e,
|
||||
"expert_map": None,
|
||||
"global_scale1": w1_data.global_scale,
|
||||
"global_scale2": w2_data.global_scale,
|
||||
"g_idx1": w1_data.g_idx,
|
||||
"g_idx2": w2_data.g_idx,
|
||||
"sort_indices1": w1_data.sort_indices,
|
||||
"sort_indices2": w2_data.sort_indices,
|
||||
"w1_zeros": w1_data.zeros,
|
||||
"w2_zeros": w2_data.zeros,
|
||||
"quant_type_id": quant_dtype.id,
|
||||
"is_k_full": True,
|
||||
}
|
||||
|
||||
# Reference
|
||||
fused_marlin_moe_kwargs = kwargs | {
|
||||
"hidden_states": a,
|
||||
"topk_ids": topk_ids,
|
||||
"topk_weights": topk_weights,
|
||||
}
|
||||
ref_marlin_output = fused_marlin_moe(**fused_marlin_moe_kwargs)
|
||||
|
||||
# Batched
|
||||
br = BatchedRun(max_tokens_per_batch, e, topk_ids, topk_weights)
|
||||
if not br.is_valid():
|
||||
pytest.skip("Cannot represent data in Batched Format.")
|
||||
marlin_output = br.run(a, kwargs)
|
||||
|
||||
torch.testing.assert_close(marlin_output, ref_marlin_output, atol=1e-3, rtol=0)
|
||||
|
||||
Reference in New Issue
Block a user