[Kernel] [Helion] [3/N] Helion kernel registry (#33203)

Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
Yanan Cao
2026-01-30 23:38:46 -08:00
committed by GitHub
parent 1618e25492
commit d5c41db35b
3 changed files with 319 additions and 3 deletions

View File

@@ -27,6 +27,9 @@ from vllm.kernels.helion.config_manager import ConfigManager
from vllm.kernels.helion.register import (
ConfiguredHelionKernel,
HelionKernelWrapper,
get_kernel_by_name,
get_registered_kernels,
register_kernel,
validate_helion_settings,
)
@@ -545,3 +548,191 @@ class TestHelionKernelWrapper:
assert result is new_op
# Check that op_func is the decorated kernel, not ConfiguredHelionKernel
assert mock_register.call_args[1]["op_func"] is mock_decorated
class TestKernelRegistry:
"""Test suite for kernel registry functionality."""
def setup_method(self):
"""Clear the registry before each test."""
from vllm.kernels.helion.register import _REGISTERED_KERNELS
_REGISTERED_KERNELS.clear()
def test_get_registered_kernels_returns_copy(self):
"""Test get_registered_kernels returns copy of registry."""
result1 = get_registered_kernels()
result2 = get_registered_kernels()
# Should be separate objects
assert result1 is not result2
# Should have same content
assert result1 == result2
def test_get_kernel_by_name_returns_kernel(self):
"""Test get_kernel_by_name returns registered kernel."""
wrapper = HelionKernelWrapper(
raw_kernel_func=Mock(),
op_name="test_kernel",
fake_impl=Mock(),
)
from vllm.kernels.helion.register import _REGISTERED_KERNELS
_REGISTERED_KERNELS["test_kernel"] = wrapper
result = get_kernel_by_name("test_kernel")
assert result is wrapper
def test_get_kernel_by_name_returns_none_for_missing(self):
"""Test get_kernel_by_name returns None for missing kernel."""
result = get_kernel_by_name("nonexistent")
assert result is None
def test_register_kernel_auto_generates_fake_impl(self):
"""Test register_kernel auto-generates fake_impl when not provided."""
with patch("vllm.kernels.helion.register.infer_fake_impl") as mock_infer:
mock_fake = Mock()
mock_infer.return_value = mock_fake
def original_kernel(x):
return x
wrapper = register_kernel(original_kernel)
mock_infer.assert_called_once_with(original_kernel, None)
assert wrapper._fake_impl is mock_fake
def test_register_kernel_creates_wrapper(self):
"""Test register_kernel creates HelionKernelWrapper."""
def test_kernel(x):
return x
result = register_kernel("test_name")(test_kernel)
assert isinstance(result, HelionKernelWrapper)
assert result.op_name == "test_name"
assert result.raw_kernel_func is test_kernel
def test_register_kernel_auto_detects_name(self):
"""Test register_kernel uses function name when no name provided."""
@register_kernel
def my_test_kernel(x):
return x
assert my_test_kernel.op_name == "my_test_kernel"
def test_register_kernel_registers_in_global_registry(self):
"""Test register_kernel adds wrapper to global registry."""
@register_kernel
def test_kernel(x):
return x
registered_kernels = get_registered_kernels()
assert "test_kernel" in registered_kernels
assert registered_kernels["test_kernel"] is test_kernel
def test_register_kernel_passes_helion_settings(self):
"""Test register_kernel passes helion_settings to wrapper."""
mock_settings = Mock()
mock_settings.to_dict.return_value = {"debug": True}
@register_kernel("test_name", helion_settings=mock_settings)
def test_kernel(x):
return x
assert test_kernel.helion_settings is mock_settings
def test_register_kernel_supports_decorator_syntax(self):
"""Test register_kernel works with decorator arguments."""
mock_fake = Mock()
wrapper = register_kernel("custom_name", fake_impl=mock_fake)
def test_kernel(x):
return x
result = wrapper(test_kernel)
assert result.op_name == "custom_name"
assert result._fake_impl is mock_fake
def test_register_kernel_bare_decorator(self):
"""Test register_kernel works as bare decorator."""
@register_kernel
def test_kernel(x):
return x
assert isinstance(test_kernel, HelionKernelWrapper)
assert test_kernel.op_name == "test_kernel"
def test_registered_wrapper_can_register_config_picker(self):
"""Test that registered wrapper can register config picker."""
@register_kernel
def test_kernel(x):
return x
def my_picker(args, config_keys):
return "default"
result = test_kernel.register_config_picker(my_picker)
assert result is my_picker
assert test_kernel._config_picker is my_picker
def test_register_kernel_raises_on_duplicate_registration(self):
"""Test register_kernel raises error on duplicate names."""
@register_kernel("duplicate_name")
def kernel1(x):
return x
with pytest.raises(ValueError, match="already registered"):
@register_kernel("duplicate_name")
def kernel2(x):
return x
def test_register_kernel_rejects_autotuner_fn_in_settings(self):
"""Test register_kernel rejects conflicting autotuner_fn."""
mock_settings = Mock()
mock_settings.to_dict.return_value = {"autotuner_fn": Mock()}
with pytest.raises(ValueError, match="uses a custom autotuner"):
@register_kernel("test", helion_settings=mock_settings)
def test_kernel(x):
return x
def test_register_kernel_warns_with_static_shapes_true(self):
"""Test register_kernel warns when static_shapes=True."""
mock_settings = Mock()
mock_settings.to_dict.return_value = {"static_shapes": True}
with patch("vllm.kernels.helion.register.logger") as mock_logger:
@register_kernel("test", helion_settings=mock_settings)
def test_kernel(x):
return x
mock_logger.warning.assert_called_once()
assert "static_shapes=True" in mock_logger.warning.call_args[0][0]
def test_register_kernel_no_warning_with_static_shapes_false(self):
"""Test register_kernel doesn't warn with static_shapes=False."""
mock_settings = Mock()
mock_settings.to_dict.return_value = {"static_shapes": False}
with patch("vllm.kernels.helion.register.logger") as mock_logger:
@register_kernel("test", helion_settings=mock_settings)
def test_kernel(x):
return x
# Should not call warning
mock_logger.warning.assert_not_called()

View File

@@ -9,6 +9,9 @@ from vllm.kernels.helion.config_manager import (
from vllm.kernels.helion.register import (
ConfiguredHelionKernel,
HelionKernelWrapper,
get_kernel_by_name,
get_registered_kernels,
register_kernel,
vllm_helion_lib,
)
from vllm.kernels.helion.utils import canonicalize_gpu_name, get_canonical_gpu_name
@@ -20,6 +23,9 @@ __all__ = [
# Kernel registration
"ConfiguredHelionKernel",
"HelionKernelWrapper",
"get_kernel_by_name",
"get_registered_kernels",
"register_kernel",
"vllm_helion_lib",
# Utilities
"canonicalize_gpu_name",

View File

@@ -37,7 +37,7 @@ Key Classes
"""
from collections.abc import Callable
from typing import Any
from typing import Any, cast, overload
import torch
from torch.library import Library
@@ -114,7 +114,7 @@ class ConfiguredHelionKernel:
def __init__(
self,
op_name: str,
config_picker: Callable[[tuple[Any, ...], list[str]], str | None],
config_picker: Callable[[tuple[Any, ...], list[str]], str | None] | None,
raw_kernel_func: Callable,
helion_settings: "helion.Settings | None" = None,
):
@@ -140,9 +140,16 @@ class ConfiguredHelionKernel:
f"Use @{self.op_name}.register_config_picker to register one."
)
# After None check, config_picker is guaranteed to be non-None
assert self.config_picker is not None
def key_computer(*args):
config_keys = list(self.configs.keys())
selected_key = self.config_picker(args, config_keys)
# Cast is safe because we checked for None above
config_picker = cast(
Callable[[tuple[Any, ...], list[str]], str | None], self.config_picker
)
selected_key = config_picker(args, config_keys)
if selected_key:
return selected_key
return "default" if "default" in self.configs else None
@@ -272,3 +279,115 @@ class HelionKernelWrapper:
target_lib=vllm_helion_lib,
)
return getattr(torch.ops.vllm_helion, self.op_name)
# Global registry for tracking all registered HelionKernelWrapper instances
_REGISTERED_KERNELS: dict[str, HelionKernelWrapper] = {}
def get_registered_kernels() -> dict[str, HelionKernelWrapper]:
return _REGISTERED_KERNELS.copy()
def get_kernel_by_name(kernel_name: str) -> HelionKernelWrapper | None:
return _REGISTERED_KERNELS.get(kernel_name)
def infer_fake_impl(
kernel_func: Callable,
helion_settings: "helion.Settings | None" = None,
) -> Callable:
def helion_fake_kernel(*args, **kwargs):
kernel_kwargs = {}
if helion_settings:
kernel_kwargs.update(helion_settings.to_dict())
temp_decorated_kernel = helion.kernel(**kernel_kwargs)(kernel_func)
# Bind with args to get config_spec, then get a valid default config
bound = temp_decorated_kernel.bind(args)
default_config = bound.config_spec.default_config()
compiled_runner = bound.compile_config(default_config)
return compiled_runner(*args, **kwargs, _launcher=lambda *a, **kw: None)
return helion_fake_kernel
# Overloads are necessary for proper mypy type inference.
# Without overloads, the union return type HelionKernelWrapper | Callable[...]
# causes mypy to complain about missing attributes when tests do:
# wrapper = register_kernel(func) # Should return HelionKernelWrapper
# wrapper._fake_impl # mypy error: "Callable has no attribute _fake_impl"
# The overloads tell mypy the exact return type based on the argument pattern.
@overload
def register_kernel(
op_name_or_func: Callable,
*,
fake_impl: Callable | None = None,
helion_settings: "helion.Settings | None" = None,
) -> HelionKernelWrapper: ...
@overload
def register_kernel(
op_name_or_func: str | None = None,
*,
fake_impl: Callable | None = None,
helion_settings: "helion.Settings | None" = None,
) -> Callable[[Callable], HelionKernelWrapper]: ...
def register_kernel(
op_name_or_func: str | Callable | None = None,
*,
fake_impl: Callable | None = None,
helion_settings: "helion.Settings | None" = None,
) -> HelionKernelWrapper | Callable[[Callable], HelionKernelWrapper]:
"""
Decorator to register a Helion kernel function as a HelionKernelWrapper.
Wraps the raw kernel function in a HelionKernelWrapper and registers it
in the global kernel registry. Auto-generates fake_impl if not provided.
"""
def decorator(kernel_func: Callable) -> HelionKernelWrapper:
op_name = op_name_or_func if isinstance(op_name_or_func, str) else None
final_op_name = op_name if op_name else kernel_func.__name__
if final_op_name in _REGISTERED_KERNELS:
raise ValueError(
f"Helion kernel '{final_op_name}' is already registered. "
f"Use a different op_name or check for duplicate registrations."
)
final_fake_impl = fake_impl
if final_fake_impl is None:
final_fake_impl = infer_fake_impl(kernel_func, helion_settings)
logger.debug(
"Auto-generated fake_impl for Helion kernel '%s'",
kernel_func.__name__,
)
kernel_wrapper = HelionKernelWrapper(
raw_kernel_func=kernel_func,
op_name=final_op_name,
fake_impl=final_fake_impl,
helion_settings=helion_settings,
)
_REGISTERED_KERNELS[final_op_name] = kernel_wrapper
logger.info(
"Registered Helion kernel '%s' as HelionKernelWrapper",
kernel_func.__name__,
)
return kernel_wrapper
if callable(op_name_or_func) and not isinstance(op_name_or_func, str):
# Bare decorator usage: @register_kernel
return decorator(op_name_or_func)
else:
# Decorator with arguments: @register_kernel(...)
return decorator