[Kernel] [Helion] [3/N] Helion kernel registry (#33203)

Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
Yanan Cao
2026-01-30 23:38:46 -08:00
committed by GitHub
parent 1618e25492
commit d5c41db35b
3 changed files with 319 additions and 3 deletions

View File

@@ -27,6 +27,9 @@ from vllm.kernels.helion.config_manager import ConfigManager
from vllm.kernels.helion.register import (
ConfiguredHelionKernel,
HelionKernelWrapper,
get_kernel_by_name,
get_registered_kernels,
register_kernel,
validate_helion_settings,
)
@@ -545,3 +548,191 @@ class TestHelionKernelWrapper:
assert result is new_op
# Check that op_func is the decorated kernel, not ConfiguredHelionKernel
assert mock_register.call_args[1]["op_func"] is mock_decorated
class TestKernelRegistry:
"""Test suite for kernel registry functionality."""
def setup_method(self):
"""Clear the registry before each test."""
from vllm.kernels.helion.register import _REGISTERED_KERNELS
_REGISTERED_KERNELS.clear()
def test_get_registered_kernels_returns_copy(self):
"""Test get_registered_kernels returns copy of registry."""
result1 = get_registered_kernels()
result2 = get_registered_kernels()
# Should be separate objects
assert result1 is not result2
# Should have same content
assert result1 == result2
def test_get_kernel_by_name_returns_kernel(self):
"""Test get_kernel_by_name returns registered kernel."""
wrapper = HelionKernelWrapper(
raw_kernel_func=Mock(),
op_name="test_kernel",
fake_impl=Mock(),
)
from vllm.kernels.helion.register import _REGISTERED_KERNELS
_REGISTERED_KERNELS["test_kernel"] = wrapper
result = get_kernel_by_name("test_kernel")
assert result is wrapper
def test_get_kernel_by_name_returns_none_for_missing(self):
"""Test get_kernel_by_name returns None for missing kernel."""
result = get_kernel_by_name("nonexistent")
assert result is None
def test_register_kernel_auto_generates_fake_impl(self):
"""Test register_kernel auto-generates fake_impl when not provided."""
with patch("vllm.kernels.helion.register.infer_fake_impl") as mock_infer:
mock_fake = Mock()
mock_infer.return_value = mock_fake
def original_kernel(x):
return x
wrapper = register_kernel(original_kernel)
mock_infer.assert_called_once_with(original_kernel, None)
assert wrapper._fake_impl is mock_fake
def test_register_kernel_creates_wrapper(self):
"""Test register_kernel creates HelionKernelWrapper."""
def test_kernel(x):
return x
result = register_kernel("test_name")(test_kernel)
assert isinstance(result, HelionKernelWrapper)
assert result.op_name == "test_name"
assert result.raw_kernel_func is test_kernel
def test_register_kernel_auto_detects_name(self):
"""Test register_kernel uses function name when no name provided."""
@register_kernel
def my_test_kernel(x):
return x
assert my_test_kernel.op_name == "my_test_kernel"
def test_register_kernel_registers_in_global_registry(self):
"""Test register_kernel adds wrapper to global registry."""
@register_kernel
def test_kernel(x):
return x
registered_kernels = get_registered_kernels()
assert "test_kernel" in registered_kernels
assert registered_kernels["test_kernel"] is test_kernel
def test_register_kernel_passes_helion_settings(self):
"""Test register_kernel passes helion_settings to wrapper."""
mock_settings = Mock()
mock_settings.to_dict.return_value = {"debug": True}
@register_kernel("test_name", helion_settings=mock_settings)
def test_kernel(x):
return x
assert test_kernel.helion_settings is mock_settings
def test_register_kernel_supports_decorator_syntax(self):
"""Test register_kernel works with decorator arguments."""
mock_fake = Mock()
wrapper = register_kernel("custom_name", fake_impl=mock_fake)
def test_kernel(x):
return x
result = wrapper(test_kernel)
assert result.op_name == "custom_name"
assert result._fake_impl is mock_fake
def test_register_kernel_bare_decorator(self):
"""Test register_kernel works as bare decorator."""
@register_kernel
def test_kernel(x):
return x
assert isinstance(test_kernel, HelionKernelWrapper)
assert test_kernel.op_name == "test_kernel"
def test_registered_wrapper_can_register_config_picker(self):
"""Test that registered wrapper can register config picker."""
@register_kernel
def test_kernel(x):
return x
def my_picker(args, config_keys):
return "default"
result = test_kernel.register_config_picker(my_picker)
assert result is my_picker
assert test_kernel._config_picker is my_picker
def test_register_kernel_raises_on_duplicate_registration(self):
"""Test register_kernel raises error on duplicate names."""
@register_kernel("duplicate_name")
def kernel1(x):
return x
with pytest.raises(ValueError, match="already registered"):
@register_kernel("duplicate_name")
def kernel2(x):
return x
def test_register_kernel_rejects_autotuner_fn_in_settings(self):
"""Test register_kernel rejects conflicting autotuner_fn."""
mock_settings = Mock()
mock_settings.to_dict.return_value = {"autotuner_fn": Mock()}
with pytest.raises(ValueError, match="uses a custom autotuner"):
@register_kernel("test", helion_settings=mock_settings)
def test_kernel(x):
return x
def test_register_kernel_warns_with_static_shapes_true(self):
"""Test register_kernel warns when static_shapes=True."""
mock_settings = Mock()
mock_settings.to_dict.return_value = {"static_shapes": True}
with patch("vllm.kernels.helion.register.logger") as mock_logger:
@register_kernel("test", helion_settings=mock_settings)
def test_kernel(x):
return x
mock_logger.warning.assert_called_once()
assert "static_shapes=True" in mock_logger.warning.call_args[0][0]
def test_register_kernel_no_warning_with_static_shapes_false(self):
"""Test register_kernel doesn't warn with static_shapes=False."""
mock_settings = Mock()
mock_settings.to_dict.return_value = {"static_shapes": False}
with patch("vllm.kernels.helion.register.logger") as mock_logger:
@register_kernel("test", helion_settings=mock_settings)
def test_kernel(x):
return x
# Should not call warning
mock_logger.warning.assert_not_called()