[Kernel] [Helion] [4/N] Add silu_mul_fp8 Helion kernel (#33373)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
@@ -554,10 +554,18 @@ class TestKernelRegistry:
|
||||
"""Test suite for kernel registry functionality."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear the registry before each test."""
|
||||
"""Save and clear the registry before each test."""
|
||||
from vllm.kernels.helion.register import _REGISTERED_KERNELS
|
||||
|
||||
self._saved_registry = dict(_REGISTERED_KERNELS)
|
||||
_REGISTERED_KERNELS.clear()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Restore the registry after each test."""
|
||||
from vllm.kernels.helion.register import _REGISTERED_KERNELS
|
||||
|
||||
_REGISTERED_KERNELS.clear()
|
||||
_REGISTERED_KERNELS.update(self._saved_registry)
|
||||
|
||||
def test_get_registered_kernels_returns_copy(self):
|
||||
"""Test get_registered_kernels returns copy of registry."""
|
||||
|
||||
331
tests/kernels/helion/test_silu_mul_fp8.py
Normal file
331
tests/kernels/helion/test_silu_mul_fp8.py
Normal file
@@ -0,0 +1,331 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.utils.import_utils import has_helion
|
||||
|
||||
if not has_helion():
|
||||
pytest.skip(
|
||||
"Helion is not installed. Install with: pip install vllm[helion]",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
from vllm.kernels.helion.config_manager import ConfigManager
|
||||
from vllm.kernels.helion.ops.silu_mul_fp8 import (
|
||||
pick_silu_mul_fp8_config,
|
||||
silu_mul_fp8,
|
||||
silu_mul_fp8_baseline,
|
||||
)
|
||||
|
||||
|
||||
def skip_if_platform_unsupported():
|
||||
try:
|
||||
from vllm.kernels.helion.utils import get_canonical_gpu_name
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
|
||||
platform = get_canonical_gpu_name()
|
||||
|
||||
try:
|
||||
config_manager = ConfigManager.get_instance()
|
||||
except RuntimeError:
|
||||
config_manager = ConfigManager()
|
||||
|
||||
configs = config_manager.get_platform_configs("silu_mul_fp8", platform)
|
||||
if len(configs) == 0:
|
||||
pytest.skip("Current GPU platform not supported for silu_mul_fp8 kernel")
|
||||
|
||||
except (ImportError, RuntimeError, KeyError):
|
||||
pytest.skip("Error detecting platform support for silu_mul_fp8 kernel")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_config_manager_singleton():
|
||||
ConfigManager.reset_instance()
|
||||
ConfigManager()
|
||||
yield
|
||||
ConfigManager.reset_instance()
|
||||
|
||||
|
||||
class TestSiluMulFp8ConfigPicker:
|
||||
def test_config_picker_exact_match(self):
|
||||
config_keys = [
|
||||
"intermediate_2048_batchsize_256",
|
||||
"intermediate_4096_batchsize_256",
|
||||
]
|
||||
|
||||
input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
args = (input_tensor, scale)
|
||||
|
||||
selected_key = pick_silu_mul_fp8_config(args, config_keys)
|
||||
assert selected_key == "intermediate_2048_batchsize_256"
|
||||
|
||||
def test_config_picker_closest_match(self):
|
||||
config_keys = [
|
||||
"intermediate_2048_batchsize_256",
|
||||
"intermediate_4096_batchsize_256",
|
||||
]
|
||||
# Use 7000 (intermediate_size=3500) which is closer to 4096 than 2048
|
||||
input_tensor = torch.randn(32, 7000, dtype=torch.bfloat16, device="cuda")
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
args = (input_tensor, scale)
|
||||
|
||||
selected_key = pick_silu_mul_fp8_config(args, config_keys)
|
||||
assert selected_key == "intermediate_4096_batchsize_256"
|
||||
|
||||
def test_config_picker_fallback_to_default(self):
|
||||
config_keys = ["default", "some_other_key"]
|
||||
|
||||
input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
args = (input_tensor, scale)
|
||||
|
||||
selected_key = pick_silu_mul_fp8_config(args, config_keys)
|
||||
assert selected_key == "default"
|
||||
|
||||
def test_config_picker_no_configs(self):
|
||||
config_keys: list[str] = []
|
||||
|
||||
input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
args = (input_tensor, scale)
|
||||
|
||||
selected_key = pick_silu_mul_fp8_config(args, config_keys)
|
||||
assert selected_key is None
|
||||
|
||||
@pytest.mark.parametrize("intermediate_size", [2048, 4096, 5120])
|
||||
def test_config_picker_different_sizes(self, intermediate_size):
|
||||
config_keys = [
|
||||
"intermediate_2048_batchsize_256",
|
||||
"intermediate_4096_batchsize_256",
|
||||
"intermediate_5120_batchsize_256",
|
||||
]
|
||||
|
||||
input_tensor = torch.randn(
|
||||
32, 2 * intermediate_size, dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
args = (input_tensor, scale)
|
||||
|
||||
selected_key = pick_silu_mul_fp8_config(args, config_keys)
|
||||
expected_key = f"intermediate_{intermediate_size}_batchsize_256"
|
||||
assert selected_key == expected_key
|
||||
|
||||
|
||||
class TestSiluMulFp8Correctness:
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
|
||||
@pytest.mark.parametrize("intermediate_size", [2048, 3000, 3500, 4096, 5000])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_silu_mul_fp8_correctness(self, batch_size, intermediate_size, dtype):
|
||||
skip_if_platform_unsupported()
|
||||
|
||||
input_size = 2 * intermediate_size
|
||||
input_tensor = torch.randn(batch_size, input_size, dtype=dtype, device="cuda")
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
|
||||
reference_output = silu_mul_fp8_baseline(input_tensor, scale)
|
||||
helion_output = silu_mul_fp8(input_tensor, scale)
|
||||
|
||||
assert helion_output.shape == reference_output.shape
|
||||
assert helion_output.dtype == torch.float8_e4m3fn
|
||||
assert reference_output.dtype == torch.float8_e4m3fn
|
||||
|
||||
ref_f32 = reference_output.to(torch.float32)
|
||||
helion_f32 = helion_output.to(torch.float32)
|
||||
# FP8 E4M3 has limited precision. Values near quantization boundaries
|
||||
# can round differently due to intermediate precision differences.
|
||||
torch.testing.assert_close(
|
||||
helion_f32,
|
||||
ref_f32,
|
||||
atol=0.05,
|
||||
rtol=0.05,
|
||||
msg=f"Mismatch at batch={batch_size}, size={intermediate_size}",
|
||||
)
|
||||
|
||||
def test_silu_mul_fp8_shape_inference(self):
|
||||
skip_if_platform_unsupported()
|
||||
batch_size, input_size = 32, 8192
|
||||
intermediate_size = input_size // 2
|
||||
|
||||
input_tensor = torch.randn(
|
||||
batch_size, input_size, dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
|
||||
output = silu_mul_fp8(input_tensor, scale)
|
||||
|
||||
expected_shape = (batch_size, intermediate_size)
|
||||
assert output.shape == expected_shape
|
||||
assert output.dtype == torch.float8_e4m3fn
|
||||
|
||||
def test_silu_mul_fp8_scale_variations(self):
|
||||
skip_if_platform_unsupported()
|
||||
batch_size, input_size = 16, 4096
|
||||
|
||||
input_tensor = torch.randn(
|
||||
batch_size, input_size, dtype=torch.bfloat16, device="cuda"
|
||||
)
|
||||
|
||||
scales = [0.1, 0.5, 1.0, 2.0, 10.0]
|
||||
|
||||
for scale_val in scales:
|
||||
scale = torch.tensor([scale_val], dtype=torch.float32, device="cuda")
|
||||
|
||||
reference_output = silu_mul_fp8_baseline(input_tensor, scale)
|
||||
helion_output = silu_mul_fp8(input_tensor, scale)
|
||||
ref_f32 = reference_output.to(torch.float32)
|
||||
helion_f32 = helion_output.to(torch.float32)
|
||||
|
||||
torch.testing.assert_close(
|
||||
helion_f32,
|
||||
ref_f32,
|
||||
atol=0.05,
|
||||
rtol=0.05,
|
||||
msg=f"Mismatch for scale={scale_val}",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shape",
|
||||
[
|
||||
(1, 4096),
|
||||
(16, 4096),
|
||||
(128, 4096),
|
||||
(1024, 4096),
|
||||
(1, 8192),
|
||||
(16, 8192),
|
||||
(128, 8192),
|
||||
],
|
||||
)
|
||||
def test_silu_mul_fp8_various_shapes(self, shape):
|
||||
skip_if_platform_unsupported()
|
||||
|
||||
input_tensor = torch.randn(*shape, dtype=torch.bfloat16, device="cuda")
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
|
||||
reference_output = silu_mul_fp8_baseline(input_tensor, scale)
|
||||
helion_output = silu_mul_fp8(input_tensor, scale)
|
||||
|
||||
assert helion_output.shape == reference_output.shape
|
||||
|
||||
ref_f32 = reference_output.to(torch.float32)
|
||||
helion_f32 = helion_output.to(torch.float32)
|
||||
|
||||
torch.testing.assert_close(
|
||||
helion_f32, ref_f32, atol=0.05, rtol=0.05, msg=f"Mismatch for shape={shape}"
|
||||
)
|
||||
|
||||
|
||||
def silu_mul_fp8_pytorch(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
"""Pure PyTorch reference using F.silu.
|
||||
|
||||
This matches vLLM's SiluAndMul.forward_native exactly:
|
||||
F.silu(x[..., :d]) * x[..., d:]
|
||||
"""
|
||||
d = input.shape[-1] // 2
|
||||
result = F.silu(input[..., :d]) * input[..., d:]
|
||||
return (result.to(torch.float32) / scale).to(torch.float8_e4m3fn)
|
||||
|
||||
|
||||
class TestSiluMulFp8PytorchReference:
|
||||
"""Tests comparing Helion kernel against pure PyTorch implementation.
|
||||
|
||||
Uses tighter tolerance since both use PyTorch's FP8 conversion
|
||||
(same rounding mode), unlike the vLLM C++ baseline which uses
|
||||
NVIDIA's hardware FP8 conversion with different rounding.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 256])
|
||||
@pytest.mark.parametrize("intermediate_size", [1024, 2048, 4096])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_silu_mul_fp8_vs_pytorch(self, batch_size, intermediate_size, dtype):
|
||||
skip_if_platform_unsupported()
|
||||
|
||||
input_tensor = torch.randn(
|
||||
batch_size, 2 * intermediate_size, dtype=dtype, device="cuda"
|
||||
)
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
|
||||
pytorch_output = silu_mul_fp8_pytorch(input_tensor, scale)
|
||||
helion_output = silu_mul_fp8(input_tensor, scale)
|
||||
|
||||
assert helion_output.shape == pytorch_output.shape
|
||||
assert helion_output.dtype == torch.float8_e4m3fn
|
||||
|
||||
pytorch_f32 = pytorch_output.to(torch.float32)
|
||||
helion_f32 = helion_output.to(torch.float32)
|
||||
|
||||
# Tolerance accounts for FP8 quantization boundary effects
|
||||
torch.testing.assert_close(
|
||||
helion_f32,
|
||||
pytorch_f32,
|
||||
atol=0.05,
|
||||
rtol=0.05,
|
||||
msg=(
|
||||
f"Mismatch at batch={batch_size}, size={intermediate_size}, "
|
||||
f"dtype={dtype}"
|
||||
),
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shape",
|
||||
[
|
||||
(1, 2, 4096), # 3D input
|
||||
(2, 4, 2048), # 3D input
|
||||
(1, 1, 1, 8192), # 4D input
|
||||
],
|
||||
)
|
||||
def test_silu_mul_fp8_multidim_vs_pytorch(self, shape):
|
||||
skip_if_platform_unsupported()
|
||||
|
||||
input_tensor = torch.randn(*shape, dtype=torch.bfloat16, device="cuda")
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
|
||||
pytorch_output = silu_mul_fp8_pytorch(input_tensor, scale)
|
||||
helion_output = silu_mul_fp8(input_tensor, scale)
|
||||
|
||||
assert helion_output.shape == pytorch_output.shape
|
||||
|
||||
pytorch_f32 = pytorch_output.to(torch.float32)
|
||||
helion_f32 = helion_output.to(torch.float32)
|
||||
|
||||
torch.testing.assert_close(
|
||||
helion_f32,
|
||||
pytorch_f32,
|
||||
atol=0.05,
|
||||
rtol=0.05,
|
||||
msg=f"Mismatch for shape={shape}",
|
||||
)
|
||||
|
||||
|
||||
class TestSiluMulFp8Integration:
|
||||
def test_kernel_registration_integration(self):
|
||||
from vllm.kernels.helion.register import get_registered_kernels
|
||||
|
||||
registered_kernels = get_registered_kernels()
|
||||
assert "silu_mul_fp8" in registered_kernels
|
||||
|
||||
kernel_wrapper = registered_kernels["silu_mul_fp8"]
|
||||
assert kernel_wrapper.op_name == "silu_mul_fp8"
|
||||
assert kernel_wrapper._config_picker is not None
|
||||
|
||||
def test_fake_impl_functionality(self):
|
||||
skip_if_platform_unsupported()
|
||||
from vllm.kernels.helion.register import get_registered_kernels
|
||||
|
||||
input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda")
|
||||
scale = torch.tensor([0.5], dtype=torch.float32, device="cuda")
|
||||
registered_kernels = get_registered_kernels()
|
||||
kernel_wrapper = registered_kernels["silu_mul_fp8"]
|
||||
fake_impl = kernel_wrapper._fake_impl
|
||||
|
||||
fake_output = fake_impl(input_tensor, scale)
|
||||
|
||||
expected_shape = (32, 2048)
|
||||
assert fake_output.shape == expected_shape
|
||||
assert fake_output.dtype == torch.float8_e4m3fn
|
||||
assert fake_output.device == input_tensor.device
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Helion integration for vLLM."""
|
||||
|
||||
import vllm.kernels.helion.ops # noqa: F401 Auto-register all Helion ops
|
||||
from vllm.kernels.helion.config_manager import (
|
||||
ConfigManager,
|
||||
ConfigSet,
|
||||
|
||||
@@ -104,9 +104,6 @@ class ConfigSet:
|
||||
result[platform] = {}
|
||||
|
||||
for config_key, config in config_keys_dict.items():
|
||||
# Convert helion.Config to dict using to_json() + json.loads()
|
||||
import json
|
||||
|
||||
result[platform][config_key] = json.loads(config.to_json())
|
||||
|
||||
return result
|
||||
|
||||
550
vllm/kernels/helion/configs/silu_mul_fp8.json
Normal file
550
vllm/kernels/helion/configs/silu_mul_fp8.json
Normal file
@@ -0,0 +1,550 @@
|
||||
{
|
||||
"nvidia_h200": {
|
||||
"intermediate_2048_batchsize_256": {
|
||||
"block_sizes": [
|
||||
64,
|
||||
128
|
||||
],
|
||||
"loop_orders": [
|
||||
[
|
||||
0,
|
||||
1
|
||||
]
|
||||
],
|
||||
"flatten_loops": [
|
||||
true
|
||||
],
|
||||
"l2_groupings": [
|
||||
2
|
||||
],
|
||||
"range_unroll_factors": [
|
||||
0
|
||||
],
|
||||
"range_num_stages": [
|
||||
0
|
||||
],
|
||||
"range_multi_buffers": [
|
||||
null
|
||||
],
|
||||
"range_flattens": [
|
||||
null
|
||||
],
|
||||
"load_eviction_policies": [
|
||||
"",
|
||||
"",
|
||||
""
|
||||
],
|
||||
"num_warps": 32,
|
||||
"num_stages": 1,
|
||||
"indexing": [
|
||||
"pointer",
|
||||
"tensor_descriptor",
|
||||
"pointer",
|
||||
"pointer"
|
||||
],
|
||||
"pid_type": "flat",
|
||||
"range_warp_specializes": []
|
||||
},
|
||||
"intermediate_4096_batchsize_256": {
|
||||
"block_sizes": [
|
||||
16,
|
||||
64
|
||||
],
|
||||
"loop_orders": [
|
||||
[
|
||||
0,
|
||||
1
|
||||
]
|
||||
],
|
||||
"flatten_loops": [
|
||||
false
|
||||
],
|
||||
"l2_groupings": [
|
||||
1
|
||||
],
|
||||
"range_unroll_factors": [
|
||||
0
|
||||
],
|
||||
"range_num_stages": [
|
||||
0
|
||||
],
|
||||
"range_multi_buffers": [
|
||||
null
|
||||
],
|
||||
"range_flattens": [
|
||||
null
|
||||
],
|
||||
"load_eviction_policies": [
|
||||
"",
|
||||
"",
|
||||
""
|
||||
],
|
||||
"num_warps": 2,
|
||||
"num_stages": 1,
|
||||
"indexing": [
|
||||
"pointer",
|
||||
"pointer",
|
||||
"pointer",
|
||||
"pointer"
|
||||
],
|
||||
"pid_type": "flat",
|
||||
"range_warp_specializes": []
|
||||
},
|
||||
"default": {
|
||||
"block_sizes": [
|
||||
1,
|
||||
512
|
||||
],
|
||||
"loop_orders": [
|
||||
[
|
||||
1,
|
||||
0
|
||||
]
|
||||
],
|
||||
"flatten_loops": [
|
||||
false
|
||||
],
|
||||
"l2_groupings": [
|
||||
4
|
||||
],
|
||||
"range_unroll_factors": [
|
||||
0
|
||||
],
|
||||
"range_num_stages": [
|
||||
0
|
||||
],
|
||||
"range_multi_buffers": [
|
||||
null
|
||||
],
|
||||
"range_flattens": [
|
||||
null
|
||||
],
|
||||
"load_eviction_policies": [
|
||||
"",
|
||||
"first",
|
||||
""
|
||||
],
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"indexing": [
|
||||
"tensor_descriptor",
|
||||
"tensor_descriptor",
|
||||
"tensor_descriptor",
|
||||
"pointer"
|
||||
],
|
||||
"pid_type": "flat",
|
||||
"range_warp_specializes": []
|
||||
}
|
||||
},
|
||||
"nvidia_h100_pcie": {
|
||||
"intermediate_2048_batchsize_256": {
|
||||
"block_sizes": [
|
||||
1,
|
||||
512
|
||||
],
|
||||
"loop_orders": [
|
||||
[
|
||||
1,
|
||||
0
|
||||
]
|
||||
],
|
||||
"flatten_loops": [
|
||||
false
|
||||
],
|
||||
"l2_groupings": [
|
||||
4
|
||||
],
|
||||
"range_unroll_factors": [
|
||||
0
|
||||
],
|
||||
"range_num_stages": [
|
||||
0
|
||||
],
|
||||
"range_multi_buffers": [
|
||||
null
|
||||
],
|
||||
"range_flattens": [
|
||||
null
|
||||
],
|
||||
"load_eviction_policies": [
|
||||
"",
|
||||
"first",
|
||||
""
|
||||
],
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"indexing": [
|
||||
"tensor_descriptor",
|
||||
"tensor_descriptor",
|
||||
"tensor_descriptor",
|
||||
"pointer"
|
||||
],
|
||||
"pid_type": "flat",
|
||||
"range_warp_specializes": []
|
||||
},
|
||||
"intermediate_4096_batchsize_256": {
|
||||
"block_sizes": [
|
||||
256,
|
||||
128
|
||||
],
|
||||
"loop_orders": [
|
||||
[
|
||||
0,
|
||||
1
|
||||
]
|
||||
],
|
||||
"flatten_loops": [
|
||||
true
|
||||
],
|
||||
"l2_groupings": [
|
||||
1
|
||||
],
|
||||
"range_unroll_factors": [
|
||||
2
|
||||
],
|
||||
"range_num_stages": [
|
||||
3
|
||||
],
|
||||
"range_multi_buffers": [
|
||||
false
|
||||
],
|
||||
"range_flattens": [
|
||||
true
|
||||
],
|
||||
"load_eviction_policies": [
|
||||
"last",
|
||||
"last",
|
||||
""
|
||||
],
|
||||
"num_warps": 32,
|
||||
"num_stages": 3,
|
||||
"indexing": [
|
||||
"tensor_descriptor",
|
||||
"tensor_descriptor",
|
||||
"tensor_descriptor",
|
||||
"pointer"
|
||||
],
|
||||
"pid_type": "persistent_blocked",
|
||||
"range_warp_specializes": []
|
||||
},
|
||||
"default": {
|
||||
"block_sizes": [
|
||||
1,
|
||||
512
|
||||
],
|
||||
"loop_orders": [
|
||||
[
|
||||
1,
|
||||
0
|
||||
]
|
||||
],
|
||||
"flatten_loops": [
|
||||
false
|
||||
],
|
||||
"l2_groupings": [
|
||||
4
|
||||
],
|
||||
"range_unroll_factors": [
|
||||
0
|
||||
],
|
||||
"range_num_stages": [
|
||||
0
|
||||
],
|
||||
"range_multi_buffers": [
|
||||
null
|
||||
],
|
||||
"range_flattens": [
|
||||
null
|
||||
],
|
||||
"load_eviction_policies": [
|
||||
"",
|
||||
"first",
|
||||
""
|
||||
],
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"indexing": [
|
||||
"tensor_descriptor",
|
||||
"tensor_descriptor",
|
||||
"tensor_descriptor",
|
||||
"pointer"
|
||||
],
|
||||
"pid_type": "flat",
|
||||
"range_warp_specializes": []
|
||||
}
|
||||
},
|
||||
"nvidia_h100_sxm5": {
|
||||
"intermediate_2048_batchsize_256": {
|
||||
"block_sizes": [
|
||||
1,
|
||||
512
|
||||
],
|
||||
"loop_orders": [
|
||||
[
|
||||
1,
|
||||
0
|
||||
]
|
||||
],
|
||||
"flatten_loops": [
|
||||
false
|
||||
],
|
||||
"l2_groupings": [
|
||||
4
|
||||
],
|
||||
"range_unroll_factors": [
|
||||
0
|
||||
],
|
||||
"range_num_stages": [
|
||||
0
|
||||
],
|
||||
"range_multi_buffers": [
|
||||
null
|
||||
],
|
||||
"range_flattens": [
|
||||
null
|
||||
],
|
||||
"load_eviction_policies": [
|
||||
"",
|
||||
"first",
|
||||
""
|
||||
],
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"indexing": [
|
||||
"tensor_descriptor",
|
||||
"tensor_descriptor",
|
||||
"tensor_descriptor",
|
||||
"pointer"
|
||||
],
|
||||
"pid_type": "flat",
|
||||
"range_warp_specializes": []
|
||||
},
|
||||
"intermediate_4096_batchsize_256": {
|
||||
"block_sizes": [
|
||||
256,
|
||||
128
|
||||
],
|
||||
"loop_orders": [
|
||||
[
|
||||
0,
|
||||
1
|
||||
]
|
||||
],
|
||||
"flatten_loops": [
|
||||
true
|
||||
],
|
||||
"l2_groupings": [
|
||||
1
|
||||
],
|
||||
"range_unroll_factors": [
|
||||
2
|
||||
],
|
||||
"range_num_stages": [
|
||||
3
|
||||
],
|
||||
"range_multi_buffers": [
|
||||
false
|
||||
],
|
||||
"range_flattens": [
|
||||
true
|
||||
],
|
||||
"load_eviction_policies": [
|
||||
"last",
|
||||
"last",
|
||||
""
|
||||
],
|
||||
"num_warps": 32,
|
||||
"num_stages": 3,
|
||||
"indexing": [
|
||||
"tensor_descriptor",
|
||||
"tensor_descriptor",
|
||||
"tensor_descriptor",
|
||||
"pointer"
|
||||
],
|
||||
"pid_type": "persistent_blocked",
|
||||
"range_warp_specializes": []
|
||||
},
|
||||
"default": {
|
||||
"block_sizes": [
|
||||
1,
|
||||
512
|
||||
],
|
||||
"loop_orders": [
|
||||
[
|
||||
1,
|
||||
0
|
||||
]
|
||||
],
|
||||
"flatten_loops": [
|
||||
false
|
||||
],
|
||||
"l2_groupings": [
|
||||
4
|
||||
],
|
||||
"range_unroll_factors": [
|
||||
0
|
||||
],
|
||||
"range_num_stages": [
|
||||
0
|
||||
],
|
||||
"range_multi_buffers": [
|
||||
null
|
||||
],
|
||||
"range_flattens": [
|
||||
null
|
||||
],
|
||||
"load_eviction_policies": [
|
||||
"",
|
||||
"first",
|
||||
""
|
||||
],
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"indexing": [
|
||||
"tensor_descriptor",
|
||||
"tensor_descriptor",
|
||||
"tensor_descriptor",
|
||||
"pointer"
|
||||
],
|
||||
"pid_type": "flat",
|
||||
"range_warp_specializes": []
|
||||
}
|
||||
},
|
||||
"nvidia_h100": {
|
||||
"intermediate_2048_batchsize_256": {
|
||||
"block_sizes": [
|
||||
1,
|
||||
512
|
||||
],
|
||||
"loop_orders": [
|
||||
[
|
||||
1,
|
||||
0
|
||||
]
|
||||
],
|
||||
"flatten_loops": [
|
||||
false
|
||||
],
|
||||
"l2_groupings": [
|
||||
4
|
||||
],
|
||||
"range_unroll_factors": [
|
||||
0
|
||||
],
|
||||
"range_num_stages": [
|
||||
0
|
||||
],
|
||||
"range_multi_buffers": [
|
||||
null
|
||||
],
|
||||
"range_flattens": [
|
||||
null
|
||||
],
|
||||
"load_eviction_policies": [
|
||||
"",
|
||||
"first",
|
||||
""
|
||||
],
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"indexing": [
|
||||
"tensor_descriptor",
|
||||
"tensor_descriptor",
|
||||
"tensor_descriptor",
|
||||
"pointer"
|
||||
],
|
||||
"pid_type": "flat",
|
||||
"range_warp_specializes": []
|
||||
},
|
||||
"intermediate_4096_batchsize_256": {
|
||||
"block_sizes": [
|
||||
256,
|
||||
128
|
||||
],
|
||||
"loop_orders": [
|
||||
[
|
||||
0,
|
||||
1
|
||||
]
|
||||
],
|
||||
"flatten_loops": [
|
||||
true
|
||||
],
|
||||
"l2_groupings": [
|
||||
1
|
||||
],
|
||||
"range_unroll_factors": [
|
||||
2
|
||||
],
|
||||
"range_num_stages": [
|
||||
3
|
||||
],
|
||||
"range_multi_buffers": [
|
||||
false
|
||||
],
|
||||
"range_flattens": [
|
||||
true
|
||||
],
|
||||
"load_eviction_policies": [
|
||||
"last",
|
||||
"last",
|
||||
""
|
||||
],
|
||||
"num_warps": 32,
|
||||
"num_stages": 3,
|
||||
"indexing": [
|
||||
"tensor_descriptor",
|
||||
"tensor_descriptor",
|
||||
"tensor_descriptor",
|
||||
"pointer"
|
||||
],
|
||||
"pid_type": "persistent_blocked",
|
||||
"range_warp_specializes": []
|
||||
},
|
||||
"default": {
|
||||
"block_sizes": [
|
||||
1,
|
||||
512
|
||||
],
|
||||
"loop_orders": [
|
||||
[
|
||||
1,
|
||||
0
|
||||
]
|
||||
],
|
||||
"flatten_loops": [
|
||||
false
|
||||
],
|
||||
"l2_groupings": [
|
||||
4
|
||||
],
|
||||
"range_unroll_factors": [
|
||||
0
|
||||
],
|
||||
"range_num_stages": [
|
||||
0
|
||||
],
|
||||
"range_multi_buffers": [
|
||||
null
|
||||
],
|
||||
"range_flattens": [
|
||||
null
|
||||
],
|
||||
"load_eviction_policies": [
|
||||
"",
|
||||
"first",
|
||||
""
|
||||
],
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"indexing": [
|
||||
"tensor_descriptor",
|
||||
"tensor_descriptor",
|
||||
"tensor_descriptor",
|
||||
"pointer"
|
||||
],
|
||||
"pid_type": "flat",
|
||||
"range_warp_specializes": []
|
||||
}
|
||||
}
|
||||
}
|
||||
11
vllm/kernels/helion/ops/__init__.py
Normal file
11
vllm/kernels/helion/ops/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Auto-import all Helion op modules to trigger kernel registration."""
|
||||
|
||||
import importlib
|
||||
import pkgutil
|
||||
|
||||
# Automatically import all submodules so that @register_kernel
|
||||
# decorators execute and register ops with torch.ops.vllm_helion.
|
||||
for _module_info in pkgutil.iter_modules(__path__):
|
||||
importlib.import_module(f"{__name__}.{_module_info.name}")
|
||||
100
vllm/kernels/helion/ops/silu_mul_fp8.py
Normal file
100
vllm/kernels/helion/ops/silu_mul_fp8.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import has_helion
|
||||
|
||||
if not has_helion():
|
||||
raise ImportError(
|
||||
"silu_mul_fp8 Helion kernel requires helion to be installed. "
|
||||
"Install it with: pip install helion"
|
||||
)
|
||||
|
||||
import helion.language as hl
|
||||
|
||||
from vllm.kernels.helion.register import register_kernel
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@register_kernel # type: ignore[misc]
|
||||
def silu_mul_fp8(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
original_shape = input.shape
|
||||
two_d = hl.specialize(original_shape[-1])
|
||||
d = two_d // 2
|
||||
output_shape = original_shape[:-1] + (d,)
|
||||
|
||||
input_2d = input.view(-1, original_shape[-1])
|
||||
m = input_2d.shape[0]
|
||||
|
||||
# TODO(gmagogsfm): Support for more float8 subtypes (e4m3fnuz, e5m2) coming
|
||||
out = torch.empty((m, d), device=input.device, dtype=torch.float8_e4m3fn)
|
||||
|
||||
input_part_a = input_2d[:, :d]
|
||||
input_part_b = input_2d[:, d:]
|
||||
|
||||
assert scale.numel() == 1, "Scale must be a scalar Tensor"
|
||||
|
||||
for tile_m, tile_n in hl.tile([m, d]):
|
||||
a_vals = input_part_a[tile_m, tile_n]
|
||||
silu_result = torch.nn.functional.silu(a_vals)
|
||||
b_vals = input_part_b[tile_m, tile_n]
|
||||
result = silu_result * b_vals
|
||||
result_f32 = result.to(torch.float32)
|
||||
scale_val = hl.load(scale, [0])
|
||||
inv_scale = 1.0 / scale_val
|
||||
result_scaled = result_f32 * inv_scale
|
||||
out[tile_m, tile_n] = result_scaled.to(out.dtype)
|
||||
|
||||
return out.view(output_shape)
|
||||
|
||||
|
||||
@silu_mul_fp8.register_config_picker # type: ignore[misc]
|
||||
def pick_silu_mul_fp8_config(
|
||||
args: tuple[Any, ...], config_keys: list[str]
|
||||
) -> str | None:
|
||||
if not config_keys:
|
||||
return None
|
||||
|
||||
input_tensor, scale = args
|
||||
intermediate_size = input_tensor.shape[-1] // 2
|
||||
|
||||
# TODO(gmagosfm): Rerun autotuning to capture config for
|
||||
# other batch sizes.
|
||||
target_key = f"intermediate_{intermediate_size}_batchsize_256"
|
||||
if target_key in config_keys:
|
||||
return target_key
|
||||
|
||||
intermediate_sizes = []
|
||||
for key in config_keys:
|
||||
if key.startswith("intermediate_") and "_batchsize_256" in key:
|
||||
try:
|
||||
size_str = key.split("_")[1]
|
||||
size = int(size_str)
|
||||
intermediate_sizes.append((abs(size - intermediate_size), key))
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
if intermediate_sizes:
|
||||
_, best_key = min(intermediate_sizes)
|
||||
logger.debug(
|
||||
"No exact config for intermediate_size=%d, using closest match: %s",
|
||||
intermediate_size,
|
||||
best_key,
|
||||
)
|
||||
return best_key
|
||||
if "default" in config_keys:
|
||||
return "default"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def silu_mul_fp8_baseline(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||
output_shape = input.shape[:-1] + (input.shape[-1] // 2,)
|
||||
out = torch.empty(output_shape, dtype=torch.float8_e4m3fn, device=input.device)
|
||||
torch.ops._C.silu_and_mul_quant(out, input, scale)
|
||||
return out
|
||||
Reference in New Issue
Block a user