[NVIDIA] Add SM100 Flashinfer Cutlass MoE fp8 backend (#22357)
Signed-off-by: Amir Klein <203507526+amirkl94@users.noreply.github.com>
This commit is contained in:
248
tests/kernels/moe/test_flashinfer.py
Normal file
248
tests/kernels/moe/test_flashinfer.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
apply_flashinfer_per_tensor_scale_fp8, flashinfer_cutlass_moe_fp8,
|
||||
register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights,
|
||||
swap_w13_to_w31)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
input_to_float8)
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
if not has_flashinfer_cutlass_fused_moe(
|
||||
) or not current_platform.has_device_capability(100):
|
||||
pytest.skip("Requires flashinfer_cutlass_fused_moe and nvfp4 support",
|
||||
allow_module_level=True)
|
||||
|
||||
NUM_EXPERTS = [16]
|
||||
TOP_KS = [1]
|
||||
|
||||
MNK_FACTORS = [
|
||||
(256, 8192, 5120),
|
||||
(256, 4096, 5120),
|
||||
(127, 8192, 5120),
|
||||
(127, 4096, 5120),
|
||||
(10, 8192, 5120),
|
||||
(10, 4096, 5120),
|
||||
(1, 8192, 5120),
|
||||
(1, 4096, 5120),
|
||||
]
|
||||
|
||||
vllm_config = VllmConfig(parallel_config=ParallelConfig(
|
||||
pipeline_parallel_size=1))
|
||||
vllm_config.scheduler_config.max_num_seqs = 128
|
||||
vllm_config.scheduler_config.max_model_len = 8192
|
||||
|
||||
|
||||
def quant_fp8_per_tensor_batches(a):
|
||||
num_batches = a.size(0)
|
||||
a_quant = []
|
||||
a_scales = []
|
||||
|
||||
for i in range(num_batches):
|
||||
a_fp8, a_global_sf = input_to_float8(a[i])
|
||||
a_global_sf = 1.0 / a_global_sf
|
||||
a_quant.append(a_fp8)
|
||||
a_scales.append(a_global_sf)
|
||||
|
||||
result_a_quant = torch.stack(a_quant)
|
||||
result_a_scales = torch.stack(a_scales)
|
||||
|
||||
return result_a_quant, result_a_scales
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestData:
|
||||
hidden_states: torch.Tensor
|
||||
w13_quantized: torch.Tensor
|
||||
w2_quantized: torch.Tensor
|
||||
a1_scale: torch.Tensor
|
||||
a2_scale: torch.Tensor
|
||||
w13_weight_scale: torch.Tensor
|
||||
w2_weight_scale: torch.Tensor
|
||||
layer: torch.nn.Module
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors_8bit(m: int, k: int, n: int, e: int,
|
||||
reorder: bool) -> "TestData":
|
||||
hidden_states = torch.randn(
|
||||
(m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||
w13 = torch.randn((e, 2 * n, k), device="cuda", dtype=torch.bfloat16)
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Scale to fp8
|
||||
_, a1_scale = input_to_float8(hidden_states)
|
||||
a1_scale = 1.0 / a1_scale
|
||||
a2_scale = torch.scalar_tensor(1.0).to(device="cuda").to(
|
||||
dtype=torch.float32)
|
||||
w13_quantized, w13_weight_scale = quant_fp8_per_tensor_batches(w13)
|
||||
w2_quantized, w2_weight_scale = quant_fp8_per_tensor_batches(w2)
|
||||
|
||||
layer = torch.nn.Module()
|
||||
layer.w13_weight = w13_quantized.clone()
|
||||
layer.w2_weight = w2_quantized.clone()
|
||||
layer.w13_input_scale = a1_scale
|
||||
layer.w2_input_scale = a2_scale
|
||||
layer.w13_weight_scale = w13_weight_scale
|
||||
layer.w2_weight_scale = w2_weight_scale
|
||||
|
||||
register_moe_scaling_factors(layer)
|
||||
|
||||
# flashinfer expects swapped rows for w13
|
||||
layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data)
|
||||
if reorder:
|
||||
rotate_flashinfer_fp8_moe_weights(layer.w13_weight,
|
||||
layer.w2_weight)
|
||||
layer.custom_routing_function = Llama4MoE.custom_routing_function
|
||||
layer.intermediate_size_per_partition = n
|
||||
layer.ep_rank = 0
|
||||
layer.local_num_experts = e
|
||||
|
||||
return TestData(
|
||||
hidden_states=hidden_states,
|
||||
w13_quantized=w13_quantized,
|
||||
w2_quantized=w2_quantized,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
w13_weight_scale=w13_weight_scale,
|
||||
w2_weight_scale=w2_weight_scale,
|
||||
layer=layer,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
monkeypatch,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=True)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=td.hidden_states,
|
||||
router_logits=score,
|
||||
use_grouped_topk=False,
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||
scoring_func="softmax")
|
||||
|
||||
output = fused_experts(
|
||||
td.hidden_states,
|
||||
td.w13_quantized,
|
||||
td.w2_quantized,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=False,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
w1_scale=td.w13_weight_scale,
|
||||
w2_scale=td.w2_weight_scale,
|
||||
a1_scale=td.a1_scale,
|
||||
a2_scale=td.a2_scale,
|
||||
apply_router_weight_on_input=True,
|
||||
)
|
||||
|
||||
flashinfer_output = apply_flashinfer_per_tensor_scale_fp8(
|
||||
layer=td.layer,
|
||||
hidden_states=td.hidden_states,
|
||||
router_logits=score,
|
||||
routing_bias=None,
|
||||
global_num_experts=e,
|
||||
top_k=topk,
|
||||
num_expert_group=None,
|
||||
topk_group=None,
|
||||
apply_router_weight_on_input=True)
|
||||
|
||||
torch.testing.assert_close(output,
|
||||
flashinfer_output,
|
||||
atol=5.5e-2,
|
||||
rtol=1e-2)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
"Requires flashinfer version that contains https://github.com/flashinfer-ai/flashinfer/pull/1472"
|
||||
)
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
monkeypatch,
|
||||
):
|
||||
current_platform.seed_everything(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
td = TestData.make_moe_tensors_8bit(m, k, n, e, reorder=False)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=torch.bfloat16)
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=td.hidden_states,
|
||||
router_logits=score,
|
||||
use_grouped_topk=False,
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||
scoring_func="softmax")
|
||||
|
||||
output = fused_experts(
|
||||
td.hidden_states,
|
||||
td.w13_quantized,
|
||||
td.w2_quantized,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=False,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
w1_scale=td.w13_weight_scale,
|
||||
w2_scale=td.w2_weight_scale,
|
||||
a1_scale=td.a1_scale,
|
||||
a2_scale=td.a2_scale,
|
||||
apply_router_weight_on_input=True,
|
||||
)
|
||||
|
||||
td.layer.dp_size = 1
|
||||
|
||||
flashinfer_cutlass_output = flashinfer_cutlass_moe_fp8(
|
||||
td.hidden_states,
|
||||
td.layer,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation="silu",
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=True,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output,
|
||||
flashinfer_cutlass_output,
|
||||
atol=5.5e-2,
|
||||
rtol=1e-2)
|
||||
Reference in New Issue
Block a user