[Kernel] [Helion] [3/N] Helion kernel registry (#33203)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
@@ -27,6 +27,9 @@ from vllm.kernels.helion.config_manager import ConfigManager
|
||||
from vllm.kernels.helion.register import (
|
||||
ConfiguredHelionKernel,
|
||||
HelionKernelWrapper,
|
||||
get_kernel_by_name,
|
||||
get_registered_kernels,
|
||||
register_kernel,
|
||||
validate_helion_settings,
|
||||
)
|
||||
|
||||
@@ -545,3 +548,191 @@ class TestHelionKernelWrapper:
|
||||
assert result is new_op
|
||||
# Check that op_func is the decorated kernel, not ConfiguredHelionKernel
|
||||
assert mock_register.call_args[1]["op_func"] is mock_decorated
|
||||
|
||||
|
||||
class TestKernelRegistry:
|
||||
"""Test suite for kernel registry functionality."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Clear the registry before each test."""
|
||||
from vllm.kernels.helion.register import _REGISTERED_KERNELS
|
||||
|
||||
_REGISTERED_KERNELS.clear()
|
||||
|
||||
def test_get_registered_kernels_returns_copy(self):
|
||||
"""Test get_registered_kernels returns copy of registry."""
|
||||
result1 = get_registered_kernels()
|
||||
result2 = get_registered_kernels()
|
||||
|
||||
# Should be separate objects
|
||||
assert result1 is not result2
|
||||
# Should have same content
|
||||
assert result1 == result2
|
||||
|
||||
def test_get_kernel_by_name_returns_kernel(self):
|
||||
"""Test get_kernel_by_name returns registered kernel."""
|
||||
wrapper = HelionKernelWrapper(
|
||||
raw_kernel_func=Mock(),
|
||||
op_name="test_kernel",
|
||||
fake_impl=Mock(),
|
||||
)
|
||||
|
||||
from vllm.kernels.helion.register import _REGISTERED_KERNELS
|
||||
|
||||
_REGISTERED_KERNELS["test_kernel"] = wrapper
|
||||
|
||||
result = get_kernel_by_name("test_kernel")
|
||||
assert result is wrapper
|
||||
|
||||
def test_get_kernel_by_name_returns_none_for_missing(self):
|
||||
"""Test get_kernel_by_name returns None for missing kernel."""
|
||||
result = get_kernel_by_name("nonexistent")
|
||||
assert result is None
|
||||
|
||||
def test_register_kernel_auto_generates_fake_impl(self):
|
||||
"""Test register_kernel auto-generates fake_impl when not provided."""
|
||||
with patch("vllm.kernels.helion.register.infer_fake_impl") as mock_infer:
|
||||
mock_fake = Mock()
|
||||
mock_infer.return_value = mock_fake
|
||||
|
||||
def original_kernel(x):
|
||||
return x
|
||||
|
||||
wrapper = register_kernel(original_kernel)
|
||||
|
||||
mock_infer.assert_called_once_with(original_kernel, None)
|
||||
assert wrapper._fake_impl is mock_fake
|
||||
|
||||
def test_register_kernel_creates_wrapper(self):
|
||||
"""Test register_kernel creates HelionKernelWrapper."""
|
||||
|
||||
def test_kernel(x):
|
||||
return x
|
||||
|
||||
result = register_kernel("test_name")(test_kernel)
|
||||
|
||||
assert isinstance(result, HelionKernelWrapper)
|
||||
assert result.op_name == "test_name"
|
||||
assert result.raw_kernel_func is test_kernel
|
||||
|
||||
def test_register_kernel_auto_detects_name(self):
|
||||
"""Test register_kernel uses function name when no name provided."""
|
||||
|
||||
@register_kernel
|
||||
def my_test_kernel(x):
|
||||
return x
|
||||
|
||||
assert my_test_kernel.op_name == "my_test_kernel"
|
||||
|
||||
def test_register_kernel_registers_in_global_registry(self):
|
||||
"""Test register_kernel adds wrapper to global registry."""
|
||||
|
||||
@register_kernel
|
||||
def test_kernel(x):
|
||||
return x
|
||||
|
||||
registered_kernels = get_registered_kernels()
|
||||
assert "test_kernel" in registered_kernels
|
||||
assert registered_kernels["test_kernel"] is test_kernel
|
||||
|
||||
def test_register_kernel_passes_helion_settings(self):
|
||||
"""Test register_kernel passes helion_settings to wrapper."""
|
||||
mock_settings = Mock()
|
||||
mock_settings.to_dict.return_value = {"debug": True}
|
||||
|
||||
@register_kernel("test_name", helion_settings=mock_settings)
|
||||
def test_kernel(x):
|
||||
return x
|
||||
|
||||
assert test_kernel.helion_settings is mock_settings
|
||||
|
||||
def test_register_kernel_supports_decorator_syntax(self):
|
||||
"""Test register_kernel works with decorator arguments."""
|
||||
mock_fake = Mock()
|
||||
|
||||
wrapper = register_kernel("custom_name", fake_impl=mock_fake)
|
||||
|
||||
def test_kernel(x):
|
||||
return x
|
||||
|
||||
result = wrapper(test_kernel)
|
||||
|
||||
assert result.op_name == "custom_name"
|
||||
assert result._fake_impl is mock_fake
|
||||
|
||||
def test_register_kernel_bare_decorator(self):
|
||||
"""Test register_kernel works as bare decorator."""
|
||||
|
||||
@register_kernel
|
||||
def test_kernel(x):
|
||||
return x
|
||||
|
||||
assert isinstance(test_kernel, HelionKernelWrapper)
|
||||
assert test_kernel.op_name == "test_kernel"
|
||||
|
||||
def test_registered_wrapper_can_register_config_picker(self):
|
||||
"""Test that registered wrapper can register config picker."""
|
||||
|
||||
@register_kernel
|
||||
def test_kernel(x):
|
||||
return x
|
||||
|
||||
def my_picker(args, config_keys):
|
||||
return "default"
|
||||
|
||||
result = test_kernel.register_config_picker(my_picker)
|
||||
|
||||
assert result is my_picker
|
||||
assert test_kernel._config_picker is my_picker
|
||||
|
||||
def test_register_kernel_raises_on_duplicate_registration(self):
|
||||
"""Test register_kernel raises error on duplicate names."""
|
||||
|
||||
@register_kernel("duplicate_name")
|
||||
def kernel1(x):
|
||||
return x
|
||||
|
||||
with pytest.raises(ValueError, match="already registered"):
|
||||
|
||||
@register_kernel("duplicate_name")
|
||||
def kernel2(x):
|
||||
return x
|
||||
|
||||
def test_register_kernel_rejects_autotuner_fn_in_settings(self):
|
||||
"""Test register_kernel rejects conflicting autotuner_fn."""
|
||||
mock_settings = Mock()
|
||||
mock_settings.to_dict.return_value = {"autotuner_fn": Mock()}
|
||||
|
||||
with pytest.raises(ValueError, match="uses a custom autotuner"):
|
||||
|
||||
@register_kernel("test", helion_settings=mock_settings)
|
||||
def test_kernel(x):
|
||||
return x
|
||||
|
||||
def test_register_kernel_warns_with_static_shapes_true(self):
|
||||
"""Test register_kernel warns when static_shapes=True."""
|
||||
mock_settings = Mock()
|
||||
mock_settings.to_dict.return_value = {"static_shapes": True}
|
||||
|
||||
with patch("vllm.kernels.helion.register.logger") as mock_logger:
|
||||
|
||||
@register_kernel("test", helion_settings=mock_settings)
|
||||
def test_kernel(x):
|
||||
return x
|
||||
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "static_shapes=True" in mock_logger.warning.call_args[0][0]
|
||||
|
||||
def test_register_kernel_no_warning_with_static_shapes_false(self):
|
||||
"""Test register_kernel doesn't warn with static_shapes=False."""
|
||||
mock_settings = Mock()
|
||||
mock_settings.to_dict.return_value = {"static_shapes": False}
|
||||
|
||||
with patch("vllm.kernels.helion.register.logger") as mock_logger:
|
||||
|
||||
@register_kernel("test", helion_settings=mock_settings)
|
||||
def test_kernel(x):
|
||||
return x
|
||||
|
||||
# Should not call warning
|
||||
mock_logger.warning.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user