[Kernel] [Helion] [3/N] Helion kernel registry (#33203)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
@@ -27,6 +27,9 @@ from vllm.kernels.helion.config_manager import ConfigManager
|
|||||||
from vllm.kernels.helion.register import (
|
from vllm.kernels.helion.register import (
|
||||||
ConfiguredHelionKernel,
|
ConfiguredHelionKernel,
|
||||||
HelionKernelWrapper,
|
HelionKernelWrapper,
|
||||||
|
get_kernel_by_name,
|
||||||
|
get_registered_kernels,
|
||||||
|
register_kernel,
|
||||||
validate_helion_settings,
|
validate_helion_settings,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -545,3 +548,191 @@ class TestHelionKernelWrapper:
|
|||||||
assert result is new_op
|
assert result is new_op
|
||||||
# Check that op_func is the decorated kernel, not ConfiguredHelionKernel
|
# Check that op_func is the decorated kernel, not ConfiguredHelionKernel
|
||||||
assert mock_register.call_args[1]["op_func"] is mock_decorated
|
assert mock_register.call_args[1]["op_func"] is mock_decorated
|
||||||
|
|
||||||
|
|
||||||
|
class TestKernelRegistry:
|
||||||
|
"""Test suite for kernel registry functionality."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Clear the registry before each test."""
|
||||||
|
from vllm.kernels.helion.register import _REGISTERED_KERNELS
|
||||||
|
|
||||||
|
_REGISTERED_KERNELS.clear()
|
||||||
|
|
||||||
|
def test_get_registered_kernels_returns_copy(self):
|
||||||
|
"""Test get_registered_kernels returns copy of registry."""
|
||||||
|
result1 = get_registered_kernels()
|
||||||
|
result2 = get_registered_kernels()
|
||||||
|
|
||||||
|
# Should be separate objects
|
||||||
|
assert result1 is not result2
|
||||||
|
# Should have same content
|
||||||
|
assert result1 == result2
|
||||||
|
|
||||||
|
def test_get_kernel_by_name_returns_kernel(self):
|
||||||
|
"""Test get_kernel_by_name returns registered kernel."""
|
||||||
|
wrapper = HelionKernelWrapper(
|
||||||
|
raw_kernel_func=Mock(),
|
||||||
|
op_name="test_kernel",
|
||||||
|
fake_impl=Mock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
from vllm.kernels.helion.register import _REGISTERED_KERNELS
|
||||||
|
|
||||||
|
_REGISTERED_KERNELS["test_kernel"] = wrapper
|
||||||
|
|
||||||
|
result = get_kernel_by_name("test_kernel")
|
||||||
|
assert result is wrapper
|
||||||
|
|
||||||
|
def test_get_kernel_by_name_returns_none_for_missing(self):
|
||||||
|
"""Test get_kernel_by_name returns None for missing kernel."""
|
||||||
|
result = get_kernel_by_name("nonexistent")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_register_kernel_auto_generates_fake_impl(self):
|
||||||
|
"""Test register_kernel auto-generates fake_impl when not provided."""
|
||||||
|
with patch("vllm.kernels.helion.register.infer_fake_impl") as mock_infer:
|
||||||
|
mock_fake = Mock()
|
||||||
|
mock_infer.return_value = mock_fake
|
||||||
|
|
||||||
|
def original_kernel(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
wrapper = register_kernel(original_kernel)
|
||||||
|
|
||||||
|
mock_infer.assert_called_once_with(original_kernel, None)
|
||||||
|
assert wrapper._fake_impl is mock_fake
|
||||||
|
|
||||||
|
def test_register_kernel_creates_wrapper(self):
|
||||||
|
"""Test register_kernel creates HelionKernelWrapper."""
|
||||||
|
|
||||||
|
def test_kernel(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
result = register_kernel("test_name")(test_kernel)
|
||||||
|
|
||||||
|
assert isinstance(result, HelionKernelWrapper)
|
||||||
|
assert result.op_name == "test_name"
|
||||||
|
assert result.raw_kernel_func is test_kernel
|
||||||
|
|
||||||
|
def test_register_kernel_auto_detects_name(self):
|
||||||
|
"""Test register_kernel uses function name when no name provided."""
|
||||||
|
|
||||||
|
@register_kernel
|
||||||
|
def my_test_kernel(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
assert my_test_kernel.op_name == "my_test_kernel"
|
||||||
|
|
||||||
|
def test_register_kernel_registers_in_global_registry(self):
|
||||||
|
"""Test register_kernel adds wrapper to global registry."""
|
||||||
|
|
||||||
|
@register_kernel
|
||||||
|
def test_kernel(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
registered_kernels = get_registered_kernels()
|
||||||
|
assert "test_kernel" in registered_kernels
|
||||||
|
assert registered_kernels["test_kernel"] is test_kernel
|
||||||
|
|
||||||
|
def test_register_kernel_passes_helion_settings(self):
|
||||||
|
"""Test register_kernel passes helion_settings to wrapper."""
|
||||||
|
mock_settings = Mock()
|
||||||
|
mock_settings.to_dict.return_value = {"debug": True}
|
||||||
|
|
||||||
|
@register_kernel("test_name", helion_settings=mock_settings)
|
||||||
|
def test_kernel(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
assert test_kernel.helion_settings is mock_settings
|
||||||
|
|
||||||
|
def test_register_kernel_supports_decorator_syntax(self):
|
||||||
|
"""Test register_kernel works with decorator arguments."""
|
||||||
|
mock_fake = Mock()
|
||||||
|
|
||||||
|
wrapper = register_kernel("custom_name", fake_impl=mock_fake)
|
||||||
|
|
||||||
|
def test_kernel(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
result = wrapper(test_kernel)
|
||||||
|
|
||||||
|
assert result.op_name == "custom_name"
|
||||||
|
assert result._fake_impl is mock_fake
|
||||||
|
|
||||||
|
def test_register_kernel_bare_decorator(self):
|
||||||
|
"""Test register_kernel works as bare decorator."""
|
||||||
|
|
||||||
|
@register_kernel
|
||||||
|
def test_kernel(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
assert isinstance(test_kernel, HelionKernelWrapper)
|
||||||
|
assert test_kernel.op_name == "test_kernel"
|
||||||
|
|
||||||
|
def test_registered_wrapper_can_register_config_picker(self):
|
||||||
|
"""Test that registered wrapper can register config picker."""
|
||||||
|
|
||||||
|
@register_kernel
|
||||||
|
def test_kernel(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def my_picker(args, config_keys):
|
||||||
|
return "default"
|
||||||
|
|
||||||
|
result = test_kernel.register_config_picker(my_picker)
|
||||||
|
|
||||||
|
assert result is my_picker
|
||||||
|
assert test_kernel._config_picker is my_picker
|
||||||
|
|
||||||
|
def test_register_kernel_raises_on_duplicate_registration(self):
|
||||||
|
"""Test register_kernel raises error on duplicate names."""
|
||||||
|
|
||||||
|
@register_kernel("duplicate_name")
|
||||||
|
def kernel1(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="already registered"):
|
||||||
|
|
||||||
|
@register_kernel("duplicate_name")
|
||||||
|
def kernel2(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def test_register_kernel_rejects_autotuner_fn_in_settings(self):
|
||||||
|
"""Test register_kernel rejects conflicting autotuner_fn."""
|
||||||
|
mock_settings = Mock()
|
||||||
|
mock_settings.to_dict.return_value = {"autotuner_fn": Mock()}
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="uses a custom autotuner"):
|
||||||
|
|
||||||
|
@register_kernel("test", helion_settings=mock_settings)
|
||||||
|
def test_kernel(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def test_register_kernel_warns_with_static_shapes_true(self):
|
||||||
|
"""Test register_kernel warns when static_shapes=True."""
|
||||||
|
mock_settings = Mock()
|
||||||
|
mock_settings.to_dict.return_value = {"static_shapes": True}
|
||||||
|
|
||||||
|
with patch("vllm.kernels.helion.register.logger") as mock_logger:
|
||||||
|
|
||||||
|
@register_kernel("test", helion_settings=mock_settings)
|
||||||
|
def test_kernel(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
mock_logger.warning.assert_called_once()
|
||||||
|
assert "static_shapes=True" in mock_logger.warning.call_args[0][0]
|
||||||
|
|
||||||
|
def test_register_kernel_no_warning_with_static_shapes_false(self):
|
||||||
|
"""Test register_kernel doesn't warn with static_shapes=False."""
|
||||||
|
mock_settings = Mock()
|
||||||
|
mock_settings.to_dict.return_value = {"static_shapes": False}
|
||||||
|
|
||||||
|
with patch("vllm.kernels.helion.register.logger") as mock_logger:
|
||||||
|
|
||||||
|
@register_kernel("test", helion_settings=mock_settings)
|
||||||
|
def test_kernel(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
# Should not call warning
|
||||||
|
mock_logger.warning.assert_not_called()
|
||||||
|
|||||||
@@ -9,6 +9,9 @@ from vllm.kernels.helion.config_manager import (
|
|||||||
from vllm.kernels.helion.register import (
|
from vllm.kernels.helion.register import (
|
||||||
ConfiguredHelionKernel,
|
ConfiguredHelionKernel,
|
||||||
HelionKernelWrapper,
|
HelionKernelWrapper,
|
||||||
|
get_kernel_by_name,
|
||||||
|
get_registered_kernels,
|
||||||
|
register_kernel,
|
||||||
vllm_helion_lib,
|
vllm_helion_lib,
|
||||||
)
|
)
|
||||||
from vllm.kernels.helion.utils import canonicalize_gpu_name, get_canonical_gpu_name
|
from vllm.kernels.helion.utils import canonicalize_gpu_name, get_canonical_gpu_name
|
||||||
@@ -20,6 +23,9 @@ __all__ = [
|
|||||||
# Kernel registration
|
# Kernel registration
|
||||||
"ConfiguredHelionKernel",
|
"ConfiguredHelionKernel",
|
||||||
"HelionKernelWrapper",
|
"HelionKernelWrapper",
|
||||||
|
"get_kernel_by_name",
|
||||||
|
"get_registered_kernels",
|
||||||
|
"register_kernel",
|
||||||
"vllm_helion_lib",
|
"vllm_helion_lib",
|
||||||
# Utilities
|
# Utilities
|
||||||
"canonicalize_gpu_name",
|
"canonicalize_gpu_name",
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ Key Classes
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any, cast, overload
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.library import Library
|
from torch.library import Library
|
||||||
@@ -114,7 +114,7 @@ class ConfiguredHelionKernel:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
op_name: str,
|
op_name: str,
|
||||||
config_picker: Callable[[tuple[Any, ...], list[str]], str | None],
|
config_picker: Callable[[tuple[Any, ...], list[str]], str | None] | None,
|
||||||
raw_kernel_func: Callable,
|
raw_kernel_func: Callable,
|
||||||
helion_settings: "helion.Settings | None" = None,
|
helion_settings: "helion.Settings | None" = None,
|
||||||
):
|
):
|
||||||
@@ -140,9 +140,16 @@ class ConfiguredHelionKernel:
|
|||||||
f"Use @{self.op_name}.register_config_picker to register one."
|
f"Use @{self.op_name}.register_config_picker to register one."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# After None check, config_picker is guaranteed to be non-None
|
||||||
|
assert self.config_picker is not None
|
||||||
|
|
||||||
def key_computer(*args):
|
def key_computer(*args):
|
||||||
config_keys = list(self.configs.keys())
|
config_keys = list(self.configs.keys())
|
||||||
selected_key = self.config_picker(args, config_keys)
|
# Cast is safe because we checked for None above
|
||||||
|
config_picker = cast(
|
||||||
|
Callable[[tuple[Any, ...], list[str]], str | None], self.config_picker
|
||||||
|
)
|
||||||
|
selected_key = config_picker(args, config_keys)
|
||||||
if selected_key:
|
if selected_key:
|
||||||
return selected_key
|
return selected_key
|
||||||
return "default" if "default" in self.configs else None
|
return "default" if "default" in self.configs else None
|
||||||
@@ -272,3 +279,115 @@ class HelionKernelWrapper:
|
|||||||
target_lib=vllm_helion_lib,
|
target_lib=vllm_helion_lib,
|
||||||
)
|
)
|
||||||
return getattr(torch.ops.vllm_helion, self.op_name)
|
return getattr(torch.ops.vllm_helion, self.op_name)
|
||||||
|
|
||||||
|
|
||||||
|
# Global registry for tracking all registered HelionKernelWrapper instances
|
||||||
|
_REGISTERED_KERNELS: dict[str, HelionKernelWrapper] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_registered_kernels() -> dict[str, HelionKernelWrapper]:
|
||||||
|
return _REGISTERED_KERNELS.copy()
|
||||||
|
|
||||||
|
|
||||||
|
def get_kernel_by_name(kernel_name: str) -> HelionKernelWrapper | None:
|
||||||
|
return _REGISTERED_KERNELS.get(kernel_name)
|
||||||
|
|
||||||
|
|
||||||
|
def infer_fake_impl(
|
||||||
|
kernel_func: Callable,
|
||||||
|
helion_settings: "helion.Settings | None" = None,
|
||||||
|
) -> Callable:
|
||||||
|
def helion_fake_kernel(*args, **kwargs):
|
||||||
|
kernel_kwargs = {}
|
||||||
|
if helion_settings:
|
||||||
|
kernel_kwargs.update(helion_settings.to_dict())
|
||||||
|
|
||||||
|
temp_decorated_kernel = helion.kernel(**kernel_kwargs)(kernel_func)
|
||||||
|
|
||||||
|
# Bind with args to get config_spec, then get a valid default config
|
||||||
|
bound = temp_decorated_kernel.bind(args)
|
||||||
|
default_config = bound.config_spec.default_config()
|
||||||
|
compiled_runner = bound.compile_config(default_config)
|
||||||
|
|
||||||
|
return compiled_runner(*args, **kwargs, _launcher=lambda *a, **kw: None)
|
||||||
|
|
||||||
|
return helion_fake_kernel
|
||||||
|
|
||||||
|
|
||||||
|
# Overloads are necessary for proper mypy type inference.
|
||||||
|
# Without overloads, the union return type HelionKernelWrapper | Callable[...]
|
||||||
|
# causes mypy to complain about missing attributes when tests do:
|
||||||
|
# wrapper = register_kernel(func) # Should return HelionKernelWrapper
|
||||||
|
# wrapper._fake_impl # mypy error: "Callable has no attribute _fake_impl"
|
||||||
|
# The overloads tell mypy the exact return type based on the argument pattern.
|
||||||
|
@overload
|
||||||
|
def register_kernel(
|
||||||
|
op_name_or_func: Callable,
|
||||||
|
*,
|
||||||
|
fake_impl: Callable | None = None,
|
||||||
|
helion_settings: "helion.Settings | None" = None,
|
||||||
|
) -> HelionKernelWrapper: ...
|
||||||
|
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def register_kernel(
|
||||||
|
op_name_or_func: str | None = None,
|
||||||
|
*,
|
||||||
|
fake_impl: Callable | None = None,
|
||||||
|
helion_settings: "helion.Settings | None" = None,
|
||||||
|
) -> Callable[[Callable], HelionKernelWrapper]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def register_kernel(
|
||||||
|
op_name_or_func: str | Callable | None = None,
|
||||||
|
*,
|
||||||
|
fake_impl: Callable | None = None,
|
||||||
|
helion_settings: "helion.Settings | None" = None,
|
||||||
|
) -> HelionKernelWrapper | Callable[[Callable], HelionKernelWrapper]:
|
||||||
|
"""
|
||||||
|
Decorator to register a Helion kernel function as a HelionKernelWrapper.
|
||||||
|
|
||||||
|
Wraps the raw kernel function in a HelionKernelWrapper and registers it
|
||||||
|
in the global kernel registry. Auto-generates fake_impl if not provided.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(kernel_func: Callable) -> HelionKernelWrapper:
|
||||||
|
op_name = op_name_or_func if isinstance(op_name_or_func, str) else None
|
||||||
|
final_op_name = op_name if op_name else kernel_func.__name__
|
||||||
|
|
||||||
|
if final_op_name in _REGISTERED_KERNELS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Helion kernel '{final_op_name}' is already registered. "
|
||||||
|
f"Use a different op_name or check for duplicate registrations."
|
||||||
|
)
|
||||||
|
|
||||||
|
final_fake_impl = fake_impl
|
||||||
|
if final_fake_impl is None:
|
||||||
|
final_fake_impl = infer_fake_impl(kernel_func, helion_settings)
|
||||||
|
logger.debug(
|
||||||
|
"Auto-generated fake_impl for Helion kernel '%s'",
|
||||||
|
kernel_func.__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
kernel_wrapper = HelionKernelWrapper(
|
||||||
|
raw_kernel_func=kernel_func,
|
||||||
|
op_name=final_op_name,
|
||||||
|
fake_impl=final_fake_impl,
|
||||||
|
helion_settings=helion_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
_REGISTERED_KERNELS[final_op_name] = kernel_wrapper
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Registered Helion kernel '%s' as HelionKernelWrapper",
|
||||||
|
kernel_func.__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
return kernel_wrapper
|
||||||
|
|
||||||
|
if callable(op_name_or_func) and not isinstance(op_name_or_func, str):
|
||||||
|
# Bare decorator usage: @register_kernel
|
||||||
|
return decorator(op_name_or_func)
|
||||||
|
else:
|
||||||
|
# Decorator with arguments: @register_kernel(...)
|
||||||
|
return decorator
|
||||||
|
|||||||
Reference in New Issue
Block a user