[Kernel] [Helion] [4/N] Add silu_mul_fp8 Helion kernel (#33373)

Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
Yanan Cao
2026-02-12 18:13:12 -08:00
committed by GitHub
parent 4453ba8d9e
commit 96161fe978
7 changed files with 1002 additions and 4 deletions

View File

@@ -554,10 +554,18 @@ class TestKernelRegistry:
"""Test suite for kernel registry functionality."""
def setup_method(self):
"""Clear the registry before each test."""
"""Save and clear the registry before each test."""
from vllm.kernels.helion.register import _REGISTERED_KERNELS
self._saved_registry = dict(_REGISTERED_KERNELS)
_REGISTERED_KERNELS.clear()
def teardown_method(self):
"""Restore the registry after each test."""
from vllm.kernels.helion.register import _REGISTERED_KERNELS
_REGISTERED_KERNELS.clear()
_REGISTERED_KERNELS.update(self._saved_registry)
def test_get_registered_kernels_returns_copy(self):
"""Test get_registered_kernels returns copy of registry."""

View File

@@ -0,0 +1,331 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
import torch.nn.functional as F
from vllm.utils.import_utils import has_helion
if not has_helion():
pytest.skip(
"Helion is not installed. Install with: pip install vllm[helion]",
allow_module_level=True,
)
from vllm.kernels.helion.config_manager import ConfigManager
from vllm.kernels.helion.ops.silu_mul_fp8 import (
pick_silu_mul_fp8_config,
silu_mul_fp8,
silu_mul_fp8_baseline,
)
def skip_if_platform_unsupported():
try:
from vllm.kernels.helion.utils import get_canonical_gpu_name
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
platform = get_canonical_gpu_name()
try:
config_manager = ConfigManager.get_instance()
except RuntimeError:
config_manager = ConfigManager()
configs = config_manager.get_platform_configs("silu_mul_fp8", platform)
if len(configs) == 0:
pytest.skip("Current GPU platform not supported for silu_mul_fp8 kernel")
except (ImportError, RuntimeError, KeyError):
pytest.skip("Error detecting platform support for silu_mul_fp8 kernel")
@pytest.fixture(autouse=True)
def reset_config_manager_singleton():
ConfigManager.reset_instance()
ConfigManager()
yield
ConfigManager.reset_instance()
class TestSiluMulFp8ConfigPicker:
def test_config_picker_exact_match(self):
config_keys = [
"intermediate_2048_batchsize_256",
"intermediate_4096_batchsize_256",
]
input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
args = (input_tensor, scale)
selected_key = pick_silu_mul_fp8_config(args, config_keys)
assert selected_key == "intermediate_2048_batchsize_256"
def test_config_picker_closest_match(self):
config_keys = [
"intermediate_2048_batchsize_256",
"intermediate_4096_batchsize_256",
]
# Use 7000 (intermediate_size=3500) which is closer to 4096 than 2048
input_tensor = torch.randn(32, 7000, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
args = (input_tensor, scale)
selected_key = pick_silu_mul_fp8_config(args, config_keys)
assert selected_key == "intermediate_4096_batchsize_256"
def test_config_picker_fallback_to_default(self):
config_keys = ["default", "some_other_key"]
input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
args = (input_tensor, scale)
selected_key = pick_silu_mul_fp8_config(args, config_keys)
assert selected_key == "default"
def test_config_picker_no_configs(self):
config_keys: list[str] = []
input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
args = (input_tensor, scale)
selected_key = pick_silu_mul_fp8_config(args, config_keys)
assert selected_key is None
@pytest.mark.parametrize("intermediate_size", [2048, 4096, 5120])
def test_config_picker_different_sizes(self, intermediate_size):
config_keys = [
"intermediate_2048_batchsize_256",
"intermediate_4096_batchsize_256",
"intermediate_5120_batchsize_256",
]
input_tensor = torch.randn(
32, 2 * intermediate_size, dtype=torch.bfloat16, device="cuda"
)
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
args = (input_tensor, scale)
selected_key = pick_silu_mul_fp8_config(args, config_keys)
expected_key = f"intermediate_{intermediate_size}_batchsize_256"
assert selected_key == expected_key
class TestSiluMulFp8Correctness:
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
@pytest.mark.parametrize("intermediate_size", [2048, 3000, 3500, 4096, 5000])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_silu_mul_fp8_correctness(self, batch_size, intermediate_size, dtype):
skip_if_platform_unsupported()
input_size = 2 * intermediate_size
input_tensor = torch.randn(batch_size, input_size, dtype=dtype, device="cuda")
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
reference_output = silu_mul_fp8_baseline(input_tensor, scale)
helion_output = silu_mul_fp8(input_tensor, scale)
assert helion_output.shape == reference_output.shape
assert helion_output.dtype == torch.float8_e4m3fn
assert reference_output.dtype == torch.float8_e4m3fn
ref_f32 = reference_output.to(torch.float32)
helion_f32 = helion_output.to(torch.float32)
# FP8 E4M3 has limited precision. Values near quantization boundaries
# can round differently due to intermediate precision differences.
torch.testing.assert_close(
helion_f32,
ref_f32,
atol=0.05,
rtol=0.05,
msg=f"Mismatch at batch={batch_size}, size={intermediate_size}",
)
def test_silu_mul_fp8_shape_inference(self):
skip_if_platform_unsupported()
batch_size, input_size = 32, 8192
intermediate_size = input_size // 2
input_tensor = torch.randn(
batch_size, input_size, dtype=torch.bfloat16, device="cuda"
)
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
output = silu_mul_fp8(input_tensor, scale)
expected_shape = (batch_size, intermediate_size)
assert output.shape == expected_shape
assert output.dtype == torch.float8_e4m3fn
def test_silu_mul_fp8_scale_variations(self):
skip_if_platform_unsupported()
batch_size, input_size = 16, 4096
input_tensor = torch.randn(
batch_size, input_size, dtype=torch.bfloat16, device="cuda"
)
scales = [0.1, 0.5, 1.0, 2.0, 10.0]
for scale_val in scales:
scale = torch.tensor([scale_val], dtype=torch.float32, device="cuda")
reference_output = silu_mul_fp8_baseline(input_tensor, scale)
helion_output = silu_mul_fp8(input_tensor, scale)
ref_f32 = reference_output.to(torch.float32)
helion_f32 = helion_output.to(torch.float32)
torch.testing.assert_close(
helion_f32,
ref_f32,
atol=0.05,
rtol=0.05,
msg=f"Mismatch for scale={scale_val}",
)
@pytest.mark.parametrize(
"shape",
[
(1, 4096),
(16, 4096),
(128, 4096),
(1024, 4096),
(1, 8192),
(16, 8192),
(128, 8192),
],
)
def test_silu_mul_fp8_various_shapes(self, shape):
skip_if_platform_unsupported()
input_tensor = torch.randn(*shape, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
reference_output = silu_mul_fp8_baseline(input_tensor, scale)
helion_output = silu_mul_fp8(input_tensor, scale)
assert helion_output.shape == reference_output.shape
ref_f32 = reference_output.to(torch.float32)
helion_f32 = helion_output.to(torch.float32)
torch.testing.assert_close(
helion_f32, ref_f32, atol=0.05, rtol=0.05, msg=f"Mismatch for shape={shape}"
)
def silu_mul_fp8_pytorch(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""Pure PyTorch reference using F.silu.
This matches vLLM's SiluAndMul.forward_native exactly:
F.silu(x[..., :d]) * x[..., d:]
"""
d = input.shape[-1] // 2
result = F.silu(input[..., :d]) * input[..., d:]
return (result.to(torch.float32) / scale).to(torch.float8_e4m3fn)
class TestSiluMulFp8PytorchReference:
"""Tests comparing Helion kernel against pure PyTorch implementation.
Uses tighter tolerance since both use PyTorch's FP8 conversion
(same rounding mode), unlike the vLLM C++ baseline which uses
NVIDIA's hardware FP8 conversion with different rounding.
"""
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 256])
@pytest.mark.parametrize("intermediate_size", [1024, 2048, 4096])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_silu_mul_fp8_vs_pytorch(self, batch_size, intermediate_size, dtype):
skip_if_platform_unsupported()
input_tensor = torch.randn(
batch_size, 2 * intermediate_size, dtype=dtype, device="cuda"
)
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
pytorch_output = silu_mul_fp8_pytorch(input_tensor, scale)
helion_output = silu_mul_fp8(input_tensor, scale)
assert helion_output.shape == pytorch_output.shape
assert helion_output.dtype == torch.float8_e4m3fn
pytorch_f32 = pytorch_output.to(torch.float32)
helion_f32 = helion_output.to(torch.float32)
# Tolerance accounts for FP8 quantization boundary effects
torch.testing.assert_close(
helion_f32,
pytorch_f32,
atol=0.05,
rtol=0.05,
msg=(
f"Mismatch at batch={batch_size}, size={intermediate_size}, "
f"dtype={dtype}"
),
)
@pytest.mark.parametrize(
"shape",
[
(1, 2, 4096), # 3D input
(2, 4, 2048), # 3D input
(1, 1, 1, 8192), # 4D input
],
)
def test_silu_mul_fp8_multidim_vs_pytorch(self, shape):
skip_if_platform_unsupported()
input_tensor = torch.randn(*shape, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
pytorch_output = silu_mul_fp8_pytorch(input_tensor, scale)
helion_output = silu_mul_fp8(input_tensor, scale)
assert helion_output.shape == pytorch_output.shape
pytorch_f32 = pytorch_output.to(torch.float32)
helion_f32 = helion_output.to(torch.float32)
torch.testing.assert_close(
helion_f32,
pytorch_f32,
atol=0.05,
rtol=0.05,
msg=f"Mismatch for shape={shape}",
)
class TestSiluMulFp8Integration:
def test_kernel_registration_integration(self):
from vllm.kernels.helion.register import get_registered_kernels
registered_kernels = get_registered_kernels()
assert "silu_mul_fp8" in registered_kernels
kernel_wrapper = registered_kernels["silu_mul_fp8"]
assert kernel_wrapper.op_name == "silu_mul_fp8"
assert kernel_wrapper._config_picker is not None
def test_fake_impl_functionality(self):
skip_if_platform_unsupported()
from vllm.kernels.helion.register import get_registered_kernels
input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
registered_kernels = get_registered_kernels()
kernel_wrapper = registered_kernels["silu_mul_fp8"]
fake_impl = kernel_wrapper._fake_impl
fake_output = fake_impl(input_tensor, scale)
expected_shape = (32, 2048)
assert fake_output.shape == expected_shape
assert fake_output.dtype == torch.float8_e4m3fn
assert fake_output.device == input_tensor.device

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Helion integration for vLLM."""
import vllm.kernels.helion.ops # noqa: F401 Auto-register all Helion ops
from vllm.kernels.helion.config_manager import (
ConfigManager,
ConfigSet,

View File

@@ -104,9 +104,6 @@ class ConfigSet:
result[platform] = {}
for config_key, config in config_keys_dict.items():
# Convert helion.Config to dict using to_json() + json.loads()
import json
result[platform][config_key] = json.loads(config.to_json())
return result

View File

@@ -0,0 +1,550 @@
{
"nvidia_h200": {
"intermediate_2048_batchsize_256": {
"block_sizes": [
64,
128
],
"loop_orders": [
[
0,
1
]
],
"flatten_loops": [
true
],
"l2_groupings": [
2
],
"range_unroll_factors": [
0
],
"range_num_stages": [
0
],
"range_multi_buffers": [
null
],
"range_flattens": [
null
],
"load_eviction_policies": [
"",
"",
""
],
"num_warps": 32,
"num_stages": 1,
"indexing": [
"pointer",
"tensor_descriptor",
"pointer",
"pointer"
],
"pid_type": "flat",
"range_warp_specializes": []
},
"intermediate_4096_batchsize_256": {
"block_sizes": [
16,
64
],
"loop_orders": [
[
0,
1
]
],
"flatten_loops": [
false
],
"l2_groupings": [
1
],
"range_unroll_factors": [
0
],
"range_num_stages": [
0
],
"range_multi_buffers": [
null
],
"range_flattens": [
null
],
"load_eviction_policies": [
"",
"",
""
],
"num_warps": 2,
"num_stages": 1,
"indexing": [
"pointer",
"pointer",
"pointer",
"pointer"
],
"pid_type": "flat",
"range_warp_specializes": []
},
"default": {
"block_sizes": [
1,
512
],
"loop_orders": [
[
1,
0
]
],
"flatten_loops": [
false
],
"l2_groupings": [
4
],
"range_unroll_factors": [
0
],
"range_num_stages": [
0
],
"range_multi_buffers": [
null
],
"range_flattens": [
null
],
"load_eviction_policies": [
"",
"first",
""
],
"num_warps": 8,
"num_stages": 2,
"indexing": [
"tensor_descriptor",
"tensor_descriptor",
"tensor_descriptor",
"pointer"
],
"pid_type": "flat",
"range_warp_specializes": []
}
},
"nvidia_h100_pcie": {
"intermediate_2048_batchsize_256": {
"block_sizes": [
1,
512
],
"loop_orders": [
[
1,
0
]
],
"flatten_loops": [
false
],
"l2_groupings": [
4
],
"range_unroll_factors": [
0
],
"range_num_stages": [
0
],
"range_multi_buffers": [
null
],
"range_flattens": [
null
],
"load_eviction_policies": [
"",
"first",
""
],
"num_warps": 8,
"num_stages": 2,
"indexing": [
"tensor_descriptor",
"tensor_descriptor",
"tensor_descriptor",
"pointer"
],
"pid_type": "flat",
"range_warp_specializes": []
},
"intermediate_4096_batchsize_256": {
"block_sizes": [
256,
128
],
"loop_orders": [
[
0,
1
]
],
"flatten_loops": [
true
],
"l2_groupings": [
1
],
"range_unroll_factors": [
2
],
"range_num_stages": [
3
],
"range_multi_buffers": [
false
],
"range_flattens": [
true
],
"load_eviction_policies": [
"last",
"last",
""
],
"num_warps": 32,
"num_stages": 3,
"indexing": [
"tensor_descriptor",
"tensor_descriptor",
"tensor_descriptor",
"pointer"
],
"pid_type": "persistent_blocked",
"range_warp_specializes": []
},
"default": {
"block_sizes": [
1,
512
],
"loop_orders": [
[
1,
0
]
],
"flatten_loops": [
false
],
"l2_groupings": [
4
],
"range_unroll_factors": [
0
],
"range_num_stages": [
0
],
"range_multi_buffers": [
null
],
"range_flattens": [
null
],
"load_eviction_policies": [
"",
"first",
""
],
"num_warps": 8,
"num_stages": 2,
"indexing": [
"tensor_descriptor",
"tensor_descriptor",
"tensor_descriptor",
"pointer"
],
"pid_type": "flat",
"range_warp_specializes": []
}
},
"nvidia_h100_sxm5": {
"intermediate_2048_batchsize_256": {
"block_sizes": [
1,
512
],
"loop_orders": [
[
1,
0
]
],
"flatten_loops": [
false
],
"l2_groupings": [
4
],
"range_unroll_factors": [
0
],
"range_num_stages": [
0
],
"range_multi_buffers": [
null
],
"range_flattens": [
null
],
"load_eviction_policies": [
"",
"first",
""
],
"num_warps": 8,
"num_stages": 2,
"indexing": [
"tensor_descriptor",
"tensor_descriptor",
"tensor_descriptor",
"pointer"
],
"pid_type": "flat",
"range_warp_specializes": []
},
"intermediate_4096_batchsize_256": {
"block_sizes": [
256,
128
],
"loop_orders": [
[
0,
1
]
],
"flatten_loops": [
true
],
"l2_groupings": [
1
],
"range_unroll_factors": [
2
],
"range_num_stages": [
3
],
"range_multi_buffers": [
false
],
"range_flattens": [
true
],
"load_eviction_policies": [
"last",
"last",
""
],
"num_warps": 32,
"num_stages": 3,
"indexing": [
"tensor_descriptor",
"tensor_descriptor",
"tensor_descriptor",
"pointer"
],
"pid_type": "persistent_blocked",
"range_warp_specializes": []
},
"default": {
"block_sizes": [
1,
512
],
"loop_orders": [
[
1,
0
]
],
"flatten_loops": [
false
],
"l2_groupings": [
4
],
"range_unroll_factors": [
0
],
"range_num_stages": [
0
],
"range_multi_buffers": [
null
],
"range_flattens": [
null
],
"load_eviction_policies": [
"",
"first",
""
],
"num_warps": 8,
"num_stages": 2,
"indexing": [
"tensor_descriptor",
"tensor_descriptor",
"tensor_descriptor",
"pointer"
],
"pid_type": "flat",
"range_warp_specializes": []
}
},
"nvidia_h100": {
"intermediate_2048_batchsize_256": {
"block_sizes": [
1,
512
],
"loop_orders": [
[
1,
0
]
],
"flatten_loops": [
false
],
"l2_groupings": [
4
],
"range_unroll_factors": [
0
],
"range_num_stages": [
0
],
"range_multi_buffers": [
null
],
"range_flattens": [
null
],
"load_eviction_policies": [
"",
"first",
""
],
"num_warps": 8,
"num_stages": 2,
"indexing": [
"tensor_descriptor",
"tensor_descriptor",
"tensor_descriptor",
"pointer"
],
"pid_type": "flat",
"range_warp_specializes": []
},
"intermediate_4096_batchsize_256": {
"block_sizes": [
256,
128
],
"loop_orders": [
[
0,
1
]
],
"flatten_loops": [
true
],
"l2_groupings": [
1
],
"range_unroll_factors": [
2
],
"range_num_stages": [
3
],
"range_multi_buffers": [
false
],
"range_flattens": [
true
],
"load_eviction_policies": [
"last",
"last",
""
],
"num_warps": 32,
"num_stages": 3,
"indexing": [
"tensor_descriptor",
"tensor_descriptor",
"tensor_descriptor",
"pointer"
],
"pid_type": "persistent_blocked",
"range_warp_specializes": []
},
"default": {
"block_sizes": [
1,
512
],
"loop_orders": [
[
1,
0
]
],
"flatten_loops": [
false
],
"l2_groupings": [
4
],
"range_unroll_factors": [
0
],
"range_num_stages": [
0
],
"range_multi_buffers": [
null
],
"range_flattens": [
null
],
"load_eviction_policies": [
"",
"first",
""
],
"num_warps": 8,
"num_stages": 2,
"indexing": [
"tensor_descriptor",
"tensor_descriptor",
"tensor_descriptor",
"pointer"
],
"pid_type": "flat",
"range_warp_specializes": []
}
}
}

View File

@@ -0,0 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Auto-import all Helion op modules to trigger kernel registration."""
import importlib
import pkgutil
# Automatically import all submodules so that @register_kernel
# decorators execute and register ops with torch.ops.vllm_helion.
for _module_info in pkgutil.iter_modules(__path__):
importlib.import_module(f"{__name__}.{_module_info.name}")

View File

@@ -0,0 +1,100 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
import torch
from vllm.logger import init_logger
from vllm.utils.import_utils import has_helion
if not has_helion():
raise ImportError(
"silu_mul_fp8 Helion kernel requires helion to be installed. "
"Install it with: pip install helion"
)
import helion.language as hl
from vllm.kernels.helion.register import register_kernel
logger = init_logger(__name__)
@register_kernel # type: ignore[misc]
def silu_mul_fp8(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
original_shape = input.shape
two_d = hl.specialize(original_shape[-1])
d = two_d // 2
output_shape = original_shape[:-1] + (d,)
input_2d = input.view(-1, original_shape[-1])
m = input_2d.shape[0]
# TODO(gmagogsfm): Support for more float8 subtypes (e4m3fnuz, e5m2) coming
out = torch.empty((m, d), device=input.device, dtype=torch.float8_e4m3fn)
input_part_a = input_2d[:, :d]
input_part_b = input_2d[:, d:]
assert scale.numel() == 1, "Scale must be a scalar Tensor"
for tile_m, tile_n in hl.tile([m, d]):
a_vals = input_part_a[tile_m, tile_n]
silu_result = torch.nn.functional.silu(a_vals)
b_vals = input_part_b[tile_m, tile_n]
result = silu_result * b_vals
result_f32 = result.to(torch.float32)
scale_val = hl.load(scale, [0])
inv_scale = 1.0 / scale_val
result_scaled = result_f32 * inv_scale
out[tile_m, tile_n] = result_scaled.to(out.dtype)
return out.view(output_shape)
@silu_mul_fp8.register_config_picker # type: ignore[misc]
def pick_silu_mul_fp8_config(
args: tuple[Any, ...], config_keys: list[str]
) -> str | None:
if not config_keys:
return None
input_tensor, scale = args
intermediate_size = input_tensor.shape[-1] // 2
# TODO(gmagosfm): Rerun autotuning to capture config for
# other batch sizes.
target_key = f"intermediate_{intermediate_size}_batchsize_256"
if target_key in config_keys:
return target_key
intermediate_sizes = []
for key in config_keys:
if key.startswith("intermediate_") and "_batchsize_256" in key:
try:
size_str = key.split("_")[1]
size = int(size_str)
intermediate_sizes.append((abs(size - intermediate_size), key))
except (ValueError, IndexError):
continue
if intermediate_sizes:
_, best_key = min(intermediate_sizes)
logger.debug(
"No exact config for intermediate_size=%d, using closest match: %s",
intermediate_size,
best_key,
)
return best_key
if "default" in config_keys:
return "default"
return None
def silu_mul_fp8_baseline(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
output_shape = input.shape[:-1] + (input.shape[-1] // 2,)
out = torch.empty(output_shape, dtype=torch.float8_e4m3fn, device=input.device)
torch.ops._C.silu_and_mul_quant(out, input, scale)
return out