[Kernel] [Helion] [4/N] Add silu_mul_fp8 Helion kernel (#33373)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
331
tests/kernels/helion/test_silu_mul_fp8.py
Normal file
331
tests/kernels/helion/test_silu_mul_fp8.py
Normal file
@@ -0,0 +1,331 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.utils.import_utils import has_helion
|
||||
|
||||
if not has_helion():
|
||||
pytest.skip(
|
||||
"Helion is not installed. Install with: pip install vllm[helion]",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
from vllm.kernels.helion.config_manager import ConfigManager
|
||||
from vllm.kernels.helion.ops.silu_mul_fp8 import (
|
||||
pick_silu_mul_fp8_config,
|
||||
silu_mul_fp8,
|
||||
silu_mul_fp8_baseline,
|
||||
)
|
||||
|
||||
|
||||
def skip_if_platform_unsupported():
|
||||
try:
|
||||
from vllm.kernels.helion.utils import get_canonical_gpu_name
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
|
||||
platform = get_canonical_gpu_name()
|
||||
|
||||
try:
|
||||
config_manager = ConfigManager.get_instance()
|
||||
except RuntimeError:
|
||||
config_manager = ConfigManager()
|
||||
|
||||
configs = config_manager.get_platform_configs("silu_mul_fp8", platform)
|
||||
if len(configs) == 0:
|
||||
pytest.skip("Current GPU platform not supported for silu_mul_fp8 kernel")
|
||||
|
||||
except (ImportError, RuntimeError, KeyError):
|
||||
pytest.skip("Error detecting platform support for silu_mul_fp8 kernel")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_config_manager_singleton():
|
||||
ConfigManager.reset_instance()
|
||||
ConfigManager()
|
||||
yield
|
||||
ConfigManager.reset_instance()
|
||||
|
||||
|
||||
class TestSiluMulFp8ConfigPicker:
|
||||
def test_config_picker_exact_match(self):
|
||||
config_keys = [
|
||||
"intermediate_2048_batchsize_256",
|
||||
"intermediate_4096_batchsize_256",
|
||||
]
|
||||
|
||||
input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
args = (input_tensor, scale)
|
||||
|
||||
selected_key = pick_silu_mul_fp8_config(args, config_keys)
|
||||
assert selected_key == "intermediate_2048_batchsize_256"
|
||||
|
||||
def test_config_picker_closest_match(self):
|
||||
config_keys = [
|
||||
"intermediate_2048_batchsize_256",
|
||||
"intermediate_4096_batchsize_256",
|
||||
]
|
||||
# Use 7000 (intermediate_size=3500) which is closer to 4096 than 2048
|
||||
input_tensor = torch.randn(32, 7000, dtype=torch.bfloat16, device="cuda")
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
args = (input_tensor, scale)
|
||||
|
||||
selected_key = pick_silu_mul_fp8_config(args, config_keys)
|
||||
assert selected_key == "intermediate_4096_batchsize_256"
|
||||
|
||||
def test_config_picker_fallback_to_default(self):
|
||||
config_keys = ["default", "some_other_key"]
|
||||
|
||||
input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
args = (input_tensor, scale)
|
||||
|
||||
selected_key = pick_silu_mul_fp8_config(args, config_keys)
|
||||
assert selected_key == "default"
|
||||
|
||||
def test_config_picker_no_configs(self):
|
||||
config_keys: list[str] = []
|
||||
|
||||
input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
args = (input_tensor, scale)
|
||||
|
||||
selected_key = pick_silu_mul_fp8_config(args, config_keys)
|
||||
assert selected_key is None
|
||||
|
||||
@pytest.mark.parametrize("intermediate_size", [2048, 4096, 5120])
|
||||
def test_config_picker_different_sizes(self, intermediate_size):
|
||||
config_keys = [
|
||||
"intermediate_2048_batchsize_256",
|
||||
"intermediate_4096_batchsize_256",
|
||||
"intermediate_5120_batchsize_256",
|
||||
]
|
||||
|
||||
input_tensor = torch.randn(
|
||||
32, 2 * intermediate_size, dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
args = (input_tensor, scale)
|
||||
|
||||
selected_key = pick_silu_mul_fp8_config(args, config_keys)
|
||||
expected_key = f"intermediate_{intermediate_size}_batchsize_256"
|
||||
assert selected_key == expected_key
|
||||
|
||||
|
||||
class TestSiluMulFp8Correctness:
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
|
||||
@pytest.mark.parametrize("intermediate_size", [2048, 3000, 3500, 4096, 5000])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_silu_mul_fp8_correctness(self, batch_size, intermediate_size, dtype):
|
||||
skip_if_platform_unsupported()
|
||||
|
||||
input_size = 2 * intermediate_size
|
||||
input_tensor = torch.randn(batch_size, input_size, dtype=dtype, device="cuda")
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
|
||||
reference_output = silu_mul_fp8_baseline(input_tensor, scale)
|
||||
helion_output = silu_mul_fp8(input_tensor, scale)
|
||||
|
||||
assert helion_output.shape == reference_output.shape
|
||||
assert helion_output.dtype == torch.float8_e4m3fn
|
||||
assert reference_output.dtype == torch.float8_e4m3fn
|
||||
|
||||
ref_f32 = reference_output.to(torch.float32)
|
||||
helion_f32 = helion_output.to(torch.float32)
|
||||
# FP8 E4M3 has limited precision. Values near quantization boundaries
|
||||
# can round differently due to intermediate precision differences.
|
||||
torch.testing.assert_close(
|
||||
helion_f32,
|
||||
ref_f32,
|
||||
atol=0.05,
|
||||
rtol=0.05,
|
||||
msg=f"Mismatch at batch={batch_size}, size={intermediate_size}",
|
||||
)
|
||||
|
||||
def test_silu_mul_fp8_shape_inference(self):
|
||||
skip_if_platform_unsupported()
|
||||
batch_size, input_size = 32, 8192
|
||||
intermediate_size = input_size // 2
|
||||
|
||||
input_tensor = torch.randn(
|
||||
batch_size, input_size, dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
|
||||
output = silu_mul_fp8(input_tensor, scale)
|
||||
|
||||
expected_shape = (batch_size, intermediate_size)
|
||||
assert output.shape == expected_shape
|
||||
assert output.dtype == torch.float8_e4m3fn
|
||||
|
||||
def test_silu_mul_fp8_scale_variations(self):
|
||||
skip_if_platform_unsupported()
|
||||
batch_size, input_size = 16, 4096
|
||||
|
||||
input_tensor = torch.randn(
|
||||
batch_size, input_size, dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
|
||||
scales = [0.1, 0.5, 1.0, 2.0, 10.0]
|
||||
|
||||
for scale_val in scales:
|
||||
scale = torch.tensor([scale_val], dtype=torch.float32, device="cuda")
|
||||
|
||||
reference_output = silu_mul_fp8_baseline(input_tensor, scale)
|
||||
helion_output = silu_mul_fp8(input_tensor, scale)
|
||||
ref_f32 = reference_output.to(torch.float32)
|
||||
helion_f32 = helion_output.to(torch.float32)
|
||||
|
||||
torch.testing.assert_close(
|
||||
helion_f32,
|
||||
ref_f32,
|
||||
atol=0.05,
|
||||
rtol=0.05,
|
||||
msg=f"Mismatch for scale={scale_val}",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shape",
|
||||
[
|
||||
(1, 4096),
|
||||
(16, 4096),
|
||||
(128, 4096),
|
||||
(1024, 4096),
|
||||
(1, 8192),
|
||||
(16, 8192),
|
||||
(128, 8192),
|
||||
],
|
||||
)
|
||||
def test_silu_mul_fp8_various_shapes(self, shape):
|
||||
skip_if_platform_unsupported()
|
||||
|
||||
input_tensor = torch.randn(*shape, dtype=torch.bfloat16, device="cuda")
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
|
||||
reference_output = silu_mul_fp8_baseline(input_tensor, scale)
|
||||
helion_output = silu_mul_fp8(input_tensor, scale)
|
||||
|
||||
assert helion_output.shape == reference_output.shape
|
||||
|
||||
ref_f32 = reference_output.to(torch.float32)
|
||||
helion_f32 = helion_output.to(torch.float32)
|
||||
|
||||
torch.testing.assert_close(
|
||||
helion_f32, ref_f32, atol=0.05, rtol=0.05, msg=f"Mismatch for shape={shape}"
|
||||
)
|
||||
|
||||
|
||||
def silu_mul_fp8_pytorch(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
"""Pure PyTorch reference using F.silu.
|
||||
|
||||
This matches vLLM's SiluAndMul.forward_native exactly:
|
||||
F.silu(x[..., :d]) * x[..., d:]
|
||||
"""
|
||||
d = input.shape[-1] // 2
|
||||
result = F.silu(input[..., :d]) * input[..., d:]
|
||||
return (result.to(torch.float32) / scale).to(torch.float8_e4m3fn)
|
||||
|
||||
|
||||
class TestSiluMulFp8PytorchReference:
|
||||
"""Tests comparing Helion kernel against pure PyTorch implementation.
|
||||
|
||||
Uses tighter tolerance since both use PyTorch's FP8 conversion
|
||||
(same rounding mode), unlike the vLLM C++ baseline which uses
|
||||
NVIDIA's hardware FP8 conversion with different rounding.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 256])
|
||||
@pytest.mark.parametrize("intermediate_size", [1024, 2048, 4096])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_silu_mul_fp8_vs_pytorch(self, batch_size, intermediate_size, dtype):
|
||||
skip_if_platform_unsupported()
|
||||
|
||||
input_tensor = torch.randn(
|
||||
batch_size, 2 * intermediate_size, dtype=dtype, device="cuda"
|
||||
)
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
|
||||
pytorch_output = silu_mul_fp8_pytorch(input_tensor, scale)
|
||||
helion_output = silu_mul_fp8(input_tensor, scale)
|
||||
|
||||
assert helion_output.shape == pytorch_output.shape
|
||||
assert helion_output.dtype == torch.float8_e4m3fn
|
||||
|
||||
pytorch_f32 = pytorch_output.to(torch.float32)
|
||||
helion_f32 = helion_output.to(torch.float32)
|
||||
|
||||
# Tolerance accounts for FP8 quantization boundary effects
|
||||
torch.testing.assert_close(
|
||||
helion_f32,
|
||||
pytorch_f32,
|
||||
atol=0.05,
|
||||
rtol=0.05,
|
||||
msg=(
|
||||
f"Mismatch at batch={batch_size}, size={intermediate_size}, "
|
||||
f"dtype={dtype}"
|
||||
),
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shape",
|
||||
[
|
||||
(1, 2, 4096), # 3D input
|
||||
(2, 4, 2048), # 3D input
|
||||
(1, 1, 1, 8192), # 4D input
|
||||
],
|
||||
)
|
||||
def test_silu_mul_fp8_multidim_vs_pytorch(self, shape):
|
||||
skip_if_platform_unsupported()
|
||||
|
||||
input_tensor = torch.randn(*shape, dtype=torch.bfloat16, device="cuda")
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
|
||||
pytorch_output = silu_mul_fp8_pytorch(input_tensor, scale)
|
||||
helion_output = silu_mul_fp8(input_tensor, scale)
|
||||
|
||||
assert helion_output.shape == pytorch_output.shape
|
||||
|
||||
pytorch_f32 = pytorch_output.to(torch.float32)
|
||||
helion_f32 = helion_output.to(torch.float32)
|
||||
|
||||
torch.testing.assert_close(
|
||||
helion_f32,
|
||||
pytorch_f32,
|
||||
atol=0.05,
|
||||
rtol=0.05,
|
||||
msg=f"Mismatch for shape={shape}",
|
||||
)
|
||||
|
||||
|
||||
class TestSiluMulFp8Integration:
|
||||
def test_kernel_registration_integration(self):
|
||||
from vllm.kernels.helion.register import get_registered_kernels
|
||||
|
||||
registered_kernels = get_registered_kernels()
|
||||
assert "silu_mul_fp8" in registered_kernels
|
||||
|
||||
kernel_wrapper = registered_kernels["silu_mul_fp8"]
|
||||
assert kernel_wrapper.op_name == "silu_mul_fp8"
|
||||
assert kernel_wrapper._config_picker is not None
|
||||
|
||||
def test_fake_impl_functionality(self):
|
||||
skip_if_platform_unsupported()
|
||||
from vllm.kernels.helion.register import get_registered_kernels
|
||||
|
||||
input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
registered_kernels = get_registered_kernels()
|
||||
kernel_wrapper = registered_kernels["silu_mul_fp8"]
|
||||
fake_impl = kernel_wrapper._fake_impl
|
||||
|
||||
fake_output = fake_impl(input_tensor, scale)
|
||||
|
||||
expected_shape = (32, 2048)
|
||||
assert fake_output.shape == expected_shape
|
||||
assert fake_output.dtype == torch.float8_e4m3fn
|
||||
assert fake_output.device == input_tensor.device
|
||||
Reference in New Issue
Block a user