Categorize tests/kernels/ based on kernel type (#16799)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
364
tests/kernels/moe/test_cutlass_moe.py
Normal file
364
tests/kernels/moe/test_cutlass_moe.py
Normal file
@@ -0,0 +1,364 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import dataclasses
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (fused_experts,
|
||||
fused_topk)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
NUM_EXPERTS = [40, 64]
|
||||
TOP_KS = [6, 8]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(2, 1024, 1024),
|
||||
(2, 1024, 1536),
|
||||
(2, 3072, 1024),
|
||||
(2, 3072, 1536),
|
||||
(64, 1024, 1024),
|
||||
(64, 1024, 1536),
|
||||
(64, 3072, 1024),
|
||||
(64, 3072, 1536),
|
||||
(224, 1024, 1024),
|
||||
(224, 1024, 1536),
|
||||
(224, 3072, 1024),
|
||||
(224, 3072, 1536),
|
||||
]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MOETensors:
|
||||
a: torch.Tensor
|
||||
w1: torch.Tensor
|
||||
w2: torch.Tensor
|
||||
ab_strides1: torch.Tensor
|
||||
c_strides1: torch.Tensor
|
||||
ab_strides2: torch.Tensor
|
||||
c_strides2: torch.Tensor
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors(m: int, k: int, n: int, e: int,
|
||||
dtype: torch.dtype) -> "MOETensors":
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
ab_strides1 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
c_strides1 = torch.full((e, ), 2 * n, device="cuda", dtype=torch.int64)
|
||||
ab_strides2 = torch.full((e, ), n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e, ), k, device="cuda", dtype=torch.int64)
|
||||
return MOETensors(a=a,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
ab_strides1=ab_strides1,
|
||||
c_strides1=c_strides1,
|
||||
ab_strides2=ab_strides2,
|
||||
c_strides2=c_strides2)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class MOETensors8Bit(MOETensors):
|
||||
# quantized
|
||||
a_q: Optional[torch.Tensor] = None # a -> a_q
|
||||
w1_q: Optional[torch.Tensor] = None # w1 -> w1_q
|
||||
w2_q: Optional[torch.Tensor] = None # w2 -> w2_q
|
||||
a_scale: Optional[torch.Tensor] = None
|
||||
w1_scale: Optional[torch.Tensor] = None
|
||||
w2_scale: Optional[torch.Tensor] = None
|
||||
# dequantized
|
||||
a_d: Optional[torch.Tensor] = None # a -> a_q -> a_d
|
||||
w1_d: Optional[torch.Tensor] = None # w1 -> w1_q -> w1_d
|
||||
w2_d: Optional[torch.Tensor] = None # w2 -> w2_q -> w2_d
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors_8bit(m: int, k: int, n: int, e: int,
|
||||
per_act_token: bool,
|
||||
per_out_channel: bool) -> "MOETensors8Bit":
|
||||
dtype = torch.half
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
|
||||
moe_tensors_fp16 = MOETensors.make_moe_tensors(m, k, n, e, dtype)
|
||||
|
||||
# a -> a_q, w1 -> w1_q, w2 -> w2_q
|
||||
n_b_scales = 2 * n if per_out_channel else 1
|
||||
k_b_scales = k if per_out_channel else 1
|
||||
# Get the right scale for tests.
|
||||
_, a_scale = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.a, use_per_token_if_dynamic=per_act_token)
|
||||
a_q, _ = ops.scaled_fp8_quant(moe_tensors_fp16.a,
|
||||
a_scale,
|
||||
use_per_token_if_dynamic=per_act_token)
|
||||
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
|
||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
|
||||
|
||||
w1_scale = torch.empty((e, n_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.w1[expert],
|
||||
use_per_token_if_dynamic=per_out_channel)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
moe_tensors_fp16.w2[expert],
|
||||
use_per_token_if_dynamic=per_out_channel)
|
||||
|
||||
# a_q -> a_d, w1_q -> w1_d, w2_q -> w2_d
|
||||
a_d = a_q.float().mul(a_scale).to(dtype)
|
||||
w1_d = torch.empty_like(moe_tensors_fp16.w1)
|
||||
w2_d = torch.empty_like(moe_tensors_fp16.w2)
|
||||
for expert in range(e):
|
||||
w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half()
|
||||
w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half()
|
||||
|
||||
return MOETensors8Bit(a=moe_tensors_fp16.a,
|
||||
w1=moe_tensors_fp16.w1,
|
||||
w2=moe_tensors_fp16.w2,
|
||||
ab_strides1=moe_tensors_fp16.ab_strides1,
|
||||
c_strides1=moe_tensors_fp16.c_strides1,
|
||||
ab_strides2=moe_tensors_fp16.ab_strides2,
|
||||
c_strides2=moe_tensors_fp16.c_strides2,
|
||||
a_q=a_q,
|
||||
w1_q=w1_q,
|
||||
w2_q=w2_q,
|
||||
a_scale=a_scale,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a_d=a_d,
|
||||
w1_d=w1_d,
|
||||
w2_d=w2_d)
|
||||
|
||||
|
||||
def run_with_expert_maps(num_experts: int, num_local_experts: int,
|
||||
**cutlass_moe_kwargs):
|
||||
|
||||
def slice_experts():
|
||||
slice_params = [
|
||||
"w1_q", "w2_q", "ab_strides1", "ab_strides2", "c_strides1",
|
||||
"c_strides2", "w1_scale", "w2_scale"
|
||||
]
|
||||
full_tensors = {
|
||||
k: v
|
||||
for k, v in cutlass_moe_kwargs.items()
|
||||
if k in slice_params and k in cutlass_moe_kwargs
|
||||
}
|
||||
|
||||
for i in range(0, num_experts, num_local_experts):
|
||||
s, e = i, i + num_local_experts
|
||||
|
||||
# make expert map
|
||||
expert_map = [-1] * num_experts
|
||||
expert_map[s:e] = list(range(num_local_experts))
|
||||
expert_map = torch.tensor(expert_map,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
|
||||
# update cutlass moe arg with expert_map
|
||||
cutlass_moe_kwargs["expert_map"] = expert_map
|
||||
# update cutlass moe arg tensors
|
||||
for k, t in full_tensors.items():
|
||||
cutlass_moe_kwargs[k] = t[s:e]
|
||||
|
||||
yield cutlass_moe_kwargs
|
||||
|
||||
out_tensor = torch.zeros_like(cutlass_moe_kwargs["a"])
|
||||
for kwargs in slice_experts():
|
||||
out_tensor = out_tensor + cutlass_moe_fp8(**kwargs)
|
||||
|
||||
return out_tensor
|
||||
|
||||
|
||||
def run_8_bit(moe_tensors: MOETensors8Bit,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_local_experts: Optional[int] = None) -> torch.Tensor:
|
||||
assert not any([
|
||||
t is None for t in [
|
||||
moe_tensors.w1_q, moe_tensors.w2_q, moe_tensors.w1_scale,
|
||||
moe_tensors.w2_scale, moe_tensors.a_scale
|
||||
]
|
||||
])
|
||||
|
||||
kwargs = {
|
||||
'a': moe_tensors.a,
|
||||
'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr]
|
||||
'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr]
|
||||
'topk_weights': topk_weights,
|
||||
'topk_ids_': topk_ids,
|
||||
'ab_strides1': moe_tensors.ab_strides1,
|
||||
'c_strides1': moe_tensors.c_strides1,
|
||||
'ab_strides2': moe_tensors.ab_strides2,
|
||||
'c_strides2': moe_tensors.c_strides2,
|
||||
'w1_scale': moe_tensors.w1_scale,
|
||||
'w2_scale': moe_tensors.w2_scale,
|
||||
'a1_scale': moe_tensors.a_scale
|
||||
}
|
||||
|
||||
num_experts = moe_tensors.w1.size(0)
|
||||
with_ep = num_local_experts is not None or num_local_experts == num_experts
|
||||
if not with_ep:
|
||||
return cutlass_moe_fp8(**kwargs)
|
||||
|
||||
assert num_local_experts is not None
|
||||
return run_with_expert_maps(
|
||||
num_experts,
|
||||
num_local_experts, # type: ignore[arg-type]
|
||||
**kwargs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
def test_cutlass_moe_8_bit_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||
per_out_ch)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.half)
|
||||
topk_weights, topk_ids = fused_topk(mt.a,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False)
|
||||
|
||||
# Note that we are using the dequantized versions of the tensors.
|
||||
# Using a, w1 and w2 directly results in minor output differences.
|
||||
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
|
||||
topk_ids)
|
||||
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
cutlass_output,
|
||||
atol=5e-2,
|
||||
rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("per_act_token", [True, False])
|
||||
@pytest.mark.parametrize("per_out_ch", [True, False])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
def test_cutlass_moe_8_bit_cuda_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_ch: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
|
||||
dtype = torch.half
|
||||
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||
per_out_ch)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
topk_weights, topk_ids = fused_topk(mt.a,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False)
|
||||
|
||||
# Note that we are using the dequantized versions of the tensors.
|
||||
# Using a, w1 and w2 directly results in minor output differences.
|
||||
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
|
||||
topk_ids)
|
||||
|
||||
stream = torch.cuda.Stream()
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph, stream=stream):
|
||||
cutlass_output = run_8_bit(mt, topk_weights, topk_ids)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
graph.replay()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
cutlass_output,
|
||||
atol=9e-2,
|
||||
rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [64])
|
||||
@pytest.mark.parametrize("n", [1024])
|
||||
@pytest.mark.parametrize("k", [4096])
|
||||
@pytest.mark.parametrize("e", [16])
|
||||
@pytest.mark.parametrize("topk", [1, 8])
|
||||
@pytest.mark.parametrize("per_act_token", [True])
|
||||
@pytest.mark.parametrize("per_out_channel", [True])
|
||||
@pytest.mark.parametrize("ep_size", [1, 2, 4, 8, 16])
|
||||
@pytest.mark.skipif(
|
||||
(lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))(
|
||||
current_platform.get_device_capability()),
|
||||
reason="Grouped gemm is not supported on this GPU type.")
|
||||
def test_cutlass_moe_8_bit_EP(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
per_act_token: bool,
|
||||
per_out_channel: bool,
|
||||
ep_size: int,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
with set_current_vllm_config(
|
||||
VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))):
|
||||
|
||||
mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token,
|
||||
per_out_channel)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.half)
|
||||
topk_weights, topk_ids = fused_topk(mt.a,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False)
|
||||
|
||||
# Note that we are using the dequantized versions of the tensors.
|
||||
# Using a, w1 and w2 directly results in minor output differences.
|
||||
triton_output = fused_experts(mt.a_d, mt.w1_d, mt.w2_d, topk_weights,
|
||||
topk_ids)
|
||||
|
||||
assert e % ep_size == 0, "Cannot distribute experts evenly"
|
||||
cutlass_output = run_8_bit(mt,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
num_local_experts=e // ep_size)
|
||||
|
||||
torch.testing.assert_close(triton_output,
|
||||
cutlass_output,
|
||||
atol=5e-2,
|
||||
rtol=1e-2)
|
||||
565
tests/kernels/moe/test_moe.py
Normal file
565
tests/kernels/moe/test_moe.py
Normal file
@@ -0,0 +1,565 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Tests for the MOE layers.
|
||||
|
||||
Run `pytest tests/kernels/test_moe.py`.
|
||||
"""
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
from torch.nn import functional as F
|
||||
from transformers import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe,
|
||||
torch_moe_single)
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
||||
fused_moe as iterative_moe)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
awq_marlin_quantize, marlin_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
quantize_weights)
|
||||
from vllm.model_executor.models.mixtral import MixtralMoE
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
EP_SIZE = [1, 4]
|
||||
TOP_KS = [2, 6]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 511, 1024])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("padding", [True, False])
|
||||
def test_fused_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
padding: bool,
|
||||
):
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
if ep_size > 1:
|
||||
local_e = e // ep_size
|
||||
e_ids = torch.randint(0,
|
||||
e, (local_e, ),
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
|
||||
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
|
||||
w1 = w1[e_ids]
|
||||
w2 = w2[e_ids]
|
||||
else:
|
||||
e_map = None
|
||||
|
||||
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
|
||||
iterative_output = iterative_moe(a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
renormalize=False)
|
||||
|
||||
# Pad the weight if moe padding is enabled
|
||||
if padding:
|
||||
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
|
||||
torch.cuda.empty_cache()
|
||||
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
triton_output = fused_moe(a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
renormalize=False)
|
||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||
torch.testing.assert_close(iterative_output,
|
||||
torch_output,
|
||||
atol=2e-2,
|
||||
rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 32, 222])
|
||||
@pytest.mark.parametrize("n", [128, 1024, 2048])
|
||||
@pytest.mark.parametrize("k", [128, 1024])
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("group_size", [64, 128])
|
||||
@pytest.mark.parametrize("has_zp", [True, False])
|
||||
@pytest.mark.parametrize("weight_bits", [4, 8])
|
||||
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
ep_size: int, dtype: torch.dtype, group_size: int,
|
||||
has_zp: bool, weight_bits: int):
|
||||
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
if weight_bits == 4:
|
||||
pack_factor = 2
|
||||
quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8
|
||||
elif weight_bits == 8:
|
||||
pack_factor = 1
|
||||
quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128
|
||||
|
||||
w1_ref = w1.clone()
|
||||
w2_ref = w2.clone()
|
||||
w1_qweight = torch.empty((e, 2 * n, k // pack_factor),
|
||||
device="cuda",
|
||||
dtype=torch.uint8)
|
||||
w2_qweight = torch.empty((e, k, n // pack_factor),
|
||||
device="cuda",
|
||||
dtype=torch.uint8)
|
||||
w1_scales = torch.empty((e, 2 * n, k // group_size),
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
w2_scales = torch.empty((e, k, n // group_size),
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size),
|
||||
device="cuda",
|
||||
dtype=torch.uint8)
|
||||
w2_qzeros = torch.empty((e, k // pack_factor, n // group_size),
|
||||
device="cuda",
|
||||
dtype=torch.uint8)
|
||||
|
||||
for i in range(e * 2):
|
||||
expert_id = i % e
|
||||
if i // e == 0:
|
||||
w, w_ref, w_qweight, w_scales, w_qzeros = \
|
||||
w1, w1_ref, w1_qweight, w1_scales, w1_qzeros
|
||||
else:
|
||||
w, w_ref, w_qweight, w_scales, w_qzeros = \
|
||||
w2, w2_ref, w2_qweight, w2_scales, w2_qzeros
|
||||
weight, qweight, scales, qzeros = quantize_weights(
|
||||
w[expert_id].T, quant_type, group_size, has_zp, False)
|
||||
weight = weight.T
|
||||
qweight = qweight.T.contiguous().to(torch.uint8)
|
||||
scales = scales.T
|
||||
if has_zp:
|
||||
qzeros = qzeros.T.contiguous().to(torch.uint8)
|
||||
if weight_bits == 4:
|
||||
qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
|
||||
if has_zp:
|
||||
qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]
|
||||
|
||||
w_ref[expert_id] = weight
|
||||
w_qweight[expert_id] = qweight
|
||||
w_scales[expert_id] = scales
|
||||
if has_zp:
|
||||
w_qzeros[expert_id] = qzeros
|
||||
|
||||
if ep_size > 1:
|
||||
local_e = e // ep_size
|
||||
e_ids = torch.randint(0,
|
||||
e, (local_e, ),
|
||||
device="cuda",
|
||||
dtype=torch.int32)
|
||||
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
|
||||
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
|
||||
w1_ref = w1_ref[e_ids]
|
||||
w2_ref = w2_ref[e_ids]
|
||||
w1_qweight = w1_qweight[e_ids]
|
||||
w2_qweight = w2_qweight[e_ids]
|
||||
w1_scales = w1_scales[e_ids]
|
||||
w2_scales = w2_scales[e_ids]
|
||||
w1_qzeros = w1_qzeros[e_ids]
|
||||
w2_qzeros = w2_qzeros[e_ids]
|
||||
else:
|
||||
e_map = None
|
||||
|
||||
triton_output = fused_moe(a,
|
||||
w1_qweight,
|
||||
w2_qweight,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
use_int4_w4a16=weight_bits == 4,
|
||||
use_int8_w8a16=weight_bits == 8,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
w1_scale=w1_scales,
|
||||
w2_scale=w2_scales,
|
||||
w1_zp=w1_qzeros if has_zp else None,
|
||||
w2_zp=w2_qzeros if has_zp else None,
|
||||
block_shape=[0, group_size])
|
||||
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map)
|
||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("padding", [True, False])
|
||||
@pytest.mark.parametrize(
|
||||
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
|
||||
@torch.inference_mode()
|
||||
def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool,
|
||||
monkeypatch):
|
||||
"""Make sure our Mixtral MoE implementation agrees with the one from
|
||||
huggingface."""
|
||||
|
||||
if use_rocm_aiter:
|
||||
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
|
||||
# Instantiate our and huggingface's MoE blocks
|
||||
config = MixtralConfig()
|
||||
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda")
|
||||
vllm_moe = MixtralMoE(
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
params_dtype=dtype,
|
||||
tp_size=1,
|
||||
dp_size=1,
|
||||
).cuda()
|
||||
|
||||
# Load the weights
|
||||
vllm_moe.gate.weight.data[:] = hf_moe.gate.weight.data
|
||||
for i in range(config.num_local_experts):
|
||||
weights = (hf_moe.experts[i].w1.weight.data,
|
||||
hf_moe.experts[i].w3.weight.data)
|
||||
vllm_moe.experts.w13_weight[i][:] = torch.cat(weights, dim=0)
|
||||
vllm_moe.experts.w2_weight[i][:] = hf_moe.experts[i].w2.weight.data
|
||||
|
||||
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim]
|
||||
hf_inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda")
|
||||
# vLLM uses 1D query [num_tokens, hidden_dim]
|
||||
vllm_inputs = hf_inputs.flatten(0, 1)
|
||||
|
||||
# Pad the weight if moe padding is enabled
|
||||
if padding:
|
||||
vllm_moe.experts.w13_weight = Parameter(F.pad(
|
||||
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128],
|
||||
requires_grad=False)
|
||||
torch.cuda.empty_cache()
|
||||
vllm_moe.experts.w2_weight = Parameter(F.pad(
|
||||
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128],
|
||||
requires_grad=False)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Run forward passes for both MoE blocks
|
||||
hf_states, _ = hf_moe.forward(hf_inputs)
|
||||
vllm_states = vllm_moe.forward(vllm_inputs)
|
||||
|
||||
mixtral_moe_tol = {
|
||||
torch.float32: 1e-3,
|
||||
torch.float16: 1e-3,
|
||||
torch.bfloat16: 1e-2,
|
||||
}
|
||||
|
||||
if use_rocm_aiter:
|
||||
# The values of rtol and atol are set based on the tests in ROCM AITER package. # noqa: E501
|
||||
# https://github.com/ROCm/aiter/blob/dfed377f4be7da96ca2d75ac0761f569676f7240/op_tests/test_moe.py#L174 # noqa: E501
|
||||
torch.testing.assert_close(hf_states.flatten(0, 1),
|
||||
vllm_states,
|
||||
rtol=0.01,
|
||||
atol=100)
|
||||
else:
|
||||
torch.testing.assert_close(hf_states.flatten(0, 1),
|
||||
vllm_states,
|
||||
rtol=mixtral_moe_tol[dtype],
|
||||
atol=mixtral_moe_tol[dtype])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 123])
|
||||
@pytest.mark.parametrize("n", [128, 1024])
|
||||
@pytest.mark.parametrize("k", [256, 2048])
|
||||
@pytest.mark.parametrize("e", [4, 12])
|
||||
@pytest.mark.parametrize("topk", [2, 3])
|
||||
@pytest.mark.parametrize("ep_size", [1, 4])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 128])
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||
@pytest.mark.parametrize("has_zp", [True, False])
|
||||
@pytest.mark.parametrize("is_k_full", [True, False])
|
||||
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
|
||||
def test_fused_marlin_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
num_bits: int,
|
||||
has_zp: bool,
|
||||
is_k_full: bool,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size in (k, n):
|
||||
return
|
||||
if has_zp:
|
||||
return
|
||||
else:
|
||||
if not is_k_full:
|
||||
return
|
||||
|
||||
if has_zp:
|
||||
# we don't build kernel for int8 with zero
|
||||
if num_bits == 8:
|
||||
return
|
||||
quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
|
||||
else:
|
||||
quant_type = scalar_types.uint4b8 \
|
||||
if num_bits == 4 else scalar_types.uint8b128
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
|
||||
if ep_size > 1:
|
||||
local_e = e // ep_size
|
||||
e_ids = torch.randperm(e, device="cuda", dtype=torch.int32)[:local_e]
|
||||
e_map = torch.full((e, ), -1, device="cuda", dtype=torch.int32)
|
||||
e_map[e_ids] = torch.arange(local_e, device="cuda", dtype=torch.int32)
|
||||
w1 = w1[e_ids]
|
||||
w2 = w2[e_ids]
|
||||
else:
|
||||
e_map = None
|
||||
|
||||
w_ref1_l = []
|
||||
qweight1_l = []
|
||||
scales1_l = []
|
||||
zeros1_l = []
|
||||
g_idx1_l = []
|
||||
sort_indices1_l = []
|
||||
|
||||
for i in range(w1.shape[0]):
|
||||
if has_zp:
|
||||
w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize(
|
||||
w1[i].transpose(1, 0), quant_type, group_size)
|
||||
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
zeros1_l.append(zeros1)
|
||||
else:
|
||||
test_perm = torch.randperm(k)
|
||||
quant_res = marlin_quantize(w1[i].transpose(1, 0), quant_type,
|
||||
group_size, act_order, test_perm)
|
||||
w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = quant_res
|
||||
|
||||
w_ref1_l.append(w_ref1.T)
|
||||
qweight1_l.append(qweight1)
|
||||
scales1_l.append(scales1)
|
||||
g_idx1_l.append(g_idx1)
|
||||
sort_indices1_l.append(sort_indices1)
|
||||
|
||||
w_ref1 = stack_and_dev(w_ref1_l)
|
||||
qweight1 = stack_and_dev(qweight1_l).contiguous()
|
||||
scales1 = stack_and_dev(scales1_l)
|
||||
g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None
|
||||
zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None
|
||||
sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None
|
||||
|
||||
w_ref2_l = []
|
||||
qweight2_l = []
|
||||
scales2_l = []
|
||||
zeros2_l = []
|
||||
g_idx2_l = []
|
||||
sort_indices2_l = []
|
||||
|
||||
for i in range(w2.shape[0]):
|
||||
if has_zp:
|
||||
w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize(
|
||||
w2[i].transpose(1, 0), quant_type, group_size)
|
||||
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
zeros2_l.append(zeros2)
|
||||
else:
|
||||
test_perm = torch.randperm(n)
|
||||
quant_res = marlin_quantize(w2[i].transpose(1, 0), quant_type,
|
||||
group_size, act_order, test_perm)
|
||||
w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = quant_res
|
||||
|
||||
w_ref2_l.append(w_ref2.T)
|
||||
qweight2_l.append(qweight2)
|
||||
scales2_l.append(scales2)
|
||||
g_idx2_l.append(g_idx2)
|
||||
sort_indices2_l.append(sort_indices2)
|
||||
|
||||
w_ref2 = stack_and_dev(w_ref2_l)
|
||||
qweight2 = stack_and_dev(qweight2_l).contiguous()
|
||||
scales2 = stack_and_dev(scales2_l)
|
||||
g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None
|
||||
zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None
|
||||
sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids = fused_topk(a, score, topk, False)
|
||||
|
||||
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)
|
||||
|
||||
marlin_output = torch.ops.vllm.fused_marlin_moe(
|
||||
a,
|
||||
qweight1,
|
||||
qweight2,
|
||||
scales1,
|
||||
scales2,
|
||||
score,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
g_idx1=g_idx1,
|
||||
g_idx2=g_idx2,
|
||||
sort_indices1=sort_indices1,
|
||||
sort_indices2=sort_indices2,
|
||||
w1_zeros=zeros1,
|
||||
w2_zeros=zeros2,
|
||||
num_bits=num_bits,
|
||||
is_k_full=is_k_full)
|
||||
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0)
|
||||
|
||||
|
||||
@pytest.mark.skip("This test is here for the sake of debugging, "
|
||||
"don't run it in automated tests.")
|
||||
@pytest.mark.parametrize("m", [1, 33, 123])
|
||||
@pytest.mark.parametrize("n", [128, 1024])
|
||||
@pytest.mark.parametrize("k", [256, 2048])
|
||||
@pytest.mark.parametrize("e", [4, 12])
|
||||
@pytest.mark.parametrize("topk", [2, 3])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 128])
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||
@pytest.mark.parametrize("has_zp", [True, False])
|
||||
@pytest.mark.parametrize("is_k_full", [True, False])
|
||||
def test_single_marlin_moe_multiply(m: int, n: int, k: int, e: int, topk: int,
|
||||
dtype: torch.dtype, group_size: int,
|
||||
act_order: bool, num_bits: int,
|
||||
has_zp: bool, is_k_full: bool):
|
||||
# Filter act_order
|
||||
if act_order:
|
||||
if group_size == -1:
|
||||
return
|
||||
if group_size in (k, n):
|
||||
return
|
||||
if has_zp:
|
||||
return
|
||||
else:
|
||||
if not is_k_full:
|
||||
return
|
||||
|
||||
if has_zp:
|
||||
quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8
|
||||
else:
|
||||
quant_type = scalar_types.uint4b8 \
|
||||
if num_bits == 4 else scalar_types.uint8b128
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
w_ref_l = []
|
||||
qweight_l = []
|
||||
scales_l = []
|
||||
zeros_l = []
|
||||
g_idx_l = []
|
||||
sort_indices_l = []
|
||||
|
||||
for i in range(w.shape[0]):
|
||||
if has_zp:
|
||||
w_ref, qweight, scales, zeros = awq_marlin_quantize(
|
||||
w[i].transpose(1, 0), quant_type, group_size)
|
||||
|
||||
w_ref_l.append(w_ref.T)
|
||||
qweight_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
zeros_l.append(zeros)
|
||||
else:
|
||||
test_perm = torch.randperm(k)
|
||||
w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize(
|
||||
w[i].transpose(1, 0), quant_type, group_size, act_order,
|
||||
test_perm)
|
||||
|
||||
w_ref_l.append(w_ref.T)
|
||||
qweight_l.append(qweight)
|
||||
scales_l.append(scales)
|
||||
g_idx_l.append(g_idx)
|
||||
sort_indices_l.append(sort_indices)
|
||||
|
||||
w_ref = stack_and_dev(w_ref_l)
|
||||
qweight = stack_and_dev(qweight_l).contiguous()
|
||||
scales = stack_and_dev(scales_l)
|
||||
g_idx = stack_and_dev(g_idx_l) if g_idx_l else None
|
||||
zeros = stack_and_dev(zeros_l) if zeros_l else None
|
||||
sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
marlin_output = torch.ops.vllm.single_marlin_moe(
|
||||
a,
|
||||
qweight,
|
||||
scales,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
g_idx=g_idx,
|
||||
sort_indices=sort_indices,
|
||||
w_zeros=zeros,
|
||||
num_bits=num_bits,
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
|
||||
torch_output = torch_moe_single(a, w_ref, score, topk)
|
||||
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0)
|
||||
|
||||
|
||||
def test_moe_align_block_size_opcheck():
|
||||
num_experts = 4
|
||||
block_size = 4
|
||||
topk_ids = torch.randint(0,
|
||||
num_experts, (3, 4),
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
sorted_ids = torch.empty((max_num_tokens_padded, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
expert_ids = torch.empty((max_num_m_blocks, ),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
num_tokens_post_pad = torch.empty((1),
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device)
|
||||
|
||||
opcheck(torch.ops._moe_C.moe_align_block_size,
|
||||
(topk_ids, num_experts, block_size, sorted_ids, expert_ids,
|
||||
num_tokens_post_pad))
|
||||
159
tests/kernels/moe/test_triton_moe_ptpc_fp8.py
Normal file
159
tests/kernels/moe/test_triton_moe_ptpc_fp8.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from https://github.com/sgl-project/sglang/blob/main/test/srt/test_triton_moe_channel_fp8_kernel.py
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.get_device_capability() < (9, 0):
|
||||
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
|
||||
allow_module_level=True)
|
||||
|
||||
|
||||
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
|
||||
"""Matrix multiplication function that supports per-token input
|
||||
quantization and per-column weight quantization"""
|
||||
A = A.to(torch.float32)
|
||||
B = B.to(torch.float32)
|
||||
|
||||
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
|
||||
assert B.ndim == 2 and B.is_contiguous(
|
||||
), "B must be a 2D contiguous tensor"
|
||||
|
||||
# Reshape input
|
||||
M = A.numel() // A.shape[-1]
|
||||
B = B.t() # Transpose weight matrix
|
||||
N, K = B.shape
|
||||
origin_C_shape = A.shape[:-1] + (K, )
|
||||
A = A.reshape(M, N)
|
||||
|
||||
# As is per-token [M, 1], Bs is per-column [1, K]
|
||||
C = torch.matmul(A, B) # [M, K]
|
||||
C = As * C * Bs.view(1, -1) # Broadcast per-column scale
|
||||
|
||||
return C.reshape(origin_C_shape).to(output_dtype)
|
||||
|
||||
|
||||
def fp8_mask(a, mask):
|
||||
dtype = a.dtype
|
||||
return a.view(torch.int8)[mask].view(dtype)
|
||||
|
||||
|
||||
def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
|
||||
"""This function performs fused moe with per-column int8
|
||||
quantization using native torch."""
|
||||
|
||||
B, D = a.shape
|
||||
# Perform per-token quantization
|
||||
a_q, a_s = ops.scaled_fp8_quant(a, use_per_token_if_dynamic=True)
|
||||
# Repeat tokens to match topk
|
||||
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
|
||||
# Also repeat the scale
|
||||
a_s = a_s.view(B, -1, 1).repeat(1, topk, 1).reshape(-1, 1) # [B*topk, 1]
|
||||
|
||||
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
|
||||
|
||||
# Calculate routing
|
||||
score = torch.softmax(score, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk)
|
||||
topk_weight = topk_weight.view(-1)
|
||||
topk_ids = topk_ids.view(-1)
|
||||
# Process each expert
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
# First MLP layer: note that a_s is now per-token
|
||||
inter_out = native_w8a8_per_token_matmul(
|
||||
fp8_mask(a_q, mask),
|
||||
w1[i],
|
||||
fp8_mask(a_s, mask),
|
||||
w1_s[i],
|
||||
output_dtype=a.dtype,
|
||||
)
|
||||
# Activation function
|
||||
act_out = SiluAndMul().forward_native(inter_out)
|
||||
# Quantize activation output with per-token
|
||||
act_out_q, act_out_s = ops.scaled_fp8_quant(
|
||||
act_out, use_per_token_if_dynamic=True)
|
||||
|
||||
# Second MLP layer
|
||||
out[mask] = native_w8a8_per_token_matmul(act_out_q,
|
||||
w2[i],
|
||||
act_out_s,
|
||||
w2_s[i],
|
||||
output_dtype=a.dtype)
|
||||
# Apply routing weights and sum
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="module")
|
||||
def setup_cuda():
|
||||
"""Sets the default CUDA device for all tests in this module."""
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
M = [1, 33]
|
||||
N = [128, 1024]
|
||||
K = [256, 4096]
|
||||
E = [8]
|
||||
TOP_KS = [2, 6]
|
||||
SEEDS = [0]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N, K, E, topk, dtype, seed",
|
||||
itertools.product(M, N, K, E, TOP_KS, DTYPES, SEEDS))
|
||||
@torch.inference_mode()
|
||||
def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed):
|
||||
torch.manual_seed(seed)
|
||||
# Initialize int8 quantization parameters
|
||||
factor_for_scale = 1e-2
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max = finfo.max
|
||||
fp8_min = finfo.min
|
||||
|
||||
# Input tensor
|
||||
# M * K
|
||||
a = torch.randn((M, K), dtype=dtype) / 10
|
||||
|
||||
# Generate int8 weights
|
||||
w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2
|
||||
w1 = (w1_fp32 * fp8_max).clamp(min=fp8_min,
|
||||
max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2
|
||||
w2 = (w2_fp32 * fp8_max).clamp(min=fp8_min,
|
||||
max=fp8_max).to(torch.float8_e4m3fn)
|
||||
|
||||
# Generate scale for each column (per-column quantization)
|
||||
w1_s = torch.rand(E, 2 * N, device=w1_fp32.device) * factor_for_scale
|
||||
w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk)
|
||||
out = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
use_fp8_w8a8=True, # using fp8
|
||||
per_channel_quant=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
block_shape=None, # Not using block quantization
|
||||
)
|
||||
|
||||
# Check results
|
||||
rel_diff = (torch.mean(
|
||||
torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) /
|
||||
torch.mean(torch.abs(ref_out.to(torch.float32))))
|
||||
assert rel_diff < 0.05
|
||||
Reference in New Issue
Block a user