396 lines
14 KiB
Python
396 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from vllm.utils.import_utils import has_helion
|
|
|
|
if not has_helion():
|
|
pytest.skip(
|
|
"Helion is not installed. Install with: pip install vllm[helion]",
|
|
allow_module_level=True,
|
|
)
|
|
|
|
from vllm.kernels.helion.config_manager import ConfigManager
|
|
from vllm.kernels.helion.ops.silu_mul_fp8 import (
|
|
pick_silu_mul_fp8_config,
|
|
silu_mul_fp8,
|
|
silu_mul_fp8_baseline,
|
|
)
|
|
|
|
|
|
def skip_if_platform_unsupported():
|
|
try:
|
|
from vllm.kernels.helion.utils import get_canonical_gpu_name
|
|
|
|
if not torch.cuda.is_available():
|
|
pytest.skip("CUDA not available")
|
|
|
|
platform = get_canonical_gpu_name()
|
|
|
|
try:
|
|
config_manager = ConfigManager.get_instance()
|
|
except RuntimeError:
|
|
config_manager = ConfigManager()
|
|
|
|
configs = config_manager.get_platform_configs("silu_mul_fp8", platform)
|
|
if len(configs) == 0:
|
|
pytest.skip("Current GPU platform not supported for silu_mul_fp8 kernel")
|
|
|
|
except (ImportError, RuntimeError, KeyError):
|
|
pytest.skip("Error detecting platform support for silu_mul_fp8 kernel")
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def reset_config_manager_singleton():
|
|
ConfigManager.reset_instance()
|
|
ConfigManager()
|
|
yield
|
|
ConfigManager.reset_instance()
|
|
|
|
|
|
class TestSiluMulFp8ConfigPicker:
|
|
def test_config_picker_exact_match(self):
|
|
config_keys = [
|
|
"intermediate_2048_numtokens_256",
|
|
"intermediate_4096_numtokens_256",
|
|
]
|
|
|
|
input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
|
|
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
|
args = (input_tensor, scale)
|
|
|
|
selected_key = pick_silu_mul_fp8_config(args, config_keys)
|
|
assert selected_key == "intermediate_2048_numtokens_256"
|
|
|
|
def test_config_picker_closest_match(self):
|
|
config_keys = [
|
|
"intermediate_2048_numtokens_256",
|
|
"intermediate_4096_numtokens_256",
|
|
]
|
|
# Use 7000 (intermediate_size=3500) which is closer to 4096 than 2048
|
|
input_tensor = torch.randn(32, 7000, dtype=torch.bfloat16, device="cuda")
|
|
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
|
args = (input_tensor, scale)
|
|
|
|
selected_key = pick_silu_mul_fp8_config(args, config_keys)
|
|
assert selected_key == "intermediate_4096_numtokens_256"
|
|
|
|
def test_config_picker_fallback_to_default(self):
|
|
config_keys = ["default"]
|
|
|
|
input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
|
|
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
|
args = (input_tensor, scale)
|
|
|
|
selected_key = pick_silu_mul_fp8_config(args, config_keys)
|
|
assert selected_key == "default"
|
|
|
|
def test_config_picker_no_configs(self):
|
|
config_keys: list[str] = []
|
|
|
|
input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
|
|
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
|
args = (input_tensor, scale)
|
|
|
|
selected_key = pick_silu_mul_fp8_config(args, config_keys)
|
|
assert selected_key is None
|
|
|
|
@pytest.mark.parametrize("intermediate_size", [2048, 4096, 5120])
|
|
def test_config_picker_different_sizes(self, intermediate_size):
|
|
config_keys = [
|
|
"intermediate_2048_numtokens_256",
|
|
"intermediate_4096_numtokens_256",
|
|
"intermediate_5120_numtokens_256",
|
|
]
|
|
|
|
input_tensor = torch.randn(
|
|
32, 2 * intermediate_size, dtype=torch.bfloat16, device="cuda"
|
|
)
|
|
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
|
args = (input_tensor, scale)
|
|
|
|
selected_key = pick_silu_mul_fp8_config(args, config_keys)
|
|
expected_key = f"intermediate_{intermediate_size}_numtokens_256"
|
|
assert selected_key == expected_key
|
|
|
|
def test_config_picker_numtokens_ceiling(self):
|
|
"""Pick the smallest numtokens >= input num_tokens."""
|
|
config_keys = [
|
|
"intermediate_4096_numtokens_8",
|
|
"intermediate_4096_numtokens_32",
|
|
"intermediate_4096_numtokens_128",
|
|
"intermediate_4096_numtokens_256",
|
|
]
|
|
# 20 tokens -> should pick numtokens_32 (smallest >= 20)
|
|
input_tensor = torch.randn(20, 8192, dtype=torch.bfloat16, device="cuda")
|
|
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
|
|
|
selected_key = pick_silu_mul_fp8_config((input_tensor, scale), config_keys)
|
|
assert selected_key == "intermediate_4096_numtokens_32"
|
|
|
|
def test_config_picker_numtokens_exact(self):
|
|
"""Exact num_tokens match is preferred over ceiling."""
|
|
config_keys = [
|
|
"intermediate_4096_numtokens_8",
|
|
"intermediate_4096_numtokens_32",
|
|
"intermediate_4096_numtokens_128",
|
|
]
|
|
input_tensor = torch.randn(32, 8192, dtype=torch.bfloat16, device="cuda")
|
|
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
|
|
|
selected_key = pick_silu_mul_fp8_config((input_tensor, scale), config_keys)
|
|
assert selected_key == "intermediate_4096_numtokens_32"
|
|
|
|
def test_config_picker_numtokens_fallback_to_largest(self):
|
|
"""Fall back to the largest numtokens when input exceeds all."""
|
|
config_keys = [
|
|
"intermediate_4096_numtokens_8",
|
|
"intermediate_4096_numtokens_32",
|
|
"intermediate_4096_numtokens_128",
|
|
]
|
|
# 512 tokens -> exceeds all available, should pick largest (128)
|
|
input_tensor = torch.randn(512, 8192, dtype=torch.bfloat16, device="cuda")
|
|
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
|
|
|
selected_key = pick_silu_mul_fp8_config((input_tensor, scale), config_keys)
|
|
assert selected_key == "intermediate_4096_numtokens_128"
|
|
|
|
def test_config_picker_malformed_key_raises(self):
|
|
"""Malformed config keys should raise ValueError."""
|
|
config_keys = ["intermediate_4096_badformat_256"]
|
|
input_tensor = torch.randn(32, 8192, dtype=torch.bfloat16, device="cuda")
|
|
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
|
|
|
with pytest.raises(ValueError, match="Malformed config key"):
|
|
pick_silu_mul_fp8_config((input_tensor, scale), config_keys)
|
|
|
|
def test_config_picker_default_ignored_when_valid_keys_exist(self):
|
|
"""'default' is skipped in favor of a real match."""
|
|
config_keys = [
|
|
"default",
|
|
"intermediate_4096_numtokens_32",
|
|
"intermediate_4096_numtokens_128",
|
|
]
|
|
input_tensor = torch.randn(64, 8192, dtype=torch.bfloat16, device="cuda")
|
|
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
|
|
|
selected_key = pick_silu_mul_fp8_config((input_tensor, scale), config_keys)
|
|
assert selected_key == "intermediate_4096_numtokens_128"
|
|
|
|
|
|
class TestSiluMulFp8Correctness:
|
|
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
|
|
@pytest.mark.parametrize("intermediate_size", [2048, 3000, 3500, 4096, 5000])
|
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
|
def test_silu_mul_fp8_correctness(self, batch_size, intermediate_size, dtype):
|
|
skip_if_platform_unsupported()
|
|
|
|
input_size = 2 * intermediate_size
|
|
input_tensor = torch.randn(batch_size, input_size, dtype=dtype, device="cuda")
|
|
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
|
|
|
reference_output = silu_mul_fp8_baseline(input_tensor, scale)
|
|
helion_output = silu_mul_fp8(input_tensor, scale)
|
|
|
|
assert helion_output.shape == reference_output.shape
|
|
assert helion_output.dtype == torch.float8_e4m3fn
|
|
assert reference_output.dtype == torch.float8_e4m3fn
|
|
|
|
ref_f32 = reference_output.to(torch.float32)
|
|
helion_f32 = helion_output.to(torch.float32)
|
|
# FP8 E4M3 has limited precision. Values near quantization boundaries
|
|
# can round differently due to intermediate precision differences.
|
|
torch.testing.assert_close(
|
|
helion_f32,
|
|
ref_f32,
|
|
atol=0.05,
|
|
rtol=0.05,
|
|
msg=f"Mismatch at batch={batch_size}, size={intermediate_size}",
|
|
)
|
|
|
|
def test_silu_mul_fp8_shape_inference(self):
|
|
skip_if_platform_unsupported()
|
|
batch_size, input_size = 32, 8192
|
|
intermediate_size = input_size // 2
|
|
|
|
input_tensor = torch.randn(
|
|
batch_size, input_size, dtype=torch.bfloat16, device="cuda"
|
|
)
|
|
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
|
|
|
output = silu_mul_fp8(input_tensor, scale)
|
|
|
|
expected_shape = (batch_size, intermediate_size)
|
|
assert output.shape == expected_shape
|
|
assert output.dtype == torch.float8_e4m3fn
|
|
|
|
def test_silu_mul_fp8_scale_variations(self):
|
|
skip_if_platform_unsupported()
|
|
batch_size, input_size = 16, 4096
|
|
|
|
input_tensor = torch.randn(
|
|
batch_size, input_size, dtype=torch.bfloat16, device="cuda"
|
|
)
|
|
|
|
scales = [0.1, 0.5, 1.0, 2.0, 10.0]
|
|
|
|
for scale_val in scales:
|
|
scale = torch.tensor([scale_val], dtype=torch.float32, device="cuda")
|
|
|
|
reference_output = silu_mul_fp8_baseline(input_tensor, scale)
|
|
helion_output = silu_mul_fp8(input_tensor, scale)
|
|
ref_f32 = reference_output.to(torch.float32)
|
|
helion_f32 = helion_output.to(torch.float32)
|
|
|
|
torch.testing.assert_close(
|
|
helion_f32,
|
|
ref_f32,
|
|
atol=0.05,
|
|
rtol=0.05,
|
|
msg=f"Mismatch for scale={scale_val}",
|
|
)
|
|
|
|
@pytest.mark.parametrize(
|
|
"shape",
|
|
[
|
|
(1, 4096),
|
|
(16, 4096),
|
|
(128, 4096),
|
|
(1024, 4096),
|
|
(1, 8192),
|
|
(16, 8192),
|
|
(128, 8192),
|
|
],
|
|
)
|
|
def test_silu_mul_fp8_various_shapes(self, shape):
|
|
skip_if_platform_unsupported()
|
|
|
|
input_tensor = torch.randn(*shape, dtype=torch.bfloat16, device="cuda")
|
|
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
|
|
|
reference_output = silu_mul_fp8_baseline(input_tensor, scale)
|
|
helion_output = silu_mul_fp8(input_tensor, scale)
|
|
|
|
assert helion_output.shape == reference_output.shape
|
|
|
|
ref_f32 = reference_output.to(torch.float32)
|
|
helion_f32 = helion_output.to(torch.float32)
|
|
|
|
torch.testing.assert_close(
|
|
helion_f32, ref_f32, atol=0.05, rtol=0.05, msg=f"Mismatch for shape={shape}"
|
|
)
|
|
|
|
|
|
def silu_mul_fp8_pytorch(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
|
"""Pure PyTorch reference using F.silu.
|
|
|
|
This matches vLLM's SiluAndMul.forward_native exactly:
|
|
F.silu(x[..., :d]) * x[..., d:]
|
|
"""
|
|
d = input.shape[-1] // 2
|
|
result = F.silu(input[..., :d]) * input[..., d:]
|
|
return (result.to(torch.float32) / scale).to(torch.float8_e4m3fn)
|
|
|
|
|
|
class TestSiluMulFp8PytorchReference:
|
|
"""Tests comparing Helion kernel against pure PyTorch implementation.
|
|
|
|
Uses tighter tolerance since both use PyTorch's FP8 conversion
|
|
(same rounding mode), unlike the vLLM C++ baseline which uses
|
|
NVIDIA's hardware FP8 conversion with different rounding.
|
|
"""
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 256])
|
|
@pytest.mark.parametrize("intermediate_size", [1024, 2048, 4096])
|
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
|
def test_silu_mul_fp8_vs_pytorch(self, batch_size, intermediate_size, dtype):
|
|
skip_if_platform_unsupported()
|
|
|
|
input_tensor = torch.randn(
|
|
batch_size, 2 * intermediate_size, dtype=dtype, device="cuda"
|
|
)
|
|
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
|
|
|
pytorch_output = silu_mul_fp8_pytorch(input_tensor, scale)
|
|
helion_output = silu_mul_fp8(input_tensor, scale)
|
|
|
|
assert helion_output.shape == pytorch_output.shape
|
|
assert helion_output.dtype == torch.float8_e4m3fn
|
|
|
|
pytorch_f32 = pytorch_output.to(torch.float32)
|
|
helion_f32 = helion_output.to(torch.float32)
|
|
|
|
# Tolerance accounts for FP8 quantization boundary effects
|
|
torch.testing.assert_close(
|
|
helion_f32,
|
|
pytorch_f32,
|
|
atol=0.05,
|
|
rtol=0.05,
|
|
msg=(
|
|
f"Mismatch at batch={batch_size}, size={intermediate_size}, "
|
|
f"dtype={dtype}"
|
|
),
|
|
)
|
|
|
|
@pytest.mark.parametrize(
|
|
"shape",
|
|
[
|
|
(1, 2, 4096), # 3D input
|
|
(2, 4, 2048), # 3D input
|
|
(1, 1, 1, 8192), # 4D input
|
|
],
|
|
)
|
|
def test_silu_mul_fp8_multidim_vs_pytorch(self, shape):
|
|
skip_if_platform_unsupported()
|
|
|
|
input_tensor = torch.randn(*shape, dtype=torch.bfloat16, device="cuda")
|
|
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
|
|
|
pytorch_output = silu_mul_fp8_pytorch(input_tensor, scale)
|
|
helion_output = silu_mul_fp8(input_tensor, scale)
|
|
|
|
assert helion_output.shape == pytorch_output.shape
|
|
|
|
pytorch_f32 = pytorch_output.to(torch.float32)
|
|
helion_f32 = helion_output.to(torch.float32)
|
|
|
|
torch.testing.assert_close(
|
|
helion_f32,
|
|
pytorch_f32,
|
|
atol=0.05,
|
|
rtol=0.05,
|
|
msg=f"Mismatch for shape={shape}",
|
|
)
|
|
|
|
|
|
class TestSiluMulFp8Integration:
|
|
def test_kernel_registration_integration(self):
|
|
from vllm.kernels.helion.register import get_registered_kernels
|
|
|
|
registered_kernels = get_registered_kernels()
|
|
assert "silu_mul_fp8" in registered_kernels
|
|
|
|
kernel_wrapper = registered_kernels["silu_mul_fp8"]
|
|
assert kernel_wrapper.op_name == "silu_mul_fp8"
|
|
assert kernel_wrapper._config_picker is not None
|
|
|
|
def test_fake_impl_functionality(self):
|
|
skip_if_platform_unsupported()
|
|
from vllm.kernels.helion.register import get_registered_kernels
|
|
|
|
input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
|
|
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
|
registered_kernels = get_registered_kernels()
|
|
kernel_wrapper = registered_kernels["silu_mul_fp8"]
|
|
fake_impl = kernel_wrapper._fake_impl
|
|
|
|
fake_output = fake_impl(input_tensor, scale)
|
|
|
|
expected_shape = (32, 2048)
|
|
assert fake_output.shape == expected_shape
|
|
assert fake_output.dtype == torch.float8_e4m3fn
|
|
assert fake_output.device == input_tensor.device
|