From d5c41db35b33d348bc8b913eab6d11a20a64f168 Mon Sep 17 00:00:00 2001 From: Yanan Cao Date: Fri, 30 Jan 2026 23:38:46 -0800 Subject: [PATCH] [Kernel] [Helion] [3/N] Helion kernel registry (#33203) Signed-off-by: Yanan Cao --- tests/kernels/helion/test_register.py | 191 ++++++++++++++++++++++++++ vllm/kernels/helion/__init__.py | 6 + vllm/kernels/helion/register.py | 125 ++++++++++++++++- 3 files changed, 319 insertions(+), 3 deletions(-) diff --git a/tests/kernels/helion/test_register.py b/tests/kernels/helion/test_register.py index c6cd3f77d..faac2765c 100644 --- a/tests/kernels/helion/test_register.py +++ b/tests/kernels/helion/test_register.py @@ -27,6 +27,9 @@ from vllm.kernels.helion.config_manager import ConfigManager from vllm.kernels.helion.register import ( ConfiguredHelionKernel, HelionKernelWrapper, + get_kernel_by_name, + get_registered_kernels, + register_kernel, validate_helion_settings, ) @@ -545,3 +548,191 @@ class TestHelionKernelWrapper: assert result is new_op # Check that op_func is the decorated kernel, not ConfiguredHelionKernel assert mock_register.call_args[1]["op_func"] is mock_decorated + + +class TestKernelRegistry: + """Test suite for kernel registry functionality.""" + + def setup_method(self): + """Clear the registry before each test.""" + from vllm.kernels.helion.register import _REGISTERED_KERNELS + + _REGISTERED_KERNELS.clear() + + def test_get_registered_kernels_returns_copy(self): + """Test get_registered_kernels returns copy of registry.""" + result1 = get_registered_kernels() + result2 = get_registered_kernels() + + # Should be separate objects + assert result1 is not result2 + # Should have same content + assert result1 == result2 + + def test_get_kernel_by_name_returns_kernel(self): + """Test get_kernel_by_name returns registered kernel.""" + wrapper = HelionKernelWrapper( + raw_kernel_func=Mock(), + op_name="test_kernel", + fake_impl=Mock(), + ) + + from vllm.kernels.helion.register import _REGISTERED_KERNELS + + _REGISTERED_KERNELS["test_kernel"] = wrapper + + result = get_kernel_by_name("test_kernel") + assert result is wrapper + + def test_get_kernel_by_name_returns_none_for_missing(self): + """Test get_kernel_by_name returns None for missing kernel.""" + result = get_kernel_by_name("nonexistent") + assert result is None + + def test_register_kernel_auto_generates_fake_impl(self): + """Test register_kernel auto-generates fake_impl when not provided.""" + with patch("vllm.kernels.helion.register.infer_fake_impl") as mock_infer: + mock_fake = Mock() + mock_infer.return_value = mock_fake + + def original_kernel(x): + return x + + wrapper = register_kernel(original_kernel) + + mock_infer.assert_called_once_with(original_kernel, None) + assert wrapper._fake_impl is mock_fake + + def test_register_kernel_creates_wrapper(self): + """Test register_kernel creates HelionKernelWrapper.""" + + def test_kernel(x): + return x + + result = register_kernel("test_name")(test_kernel) + + assert isinstance(result, HelionKernelWrapper) + assert result.op_name == "test_name" + assert result.raw_kernel_func is test_kernel + + def test_register_kernel_auto_detects_name(self): + """Test register_kernel uses function name when no name provided.""" + + @register_kernel + def my_test_kernel(x): + return x + + assert my_test_kernel.op_name == "my_test_kernel" + + def test_register_kernel_registers_in_global_registry(self): + """Test register_kernel adds wrapper to global registry.""" + + @register_kernel + def test_kernel(x): + return x + + registered_kernels = get_registered_kernels() + assert "test_kernel" in registered_kernels + assert registered_kernels["test_kernel"] is test_kernel + + def test_register_kernel_passes_helion_settings(self): + """Test register_kernel passes helion_settings to wrapper.""" + mock_settings = Mock() + mock_settings.to_dict.return_value = {"debug": True} + + @register_kernel("test_name", helion_settings=mock_settings) + def test_kernel(x): + return x + + assert test_kernel.helion_settings is mock_settings + + def test_register_kernel_supports_decorator_syntax(self): + """Test register_kernel works with decorator arguments.""" + mock_fake = Mock() + + wrapper = register_kernel("custom_name", fake_impl=mock_fake) + + def test_kernel(x): + return x + + result = wrapper(test_kernel) + + assert result.op_name == "custom_name" + assert result._fake_impl is mock_fake + + def test_register_kernel_bare_decorator(self): + """Test register_kernel works as bare decorator.""" + + @register_kernel + def test_kernel(x): + return x + + assert isinstance(test_kernel, HelionKernelWrapper) + assert test_kernel.op_name == "test_kernel" + + def test_registered_wrapper_can_register_config_picker(self): + """Test that registered wrapper can register config picker.""" + + @register_kernel + def test_kernel(x): + return x + + def my_picker(args, config_keys): + return "default" + + result = test_kernel.register_config_picker(my_picker) + + assert result is my_picker + assert test_kernel._config_picker is my_picker + + def test_register_kernel_raises_on_duplicate_registration(self): + """Test register_kernel raises error on duplicate names.""" + + @register_kernel("duplicate_name") + def kernel1(x): + return x + + with pytest.raises(ValueError, match="already registered"): + + @register_kernel("duplicate_name") + def kernel2(x): + return x + + def test_register_kernel_rejects_autotuner_fn_in_settings(self): + """Test register_kernel rejects conflicting autotuner_fn.""" + mock_settings = Mock() + mock_settings.to_dict.return_value = {"autotuner_fn": Mock()} + + with pytest.raises(ValueError, match="uses a custom autotuner"): + + @register_kernel("test", helion_settings=mock_settings) + def test_kernel(x): + return x + + def test_register_kernel_warns_with_static_shapes_true(self): + """Test register_kernel warns when static_shapes=True.""" + mock_settings = Mock() + mock_settings.to_dict.return_value = {"static_shapes": True} + + with patch("vllm.kernels.helion.register.logger") as mock_logger: + + @register_kernel("test", helion_settings=mock_settings) + def test_kernel(x): + return x + + mock_logger.warning.assert_called_once() + assert "static_shapes=True" in mock_logger.warning.call_args[0][0] + + def test_register_kernel_no_warning_with_static_shapes_false(self): + """Test register_kernel doesn't warn with static_shapes=False.""" + mock_settings = Mock() + mock_settings.to_dict.return_value = {"static_shapes": False} + + with patch("vllm.kernels.helion.register.logger") as mock_logger: + + @register_kernel("test", helion_settings=mock_settings) + def test_kernel(x): + return x + + # Should not call warning + mock_logger.warning.assert_not_called() diff --git a/vllm/kernels/helion/__init__.py b/vllm/kernels/helion/__init__.py index 932f7979c..dfbf28b8d 100644 --- a/vllm/kernels/helion/__init__.py +++ b/vllm/kernels/helion/__init__.py @@ -9,6 +9,9 @@ from vllm.kernels.helion.config_manager import ( from vllm.kernels.helion.register import ( ConfiguredHelionKernel, HelionKernelWrapper, + get_kernel_by_name, + get_registered_kernels, + register_kernel, vllm_helion_lib, ) from vllm.kernels.helion.utils import canonicalize_gpu_name, get_canonical_gpu_name @@ -20,6 +23,9 @@ __all__ = [ # Kernel registration "ConfiguredHelionKernel", "HelionKernelWrapper", + "get_kernel_by_name", + "get_registered_kernels", + "register_kernel", "vllm_helion_lib", # Utilities "canonicalize_gpu_name", diff --git a/vllm/kernels/helion/register.py b/vllm/kernels/helion/register.py index 045b2054b..b90110724 100644 --- a/vllm/kernels/helion/register.py +++ b/vllm/kernels/helion/register.py @@ -37,7 +37,7 @@ Key Classes """ from collections.abc import Callable -from typing import Any +from typing import Any, cast, overload import torch from torch.library import Library @@ -114,7 +114,7 @@ class ConfiguredHelionKernel: def __init__( self, op_name: str, - config_picker: Callable[[tuple[Any, ...], list[str]], str | None], + config_picker: Callable[[tuple[Any, ...], list[str]], str | None] | None, raw_kernel_func: Callable, helion_settings: "helion.Settings | None" = None, ): @@ -140,9 +140,16 @@ class ConfiguredHelionKernel: f"Use @{self.op_name}.register_config_picker to register one." ) + # After None check, config_picker is guaranteed to be non-None + assert self.config_picker is not None + def key_computer(*args): config_keys = list(self.configs.keys()) - selected_key = self.config_picker(args, config_keys) + # Cast is safe because we checked for None above + config_picker = cast( + Callable[[tuple[Any, ...], list[str]], str | None], self.config_picker + ) + selected_key = config_picker(args, config_keys) if selected_key: return selected_key return "default" if "default" in self.configs else None @@ -272,3 +279,115 @@ class HelionKernelWrapper: target_lib=vllm_helion_lib, ) return getattr(torch.ops.vllm_helion, self.op_name) + + +# Global registry for tracking all registered HelionKernelWrapper instances +_REGISTERED_KERNELS: dict[str, HelionKernelWrapper] = {} + + +def get_registered_kernels() -> dict[str, HelionKernelWrapper]: + return _REGISTERED_KERNELS.copy() + + +def get_kernel_by_name(kernel_name: str) -> HelionKernelWrapper | None: + return _REGISTERED_KERNELS.get(kernel_name) + + +def infer_fake_impl( + kernel_func: Callable, + helion_settings: "helion.Settings | None" = None, +) -> Callable: + def helion_fake_kernel(*args, **kwargs): + kernel_kwargs = {} + if helion_settings: + kernel_kwargs.update(helion_settings.to_dict()) + + temp_decorated_kernel = helion.kernel(**kernel_kwargs)(kernel_func) + + # Bind with args to get config_spec, then get a valid default config + bound = temp_decorated_kernel.bind(args) + default_config = bound.config_spec.default_config() + compiled_runner = bound.compile_config(default_config) + + return compiled_runner(*args, **kwargs, _launcher=lambda *a, **kw: None) + + return helion_fake_kernel + + +# Overloads are necessary for proper mypy type inference. +# Without overloads, the union return type HelionKernelWrapper | Callable[...] +# causes mypy to complain about missing attributes when tests do: +# wrapper = register_kernel(func) # Should return HelionKernelWrapper +# wrapper._fake_impl # mypy error: "Callable has no attribute _fake_impl" +# The overloads tell mypy the exact return type based on the argument pattern. +@overload +def register_kernel( + op_name_or_func: Callable, + *, + fake_impl: Callable | None = None, + helion_settings: "helion.Settings | None" = None, +) -> HelionKernelWrapper: ... + + +@overload +def register_kernel( + op_name_or_func: str | None = None, + *, + fake_impl: Callable | None = None, + helion_settings: "helion.Settings | None" = None, +) -> Callable[[Callable], HelionKernelWrapper]: ... + + +def register_kernel( + op_name_or_func: str | Callable | None = None, + *, + fake_impl: Callable | None = None, + helion_settings: "helion.Settings | None" = None, +) -> HelionKernelWrapper | Callable[[Callable], HelionKernelWrapper]: + """ + Decorator to register a Helion kernel function as a HelionKernelWrapper. + + Wraps the raw kernel function in a HelionKernelWrapper and registers it + in the global kernel registry. Auto-generates fake_impl if not provided. + """ + + def decorator(kernel_func: Callable) -> HelionKernelWrapper: + op_name = op_name_or_func if isinstance(op_name_or_func, str) else None + final_op_name = op_name if op_name else kernel_func.__name__ + + if final_op_name in _REGISTERED_KERNELS: + raise ValueError( + f"Helion kernel '{final_op_name}' is already registered. " + f"Use a different op_name or check for duplicate registrations." + ) + + final_fake_impl = fake_impl + if final_fake_impl is None: + final_fake_impl = infer_fake_impl(kernel_func, helion_settings) + logger.debug( + "Auto-generated fake_impl for Helion kernel '%s'", + kernel_func.__name__, + ) + + kernel_wrapper = HelionKernelWrapper( + raw_kernel_func=kernel_func, + op_name=final_op_name, + fake_impl=final_fake_impl, + helion_settings=helion_settings, + ) + + _REGISTERED_KERNELS[final_op_name] = kernel_wrapper + + logger.info( + "Registered Helion kernel '%s' as HelionKernelWrapper", + kernel_func.__name__, + ) + + return kernel_wrapper + + if callable(op_name_or_func) and not isinstance(op_name_or_func, str): + # Bare decorator usage: @register_kernel + return decorator(op_name_or_func) + else: + # Decorator with arguments: @register_kernel(...) + return decorator