Signed-off-by: EdalatiAli <aliedalati@cohere.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
238 lines
7.3 KiB
Python
238 lines
7.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# Adapted from SGLang:
|
|
# https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/tests/test_es_fp8_blockwise_moe.py
|
|
|
|
"""Tests for SM100 CUTLASS MXFP8 grouped MoE kernels."""
|
|
|
|
import random
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from tests.kernels.utils import torch_moe_single
|
|
from vllm import _custom_ops as ops
|
|
from vllm.platforms import current_platform
|
|
from vllm.utils.torch_utils import set_random_seed
|
|
|
|
random.seed(42)
|
|
set_random_seed(42)
|
|
|
|
|
|
def align(val: int, alignment: int = 128) -> int:
|
|
return int((val + alignment - 1) // alignment * alignment)
|
|
|
|
|
|
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
|
|
def calc_diff(x, y):
|
|
x, y = x.double(), y.double()
|
|
denominator = (x * x + y * y).sum()
|
|
sim = 2 * (x * y).sum() / denominator
|
|
return 1 - sim
|
|
|
|
|
|
def is_sm100_supported() -> bool:
|
|
return current_platform.is_cuda() and current_platform.is_device_capability_family(
|
|
100
|
|
)
|
|
|
|
|
|
def compute_ref_output(
|
|
input_tensor: torch.Tensor,
|
|
weight_list: list[torch.Tensor],
|
|
expert_offsets: list[int],
|
|
expert_offset: int,
|
|
num_experts: int,
|
|
) -> torch.Tensor:
|
|
# Build a top-1 routing score so each token maps to its owning expert.
|
|
score = torch.full(
|
|
(expert_offset, num_experts),
|
|
-1e9,
|
|
device=input_tensor.device,
|
|
dtype=torch.float32,
|
|
)
|
|
for g in range(num_experts):
|
|
start = expert_offsets[g]
|
|
end = expert_offsets[g + 1] if g + 1 < num_experts else expert_offset
|
|
score[start:end, g] = 0.0
|
|
|
|
return torch_moe_single(
|
|
input_tensor, torch.stack(weight_list, dim=0), score, topk=1
|
|
)
|
|
|
|
|
|
def compute_kernel_output(
|
|
input_tensor: torch.Tensor,
|
|
weight_tensor: torch.Tensor,
|
|
problem_sizes: list[list[int]],
|
|
aux_problem_sizes: list[list[int]],
|
|
expert_offsets: list[int],
|
|
aux_expert_offsets: list[int],
|
|
input_blockscale_offsets: list[int],
|
|
weight_blockscale_offsets: list[int],
|
|
input_blockscale_offset: int,
|
|
n_g: int,
|
|
k_g: int,
|
|
num_experts: int,
|
|
expert_offset: int,
|
|
out_dtype: torch.dtype,
|
|
) -> torch.Tensor:
|
|
device = input_tensor.device
|
|
_problem_sizes = torch.tensor(problem_sizes).to(device=device, dtype=torch.int32)
|
|
_aux_problem_sizes = torch.tensor(aux_problem_sizes).to(
|
|
device=device, dtype=torch.int32
|
|
)
|
|
_expert_offsets = torch.tensor(expert_offsets).to(device=device, dtype=torch.int32)
|
|
_aux_expert_offsets = torch.tensor(aux_expert_offsets).to(
|
|
device=device, dtype=torch.int32
|
|
)
|
|
_input_blockscale_offsets = torch.tensor(input_blockscale_offsets).to(
|
|
device=device, dtype=torch.int32
|
|
)
|
|
_weight_blockscale_offsets = torch.tensor(weight_blockscale_offsets).to(
|
|
device=device, dtype=torch.int32
|
|
)
|
|
|
|
input_quant = torch.zeros_like(
|
|
input_tensor, dtype=torch.float8_e4m3fn, device=device
|
|
)
|
|
input_scale_factor = torch.zeros(
|
|
(input_blockscale_offset, k_g // 32), dtype=torch.uint8, device=device
|
|
)
|
|
|
|
weight_quant = torch.zeros_like(
|
|
weight_tensor, dtype=torch.float8_e4m3fn, device=device
|
|
)
|
|
weight_scale_factor = torch.zeros(
|
|
(num_experts, n_g, k_g // 32), dtype=torch.uint8, device=device
|
|
)
|
|
|
|
ops.mxfp8_experts_quant(
|
|
input_tensor,
|
|
_problem_sizes,
|
|
_expert_offsets,
|
|
_input_blockscale_offsets,
|
|
input_quant,
|
|
input_scale_factor,
|
|
)
|
|
|
|
ops.mxfp8_experts_quant(
|
|
weight_tensor,
|
|
_aux_problem_sizes,
|
|
_aux_expert_offsets,
|
|
_weight_blockscale_offsets,
|
|
weight_quant,
|
|
weight_scale_factor,
|
|
)
|
|
weight_quant = weight_quant.view(num_experts, n_g, k_g).transpose(1, 2)
|
|
weight_scale_factor = weight_scale_factor.view(
|
|
num_experts, n_g, k_g // 32
|
|
).transpose(1, 2)
|
|
|
|
output = torch.empty((expert_offset, n_g), device=device, dtype=out_dtype)
|
|
ops.cutlass_mxfp8_grouped_mm(
|
|
input_quant,
|
|
weight_quant,
|
|
input_scale_factor,
|
|
weight_scale_factor,
|
|
output,
|
|
_problem_sizes,
|
|
_expert_offsets,
|
|
_input_blockscale_offsets,
|
|
)
|
|
return output
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not is_sm100_supported(),
|
|
reason=(
|
|
"cutlass_mxfp8_grouped_mm and mxfp8_experts_quant "
|
|
"are only supported on CUDA SM100"
|
|
),
|
|
)
|
|
@pytest.mark.parametrize("num_experts", [8, 16, 32, 64])
|
|
@pytest.mark.parametrize("out_dtype", [torch.half, torch.bfloat16])
|
|
def test_cutlass_mxfp8_grouped_mm(num_experts, out_dtype):
|
|
device = "cuda"
|
|
alignment = 128
|
|
n_g = random.randint(1, 64) * alignment
|
|
k_g = random.randint(1, 64) * alignment
|
|
|
|
expert_offset = 0
|
|
expert_offsets = []
|
|
aux_expert_offset = 0
|
|
aux_expert_offsets = []
|
|
input_blockscale_offset = 0
|
|
input_blockscale_offsets = []
|
|
weight_blockscale_offset = 0
|
|
weight_blockscale_offsets = []
|
|
problem_sizes = []
|
|
aux_problem_sizes = []
|
|
input_list = []
|
|
weight_list = []
|
|
|
|
for g in range(num_experts):
|
|
m_g = random.randint(1, 512)
|
|
expert_offsets.append(expert_offset)
|
|
expert_offset += m_g
|
|
aux_expert_offsets.append(aux_expert_offset)
|
|
aux_expert_offset += n_g
|
|
input_blockscale_offsets.append(input_blockscale_offset)
|
|
input_blockscale_offset += align(m_g, 128)
|
|
weight_blockscale_offsets.append(weight_blockscale_offset)
|
|
weight_blockscale_offset += n_g # n_g already align to 128
|
|
problem_sizes.append([m_g, n_g, k_g])
|
|
aux_problem_sizes.append([n_g, m_g, k_g])
|
|
|
|
input_tensor = torch.normal(
|
|
0.0, std=1.0, size=(m_g, k_g), device=device, dtype=out_dtype
|
|
) # (M, K):(K, 1)
|
|
weight_tensor = torch.normal(
|
|
0.0, std=1.0, size=(n_g, k_g), device=device, dtype=out_dtype
|
|
) # (N, K):(K, 1)
|
|
|
|
input_list.append(input_tensor)
|
|
weight_list.append(weight_tensor)
|
|
input_tensor = torch.concat(input_list, dim=0)
|
|
weight_tensor = torch.concat(weight_list, dim=0)
|
|
|
|
ref_output = compute_ref_output(
|
|
input_tensor=input_tensor,
|
|
weight_list=weight_list,
|
|
expert_offsets=expert_offsets,
|
|
expert_offset=expert_offset,
|
|
num_experts=num_experts,
|
|
)
|
|
output = compute_kernel_output(
|
|
input_tensor=input_tensor,
|
|
weight_tensor=weight_tensor,
|
|
problem_sizes=problem_sizes,
|
|
aux_problem_sizes=aux_problem_sizes,
|
|
expert_offsets=expert_offsets,
|
|
aux_expert_offsets=aux_expert_offsets,
|
|
input_blockscale_offsets=input_blockscale_offsets,
|
|
weight_blockscale_offsets=weight_blockscale_offsets,
|
|
input_blockscale_offset=input_blockscale_offset,
|
|
n_g=n_g,
|
|
k_g=k_g,
|
|
num_experts=num_experts,
|
|
expert_offset=expert_offset,
|
|
out_dtype=out_dtype,
|
|
)
|
|
|
|
for g in range(num_experts):
|
|
baseline = ref_output[
|
|
expert_offsets[g] : (expert_offsets[g] + problem_sizes[g][0])
|
|
]
|
|
actual = output[expert_offsets[g] : (expert_offsets[g] + problem_sizes[g][0])]
|
|
diff = calc_diff(actual, baseline)
|
|
assert diff < 0.001
|
|
print(
|
|
f"m_g={baseline.shape[0]} n_g={n_g} k_g={k_g} num_experts={num_experts}, "
|
|
f"out_dtype={out_dtype}, diff={diff:.5f}: OK"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|