[Kernel] [Helion] [3/N] Helion kernel registry (#33203)

Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
Yanan Cao
2026-01-30 23:38:46 -08:00
committed by GitHub
parent 1618e25492
commit d5c41db35b
3 changed files with 319 additions and 3 deletions

View File

@@ -27,6 +27,9 @@ from vllm.kernels.helion.config_manager import ConfigManager
from vllm.kernels.helion.register import ( from vllm.kernels.helion.register import (
ConfiguredHelionKernel, ConfiguredHelionKernel,
HelionKernelWrapper, HelionKernelWrapper,
get_kernel_by_name,
get_registered_kernels,
register_kernel,
validate_helion_settings, validate_helion_settings,
) )
@@ -545,3 +548,191 @@ class TestHelionKernelWrapper:
assert result is new_op assert result is new_op
# Check that op_func is the decorated kernel, not ConfiguredHelionKernel # Check that op_func is the decorated kernel, not ConfiguredHelionKernel
assert mock_register.call_args[1]["op_func"] is mock_decorated assert mock_register.call_args[1]["op_func"] is mock_decorated
class TestKernelRegistry:
"""Test suite for kernel registry functionality."""
def setup_method(self):
"""Clear the registry before each test."""
from vllm.kernels.helion.register import _REGISTERED_KERNELS
_REGISTERED_KERNELS.clear()
def test_get_registered_kernels_returns_copy(self):
"""Test get_registered_kernels returns copy of registry."""
result1 = get_registered_kernels()
result2 = get_registered_kernels()
# Should be separate objects
assert result1 is not result2
# Should have same content
assert result1 == result2
def test_get_kernel_by_name_returns_kernel(self):
"""Test get_kernel_by_name returns registered kernel."""
wrapper = HelionKernelWrapper(
raw_kernel_func=Mock(),
op_name="test_kernel",
fake_impl=Mock(),
)
from vllm.kernels.helion.register import _REGISTERED_KERNELS
_REGISTERED_KERNELS["test_kernel"] = wrapper
result = get_kernel_by_name("test_kernel")
assert result is wrapper
def test_get_kernel_by_name_returns_none_for_missing(self):
"""Test get_kernel_by_name returns None for missing kernel."""
result = get_kernel_by_name("nonexistent")
assert result is None
def test_register_kernel_auto_generates_fake_impl(self):
"""Test register_kernel auto-generates fake_impl when not provided."""
with patch("vllm.kernels.helion.register.infer_fake_impl") as mock_infer:
mock_fake = Mock()
mock_infer.return_value = mock_fake
def original_kernel(x):
return x
wrapper = register_kernel(original_kernel)
mock_infer.assert_called_once_with(original_kernel, None)
assert wrapper._fake_impl is mock_fake
def test_register_kernel_creates_wrapper(self):
"""Test register_kernel creates HelionKernelWrapper."""
def test_kernel(x):
return x
result = register_kernel("test_name")(test_kernel)
assert isinstance(result, HelionKernelWrapper)
assert result.op_name == "test_name"
assert result.raw_kernel_func is test_kernel
def test_register_kernel_auto_detects_name(self):
"""Test register_kernel uses function name when no name provided."""
@register_kernel
def my_test_kernel(x):
return x
assert my_test_kernel.op_name == "my_test_kernel"
def test_register_kernel_registers_in_global_registry(self):
"""Test register_kernel adds wrapper to global registry."""
@register_kernel
def test_kernel(x):
return x
registered_kernels = get_registered_kernels()
assert "test_kernel" in registered_kernels
assert registered_kernels["test_kernel"] is test_kernel
def test_register_kernel_passes_helion_settings(self):
"""Test register_kernel passes helion_settings to wrapper."""
mock_settings = Mock()
mock_settings.to_dict.return_value = {"debug": True}
@register_kernel("test_name", helion_settings=mock_settings)
def test_kernel(x):
return x
assert test_kernel.helion_settings is mock_settings
def test_register_kernel_supports_decorator_syntax(self):
"""Test register_kernel works with decorator arguments."""
mock_fake = Mock()
wrapper = register_kernel("custom_name", fake_impl=mock_fake)
def test_kernel(x):
return x
result = wrapper(test_kernel)
assert result.op_name == "custom_name"
assert result._fake_impl is mock_fake
def test_register_kernel_bare_decorator(self):
"""Test register_kernel works as bare decorator."""
@register_kernel
def test_kernel(x):
return x
assert isinstance(test_kernel, HelionKernelWrapper)
assert test_kernel.op_name == "test_kernel"
def test_registered_wrapper_can_register_config_picker(self):
"""Test that registered wrapper can register config picker."""
@register_kernel
def test_kernel(x):
return x
def my_picker(args, config_keys):
return "default"
result = test_kernel.register_config_picker(my_picker)
assert result is my_picker
assert test_kernel._config_picker is my_picker
def test_register_kernel_raises_on_duplicate_registration(self):
"""Test register_kernel raises error on duplicate names."""
@register_kernel("duplicate_name")
def kernel1(x):
return x
with pytest.raises(ValueError, match="already registered"):
@register_kernel("duplicate_name")
def kernel2(x):
return x
def test_register_kernel_rejects_autotuner_fn_in_settings(self):
"""Test register_kernel rejects conflicting autotuner_fn."""
mock_settings = Mock()
mock_settings.to_dict.return_value = {"autotuner_fn": Mock()}
with pytest.raises(ValueError, match="uses a custom autotuner"):
@register_kernel("test", helion_settings=mock_settings)
def test_kernel(x):
return x
def test_register_kernel_warns_with_static_shapes_true(self):
"""Test register_kernel warns when static_shapes=True."""
mock_settings = Mock()
mock_settings.to_dict.return_value = {"static_shapes": True}
with patch("vllm.kernels.helion.register.logger") as mock_logger:
@register_kernel("test", helion_settings=mock_settings)
def test_kernel(x):
return x
mock_logger.warning.assert_called_once()
assert "static_shapes=True" in mock_logger.warning.call_args[0][0]
def test_register_kernel_no_warning_with_static_shapes_false(self):
"""Test register_kernel doesn't warn with static_shapes=False."""
mock_settings = Mock()
mock_settings.to_dict.return_value = {"static_shapes": False}
with patch("vllm.kernels.helion.register.logger") as mock_logger:
@register_kernel("test", helion_settings=mock_settings)
def test_kernel(x):
return x
# Should not call warning
mock_logger.warning.assert_not_called()

View File

@@ -9,6 +9,9 @@ from vllm.kernels.helion.config_manager import (
from vllm.kernels.helion.register import ( from vllm.kernels.helion.register import (
ConfiguredHelionKernel, ConfiguredHelionKernel,
HelionKernelWrapper, HelionKernelWrapper,
get_kernel_by_name,
get_registered_kernels,
register_kernel,
vllm_helion_lib, vllm_helion_lib,
) )
from vllm.kernels.helion.utils import canonicalize_gpu_name, get_canonical_gpu_name from vllm.kernels.helion.utils import canonicalize_gpu_name, get_canonical_gpu_name
@@ -20,6 +23,9 @@ __all__ = [
# Kernel registration # Kernel registration
"ConfiguredHelionKernel", "ConfiguredHelionKernel",
"HelionKernelWrapper", "HelionKernelWrapper",
"get_kernel_by_name",
"get_registered_kernels",
"register_kernel",
"vllm_helion_lib", "vllm_helion_lib",
# Utilities # Utilities
"canonicalize_gpu_name", "canonicalize_gpu_name",

View File

@@ -37,7 +37,7 @@ Key Classes
""" """
from collections.abc import Callable from collections.abc import Callable
from typing import Any from typing import Any, cast, overload
import torch import torch
from torch.library import Library from torch.library import Library
@@ -114,7 +114,7 @@ class ConfiguredHelionKernel:
def __init__( def __init__(
self, self,
op_name: str, op_name: str,
config_picker: Callable[[tuple[Any, ...], list[str]], str | None], config_picker: Callable[[tuple[Any, ...], list[str]], str | None] | None,
raw_kernel_func: Callable, raw_kernel_func: Callable,
helion_settings: "helion.Settings | None" = None, helion_settings: "helion.Settings | None" = None,
): ):
@@ -140,9 +140,16 @@ class ConfiguredHelionKernel:
f"Use @{self.op_name}.register_config_picker to register one." f"Use @{self.op_name}.register_config_picker to register one."
) )
# After None check, config_picker is guaranteed to be non-None
assert self.config_picker is not None
def key_computer(*args): def key_computer(*args):
config_keys = list(self.configs.keys()) config_keys = list(self.configs.keys())
selected_key = self.config_picker(args, config_keys) # Cast is safe because we checked for None above
config_picker = cast(
Callable[[tuple[Any, ...], list[str]], str | None], self.config_picker
)
selected_key = config_picker(args, config_keys)
if selected_key: if selected_key:
return selected_key return selected_key
return "default" if "default" in self.configs else None return "default" if "default" in self.configs else None
@@ -272,3 +279,115 @@ class HelionKernelWrapper:
target_lib=vllm_helion_lib, target_lib=vllm_helion_lib,
) )
return getattr(torch.ops.vllm_helion, self.op_name) return getattr(torch.ops.vllm_helion, self.op_name)
# Global registry for tracking all registered HelionKernelWrapper instances
_REGISTERED_KERNELS: dict[str, HelionKernelWrapper] = {}
def get_registered_kernels() -> dict[str, HelionKernelWrapper]:
return _REGISTERED_KERNELS.copy()
def get_kernel_by_name(kernel_name: str) -> HelionKernelWrapper | None:
return _REGISTERED_KERNELS.get(kernel_name)
def infer_fake_impl(
kernel_func: Callable,
helion_settings: "helion.Settings | None" = None,
) -> Callable:
def helion_fake_kernel(*args, **kwargs):
kernel_kwargs = {}
if helion_settings:
kernel_kwargs.update(helion_settings.to_dict())
temp_decorated_kernel = helion.kernel(**kernel_kwargs)(kernel_func)
# Bind with args to get config_spec, then get a valid default config
bound = temp_decorated_kernel.bind(args)
default_config = bound.config_spec.default_config()
compiled_runner = bound.compile_config(default_config)
return compiled_runner(*args, **kwargs, _launcher=lambda *a, **kw: None)
return helion_fake_kernel
# Overloads are necessary for proper mypy type inference.
# Without overloads, the union return type HelionKernelWrapper | Callable[...]
# causes mypy to complain about missing attributes when tests do:
# wrapper = register_kernel(func) # Should return HelionKernelWrapper
# wrapper._fake_impl # mypy error: "Callable has no attribute _fake_impl"
# The overloads tell mypy the exact return type based on the argument pattern.
@overload
def register_kernel(
op_name_or_func: Callable,
*,
fake_impl: Callable | None = None,
helion_settings: "helion.Settings | None" = None,
) -> HelionKernelWrapper: ...
@overload
def register_kernel(
op_name_or_func: str | None = None,
*,
fake_impl: Callable | None = None,
helion_settings: "helion.Settings | None" = None,
) -> Callable[[Callable], HelionKernelWrapper]: ...
def register_kernel(
op_name_or_func: str | Callable | None = None,
*,
fake_impl: Callable | None = None,
helion_settings: "helion.Settings | None" = None,
) -> HelionKernelWrapper | Callable[[Callable], HelionKernelWrapper]:
"""
Decorator to register a Helion kernel function as a HelionKernelWrapper.
Wraps the raw kernel function in a HelionKernelWrapper and registers it
in the global kernel registry. Auto-generates fake_impl if not provided.
"""
def decorator(kernel_func: Callable) -> HelionKernelWrapper:
op_name = op_name_or_func if isinstance(op_name_or_func, str) else None
final_op_name = op_name if op_name else kernel_func.__name__
if final_op_name in _REGISTERED_KERNELS:
raise ValueError(
f"Helion kernel '{final_op_name}' is already registered. "
f"Use a different op_name or check for duplicate registrations."
)
final_fake_impl = fake_impl
if final_fake_impl is None:
final_fake_impl = infer_fake_impl(kernel_func, helion_settings)
logger.debug(
"Auto-generated fake_impl for Helion kernel '%s'",
kernel_func.__name__,
)
kernel_wrapper = HelionKernelWrapper(
raw_kernel_func=kernel_func,
op_name=final_op_name,
fake_impl=final_fake_impl,
helion_settings=helion_settings,
)
_REGISTERED_KERNELS[final_op_name] = kernel_wrapper
logger.info(
"Registered Helion kernel '%s' as HelionKernelWrapper",
kernel_func.__name__,
)
return kernel_wrapper
if callable(op_name_or_func) and not isinstance(op_name_or_func, str):
# Bare decorator usage: @register_kernel
return decorator(op_name_or_func)
else:
# Decorator with arguments: @register_kernel(...)
return decorator