diff --git a/tests/kernels/helion/test_register.py b/tests/kernels/helion/test_register.py new file mode 100644 index 000000000..c6cd3f77d --- /dev/null +++ b/tests/kernels/helion/test_register.py @@ -0,0 +1,547 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for Helion kernel registration. + +Tests ConfiguredHelionKernel, HelionKernelWrapper, and PresetConfigSearch +including config picker registration, custom autotuner integration, and +PyTorch op registration. +""" + +from unittest.mock import Mock, patch + +import pytest +import torch + +from vllm.utils.import_utils import has_helion + +if not has_helion(): + pytest.skip( + "Helion is not installed. Install with: pip install vllm[helion]", + allow_module_level=True, + ) + +import helion + +from vllm.kernels.helion.config_manager import ConfigManager +from vllm.kernels.helion.register import ( + ConfiguredHelionKernel, + HelionKernelWrapper, + validate_helion_settings, +) + + +@pytest.fixture +def sample_configs(): + """Create real Helion config objects for testing.""" + return { + "hiddensize_4096_batchsize_32": helion.Config( + block_sizes=[128], + num_warps=4, + num_stages=3, + ), + "hiddensize_4096_batchsize_64": helion.Config( + block_sizes=[256], + num_warps=8, + num_stages=4, + ), + "hiddensize_4096_batchsize_128": helion.Config( + block_sizes=[512], + num_warps=16, + num_stages=2, + ), + "default": helion.Config( + block_sizes=[64], + num_warps=2, + num_stages=2, + ), + } + + +@pytest.fixture +def sample_kernel(): + """Create a simple test kernel function.""" + + def test_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Simple test kernel that adds two tensors.""" + return x + y + + return test_kernel + + +@pytest.fixture +def config_manager_with_test_configs(sample_configs): + """Set up ConfigManager with test configs for nvidia_h200 platform.""" + mock_config_manager = Mock(spec=ConfigManager) + mock_config_manager.get_platform_configs = Mock(return_value=sample_configs) + return mock_config_manager + + +@pytest.fixture +def configured_kernel(sample_kernel, sample_configs, config_manager_with_test_configs): + """Create a ConfiguredHelionKernel for testing.""" + + def test_config_picker(args, config_keys): + """Simple config picker that returns default.""" + return "default" + + with ( + patch( + "vllm.kernels.helion.config_manager.ConfigManager.get_instance", + return_value=config_manager_with_test_configs, + ), + patch( + "vllm.kernels.helion.utils.get_canonical_gpu_name", + return_value="nvidia_h200", + ), + patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel, + ): + # Mock just the helion.kernel decorator to avoid actual kernel compilation + mock_decorated = Mock() + mock_kernel.return_value = Mock(return_value=mock_decorated) + + return ConfiguredHelionKernel( + op_name="test_kernel", + config_picker=test_config_picker, + raw_kernel_func=sample_kernel, + helion_settings=None, + ) + + +class TestValidateHelionSettings: + """Test suite for validate_helion_settings utility function.""" + + def test_accepts_none_settings(self): + """Test that None settings are accepted without error.""" + validate_helion_settings(None, "test_kernel") # Should not raise + + def test_accepts_valid_settings(self): + """Test that valid settings without conflicts are accepted.""" + settings = helion.Settings() + settings.static_shapes = False + settings.print_output_code = True + validate_helion_settings(settings, "test_kernel") # Should not raise + + def test_rejects_autotuner_fn(self): + """Test that settings with custom autotuner_fn raise ValueError.""" + settings = helion.Settings() + settings.autotuner_fn = lambda *args: None # Set custom autotuner function + + with pytest.raises(ValueError, match="uses a custom autotuner"): + validate_helion_settings(settings, "test_kernel") + + def test_warns_on_static_shapes_true(self): + """Test that static_shapes=True emits a warning.""" + settings = helion.Settings() + settings.static_shapes = True + + with patch("vllm.kernels.helion.register.logger") as mock_logger: + validate_helion_settings(settings, "test_kernel") + mock_logger.warning.assert_called_once() + assert "static_shapes=True" in mock_logger.warning.call_args[0][0] + + +def create_configured_kernel_with_configs( + op_name, + config_picker, + kernel_func, + configs, + platform="nvidia_h200", + helion_settings=None, +): + """Helper to create ConfiguredHelionKernel with real config objects.""" + mock_config_manager = Mock(spec=ConfigManager) + mock_config_manager.get_platform_configs = Mock(return_value=configs) + + with ( + patch( + "vllm.kernels.helion.config_manager.ConfigManager.get_instance", + return_value=mock_config_manager, + ), + patch( + "vllm.kernels.helion.utils.get_canonical_gpu_name", + return_value=platform, + ), + patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel, + ): + mock_decorated = Mock() + mock_kernel.return_value = Mock(return_value=mock_decorated) + + return ConfiguredHelionKernel( + op_name=op_name, + config_picker=config_picker, + raw_kernel_func=kernel_func, + helion_settings=helion_settings, + ) + + +class TestConfiguredHelionKernel: + """Test suite for ConfiguredHelionKernel.""" + + def test_init_raises_without_picker(self, sample_kernel, sample_configs): + """Test that __init__ raises when no picker registered.""" + configs = {"default": sample_configs["default"]} + mock_config_manager = Mock(spec=ConfigManager) + mock_config_manager.get_platform_configs = Mock(return_value=configs) + + with ( + patch( + "vllm.kernels.helion.config_manager.ConfigManager.get_instance", + return_value=mock_config_manager, + ), + patch( + "vllm.kernels.helion.utils.get_canonical_gpu_name", + return_value="nvidia_h200", + ), + pytest.raises(RuntimeError, match="No config picker registered"), + ): + ConfiguredHelionKernel( + op_name="test_kernel", + config_picker=None, # No picker registered + raw_kernel_func=sample_kernel, + helion_settings=None, + ) + + def test_config_selector_validates_picker_result( + self, sample_kernel, sample_configs + ): + """Test that config selector validates picker returns valid key.""" + + def invalid_picker(args, config_keys): + return "invalid_key" + + kernel = create_configured_kernel_with_configs( + op_name="test_kernel", + config_picker=invalid_picker, + kernel_func=sample_kernel, + configs=sample_configs, + ) + + key_computer = kernel._create_key_computer() + selector = kernel._create_config_selector(key_computer) + + with pytest.raises( + ValueError, match="Config picker returned invalid config key" + ): + selector((torch.randn(32, 4096),)) + + def test_config_selector_handles_none_from_picker( + self, sample_kernel, sample_configs + ): + """Test that config selector falls back to 'default' on None.""" + + def none_picker(args, config_keys): + return None + + kernel = create_configured_kernel_with_configs( + op_name="test_kernel", + config_picker=none_picker, + kernel_func=sample_kernel, + configs=sample_configs, + ) + + key_computer = kernel._create_key_computer() + selector = kernel._create_config_selector(key_computer) + + result = selector((torch.randn(32, 4096),)) + assert result is kernel.configs["default"] + + def test_create_decorated_kernel_passes_helion_settings( + self, sample_kernel, sample_configs + ): + """Test that _create_decorated_kernel passes helion_settings.""" + + def default_picker(args, config_keys): + return "default" + + settings = helion.Settings() + settings.print_output_code = True + # Note: helion.Settings() defaults static_shapes to True + + mock_config_manager = Mock(spec=ConfigManager) + mock_config_manager.get_platform_configs = Mock(return_value=sample_configs) + + with ( + patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel, + patch( + "vllm.kernels.helion.config_manager.ConfigManager.get_instance", + return_value=mock_config_manager, + ), + patch( + "vllm.kernels.helion.utils.get_canonical_gpu_name", + return_value="nvidia_h200", + ), + ): + mock_decorated = Mock() + mock_kernel.return_value = Mock(return_value=mock_decorated) + + ConfiguredHelionKernel( + op_name="test_kernel", + config_picker=default_picker, + raw_kernel_func=sample_kernel, + helion_settings=settings, + ) + + call_kwargs = mock_kernel.call_args[1] + assert "print_output_code" in call_kwargs + assert call_kwargs["print_output_code"] is True + # helion.Settings() defaults to static_shapes=True, so it should remain True + assert call_kwargs["static_shapes"] is True + + def test_create_decorated_kernel_preserves_static_shapes_true( + self, sample_kernel, sample_configs + ): + """Test that explicit static_shapes=True is preserved.""" + + def default_picker(args, config_keys): + return "default" + + settings = helion.Settings() + settings.static_shapes = True + + mock_config_manager = Mock(spec=ConfigManager) + mock_config_manager.get_platform_configs = Mock(return_value=sample_configs) + + with ( + patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel, + patch( + "vllm.kernels.helion.config_manager.ConfigManager.get_instance", + return_value=mock_config_manager, + ), + patch( + "vllm.kernels.helion.utils.get_canonical_gpu_name", + return_value="nvidia_h200", + ), + ): + mock_decorated = Mock() + mock_kernel.return_value = Mock(return_value=mock_decorated) + + ConfiguredHelionKernel( + op_name="test_kernel", + config_picker=default_picker, + raw_kernel_func=sample_kernel, + helion_settings=settings, + ) + + call_kwargs = mock_kernel.call_args[1] + assert call_kwargs["static_shapes"] is True + + def test_key_and_config_selector_use_same_logic( + self, sample_kernel, sample_configs + ): + """Test that key and config_selector produce identical results.""" + + def tracking_picker(args, config_keys): + x = args[0] + batch_size = x.shape[0] + if batch_size <= 32: + return "hiddensize_4096_batchsize_32" + elif batch_size <= 64: + return "hiddensize_4096_batchsize_64" + return "hiddensize_4096_batchsize_128" + + mock_config_manager = Mock(spec=ConfigManager) + mock_config_manager.get_platform_configs = Mock(return_value=sample_configs) + + with ( + patch("vllm.kernels.helion.register.helion.kernel") as mock_helion_kernel, + patch( + "vllm.kernels.helion.config_manager.ConfigManager.get_instance", + return_value=mock_config_manager, + ), + patch( + "vllm.kernels.helion.utils.get_canonical_gpu_name", + return_value="nvidia_h200", + ), + ): + mock_decorated = Mock() + mock_helion_kernel.return_value = Mock(return_value=mock_decorated) + + kernel = ConfiguredHelionKernel( + op_name="test_kernel", + config_picker=tracking_picker, + raw_kernel_func=sample_kernel, + helion_settings=None, + ) + + call_kwargs = mock_helion_kernel.call_args[1] + key_fn = call_kwargs["key"] + autotuner_fn = call_kwargs["autotuner_fn"] + + tensor = torch.randn(50, 4096) # batch=50, should select batchsize_64 + + # key receives unpacked args, autotuner receives args as tuple + key_result = key_fn(tensor) + autotuner = autotuner_fn(None, (tensor,)) + config = autotuner.autotune() + + assert key_result == "hiddensize_4096_batchsize_64" + assert config is kernel.configs["hiddensize_4096_batchsize_64"] + + +class TestHelionKernelWrapper: + """Test suite for HelionKernelWrapper.""" + + def test_get_configured_op_validates_configs_available(self, sample_kernel): + """Test get_configured_op validates configs are available.""" + + def fake_impl(*args, **kwargs): + return torch.zeros_like(args[0]) + + wrapper = HelionKernelWrapper( + raw_kernel_func=sample_kernel, + op_name="test_kernel", + fake_impl=fake_impl, + ) + + def default_picker(args, config_keys): + return "default" + + wrapper._config_picker = default_picker + + mock_config_manager = Mock(spec=ConfigManager) + mock_config_manager.get_platform_configs = Mock( + return_value={} + ) # Empty configs + + with ( + patch( + "vllm.kernels.helion.config_manager.ConfigManager.get_instance", + return_value=mock_config_manager, + ), + patch( + "vllm.kernels.helion.utils.get_canonical_gpu_name", + return_value="nvidia_h200", + ), + pytest.raises(ValueError, match="No configs available"), + ): + wrapper.get_configured_op() + + def test_get_configured_op_validates_config_picker( + self, sample_kernel, sample_configs + ): + """Test get_configured_op validates config picker.""" + + def fake_impl(*args, **kwargs): + return torch.zeros_like(args[0]) + + wrapper = HelionKernelWrapper( + raw_kernel_func=sample_kernel, + op_name="test_kernel", + fake_impl=fake_impl, + ) + # Don't set config picker - should raise assertion error + + mock_config_manager = Mock(spec=ConfigManager) + mock_config_manager.get_platform_configs = Mock(return_value=sample_configs) + + with ( + patch( + "vllm.kernels.helion.config_manager.ConfigManager.get_instance", + return_value=mock_config_manager, + ), + patch( + "vllm.kernels.helion.utils.get_canonical_gpu_name", + return_value="nvidia_h200", + ), + pytest.raises(AssertionError, match="No config picker registered"), + ): + wrapper.get_configured_op() + + def test_get_configured_op_returns_cached_op(self, sample_kernel, sample_configs): + """Test get_configured_op returns cached op when already registered.""" + + def fake_impl(*args, **kwargs): + return torch.zeros_like(args[0]) + + def default_picker(args, config_keys): + return "default" + + wrapper = HelionKernelWrapper( + raw_kernel_func=sample_kernel, + op_name="test_kernel", + fake_impl=fake_impl, + ) + wrapper._config_picker = default_picker + + mock_config_manager = Mock(spec=ConfigManager) + mock_config_manager.get_platform_configs = Mock(return_value=sample_configs) + + existing_op = Mock() + mock_namespace = Mock() + mock_namespace.test_kernel = existing_op + + with ( + patch( + "vllm.kernels.helion.config_manager.ConfigManager.get_instance", + return_value=mock_config_manager, + ), + patch( + "vllm.kernels.helion.utils.get_canonical_gpu_name", + return_value="nvidia_h200", + ), + patch.object(torch.ops, "vllm_helion", mock_namespace), + patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel, + ): + mock_decorated = Mock() + mock_kernel.return_value = Mock(return_value=mock_decorated) + result = wrapper.get_configured_op() + assert result is existing_op + + def test_get_configured_op_registers_new_op(self, sample_kernel, sample_configs): + """Test get_configured_op creates and registers new op.""" + + def fake_impl(*args, **kwargs): + return torch.zeros_like(args[0]) + + def default_picker(args, config_keys): + return "default" + + wrapper = HelionKernelWrapper( + raw_kernel_func=sample_kernel, + op_name="test_kernel", + fake_impl=fake_impl, + ) + wrapper._config_picker = default_picker + + mock_config_manager = Mock(spec=ConfigManager) + mock_config_manager.get_platform_configs = Mock(return_value=sample_configs) + + new_op = Mock() + registered_ops: dict[str, Mock] = {} + + class MockNamespace: + def __getattr__(self, name): + if name in registered_ops: + return registered_ops[name] + raise AttributeError(name) + + mock_namespace = MockNamespace() + + def register_side_effect(op_name, op_func, **kwargs): + registered_ops[op_name] = new_op + + with ( + patch( + "vllm.kernels.helion.config_manager.ConfigManager.get_instance", + return_value=mock_config_manager, + ), + patch( + "vllm.kernels.helion.utils.get_canonical_gpu_name", + return_value="nvidia_h200", + ), + patch.object(torch.ops, "vllm_helion", mock_namespace), + patch( + "vllm.kernels.helion.register.direct_register_custom_op", + side_effect=register_side_effect, + ) as mock_register, + patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel, + ): + mock_decorated = Mock() + mock_kernel.return_value = Mock(return_value=mock_decorated) + result = wrapper.get_configured_op() + + mock_register.assert_called_once() + assert result is new_op + # Check that op_func is the decorated kernel, not ConfiguredHelionKernel + assert mock_register.call_args[1]["op_func"] is mock_decorated diff --git a/tests/kernels/helion/test_utils.py b/tests/kernels/helion/test_utils.py new file mode 100644 index 000000000..807aa4606 --- /dev/null +++ b/tests/kernels/helion/test_utils.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for Helion utility functions.""" + +import pytest + +from vllm.kernels.helion.utils import canonicalize_gpu_name + + +@pytest.mark.parametrize( + "driver_reported_name,expected", + [ + ("NVIDIA H200", "nvidia_h200"), + ("NVIDIA A100-SXM4-80GB", "nvidia_a100_sxm4_80gb"), + ("NVIDIA H100 80GB HBM3", "nvidia_h100_80gb_hbm3"), + ("NVIDIA GeForce RTX 4090", "nvidia_geforce_rtx_4090"), + ("AMD Instinct MI300X", "amd_instinct_mi300x"), + ("Tesla V100-SXM2-32GB", "tesla_v100_sxm2_32gb"), + ], +) +def test_canonicalize_gpu_name(driver_reported_name, expected): + """Test GPU name canonicalization.""" + assert canonicalize_gpu_name(driver_reported_name) == expected + + +@pytest.mark.parametrize("invalid_name", ["", " ", "\t", "\n"]) +def test_canonicalize_gpu_name_rejects_empty(invalid_name): + """Test that empty or whitespace-only names are rejected.""" + with pytest.raises(ValueError, match="cannot be empty"): + canonicalize_gpu_name(invalid_name) diff --git a/vllm/kernels/helion/__init__.py b/vllm/kernels/helion/__init__.py index 68385e5eb..932f7979c 100644 --- a/vllm/kernels/helion/__init__.py +++ b/vllm/kernels/helion/__init__.py @@ -6,8 +6,22 @@ from vllm.kernels.helion.config_manager import ( ConfigManager, ConfigSet, ) +from vllm.kernels.helion.register import ( + ConfiguredHelionKernel, + HelionKernelWrapper, + vllm_helion_lib, +) +from vllm.kernels.helion.utils import canonicalize_gpu_name, get_canonical_gpu_name __all__ = [ + # Config management "ConfigManager", "ConfigSet", + # Kernel registration + "ConfiguredHelionKernel", + "HelionKernelWrapper", + "vllm_helion_lib", + # Utilities + "canonicalize_gpu_name", + "get_canonical_gpu_name", ] diff --git a/vllm/kernels/helion/register.py b/vllm/kernels/helion/register.py new file mode 100644 index 000000000..045b2054b --- /dev/null +++ b/vllm/kernels/helion/register.py @@ -0,0 +1,274 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +vLLM Helion kernel registration with pre-tuned config selection. + +This module leverages Helion's internal config selection infrastructure to use +pre-tuned configs instead of runtime autotuning. + +How Helion Normally Works +------------------------- +For each kernel invocation, Helion: +1. Computes a cache key from input arguments +2. Looks up the key in its internal compilation cache +3. On cache miss, runs autotuning to find the best config +4. Compiles and caches the kernel with that config + +How We Override It +------------------ +We override two Helion hooks to use pre-tuned configs: + +1. **key**: We provide a key function (derived from config_picker) that + computes cache keys matching our pre-tuned config keys. This ensures Helion's + internal cache uses keys that correspond to configs we've prepared. + +2. **autotuner_fn**: We provide PresetConfigSearch which, instead of autotuning, + simply returns the pre-tuned config for the computed key. On cache miss, + Helion calls our autotuner which returns the author-prepared config. + +Both hooks use the same config_picker logic to ensure the cache key computed +by key matches the config returned by the autotuner. + +Key Classes +----------- +- HelionKernelWrapper: Wraps raw kernel + config_picker, creates configured ops +- ConfiguredHelionKernel: Platform-specific kernel registered as PyTorch custom op +- PresetConfigSearch: Custom autotuner that returns pre-tuned configs +""" + +from collections.abc import Callable +from typing import Any + +import torch +from torch.library import Library + +from vllm.logger import init_logger +from vllm.utils.import_utils import has_helion +from vllm.utils.torch_utils import direct_register_custom_op + +if not has_helion(): + raise ImportError( + "register module requires helion to be installed. " + "Install it with: pip install helion" + ) + +import helion +from helion.autotuner.base_search import BaseAutotuner +from helion.runtime.config import Config +from helion.runtime.settings import default_autotuner_fn + +logger = init_logger(__name__) + +vllm_helion_lib = Library("vllm_helion", "FRAGMENT") # noqa + + +def validate_helion_settings( + helion_settings: "helion.Settings | None", op_name: str +) -> None: + """Validate that helion_settings doesn't contain conflicting options.""" + if helion_settings is None: + return + + settings_dict = helion_settings.to_dict() + + if ( + "autotuner_fn" in settings_dict + and settings_dict["autotuner_fn"] is not None + and settings_dict["autotuner_fn"] is not default_autotuner_fn + ): + raise ValueError( + f"HelionKernelWrapper for '{op_name}' uses a custom autotuner via " + f"config picker. Remove 'autotuner_fn' from helion_settings and use " + f"@{op_name}.register_config_picker instead." + ) + + # Warn if static_shapes is explicitly set to True since most vLLM ops need + # dynamic shapes for variable batch sizes and sequence lengths + if settings_dict.get("static_shapes") is True: + logger.warning( + "Kernel '%s' has static_shapes=True in helion_settings. " + "Most vLLM ops require dynamic shapes for variable batch sizes " + "and sequence lengths. Consider removing this setting.", + op_name, + ) + + +class PresetConfigSearch(BaseAutotuner): + """Custom autotuner that uses a preset config selector instead of autotuning.""" + + def __init__( + self, + args: tuple[Any, ...], + config_selector: Callable[[tuple[Any, ...]], Config], + ): + self.args = args + self.config_selector = config_selector + + def autotune(self, *, skip_cache: bool = False) -> Config: + return self.config_selector(self.args) + + +class ConfiguredHelionKernel: + """A configured Helion kernel bound to a specific platform.""" + + def __init__( + self, + op_name: str, + config_picker: Callable[[tuple[Any, ...], list[str]], str | None], + raw_kernel_func: Callable, + helion_settings: "helion.Settings | None" = None, + ): + self.op_name = op_name + self.config_picker = config_picker + self.raw_kernel_func = raw_kernel_func + self.helion_settings = helion_settings + self._decorated_kernel = self._create_decorated_kernel() + + def __call__(self, *args, **kwargs): + return self._decorated_kernel(*args, **kwargs) + + def _create_key_computer(self): + """ + Create a key computer function derived from the config picker. + + The returned function receives kernel arguments unpacked (*args) to match + Helion's key signature (called as self._key_fn(*args)). + """ + if self.config_picker is None: + raise RuntimeError( + f"No config picker registered for kernel '{self.op_name}'. " + f"Use @{self.op_name}.register_config_picker to register one." + ) + + def key_computer(*args): + config_keys = list(self.configs.keys()) + selected_key = self.config_picker(args, config_keys) + if selected_key: + return selected_key + return "default" if "default" in self.configs else None + + return key_computer + + def _create_config_selector(self, key_computer): + def config_selector(args): + # args is a tuple; key_computer expects unpacked args + selected_config_key = key_computer(*args) + + if selected_config_key is None: + raise ValueError( + f"Config picker returned None for kernel '{self.op_name}' " + f"with available config keys: {list(self.configs.keys())}" + ) + + if selected_config_key not in self.configs: + raise ValueError( + f"Config picker returned invalid config key " + f"'{selected_config_key}' for kernel '{self.op_name}'. " + f"Available keys: {list(self.configs.keys())}" + ) + + return self.configs[selected_config_key] + + return config_selector + + def _load_platform_configs(self) -> None: + from vllm.kernels.helion.config_manager import ConfigManager + from vllm.kernels.helion.utils import get_canonical_gpu_name + + self.platform = get_canonical_gpu_name() + config_manager = ConfigManager.get_instance() + self.configs = config_manager.get_platform_configs(self.op_name, self.platform) + + if not self.configs: + raise ValueError( + f"No configs available for kernel '{self.op_name}' " + f"on platform '{self.platform}'" + ) + + def _create_decorated_kernel(self) -> Callable[..., Any]: + self._load_platform_configs() + + key_computer = self._create_key_computer() + config_selector = self._create_config_selector(key_computer) + + kernel_kwargs = {} + if self.helion_settings: + kernel_kwargs.update(self.helion_settings.to_dict()) + + # Set static_shapes=False by default if user didn't explicitly set it to True + # This is needed for dynamic batch sizes and sequence lengths in vLLM + if kernel_kwargs.get("static_shapes") is not True: + kernel_kwargs["static_shapes"] = False + + kernel_kwargs["autotuner_fn"] = lambda _, args: PresetConfigSearch( + args, config_selector + ) + kernel_kwargs["key"] = key_computer + + logger.debug( + "Creating decorated kernel %s with custom autotuner on platform %s", + self.op_name, + self.platform, + ) + return helion.kernel(**kernel_kwargs)(self.raw_kernel_func) + + +class HelionKernelWrapper: + """Wrapper for Helion kernels that creates config-specific PyTorch custom ops.""" + + def __init__( + self, + raw_kernel_func: Callable, + op_name: str, + fake_impl: Callable, + helion_settings: "helion.Settings | None" = None, + ): + # Validate helion_settings doesn't conflict with our custom autotuner + validate_helion_settings(helion_settings, op_name) + + self.raw_kernel_func = raw_kernel_func + self.op_name = op_name + self._fake_impl = fake_impl + self.helion_settings = helion_settings + self._config_picker: ( + Callable[[tuple[Any, ...], list[str]], str | None] | None + ) = None + + def __call__(self, *args, **kwargs): + configured_op = self.get_configured_op() + return configured_op(*args, **kwargs) + + def register_config_picker( + self, picker_func: Callable[[tuple[Any, ...], list[str]], str | None] + ) -> Callable[[tuple[Any, ...], list[str]], str | None]: + self._config_picker = picker_func + return picker_func + + def get_configured_op(self) -> Any: + assert self._config_picker is not None, ( + f"No config picker registered for kernel '{self.op_name}'. " + f"Use @{self.op_name}.register_config_picker to register one." + ) + + if hasattr(torch.ops.vllm_helion, self.op_name): + logger.debug("Op vllm_helion::%s already registered", self.op_name) + return getattr(torch.ops.vllm_helion, self.op_name) + + configured_kernel = ConfiguredHelionKernel( + op_name=self.op_name, + config_picker=self._config_picker, + raw_kernel_func=self.raw_kernel_func, + helion_settings=self.helion_settings, + ) + + logger.info("Registering op: vllm_helion::%s", self.op_name) + direct_register_custom_op( + op_name=self.op_name, + op_func=configured_kernel._decorated_kernel, # Register decorated kernel + # TODO(gmagogsfm): Implement automatic mutation/aliasing detection + # for Helion kernels. + mutates_args=None, + fake_impl=self._fake_impl, + target_lib=vllm_helion_lib, + ) + return getattr(torch.ops.vllm_helion, self.op_name) diff --git a/vllm/kernels/helion/utils.py b/vllm/kernels/helion/utils.py new file mode 100644 index 000000000..65e327a82 --- /dev/null +++ b/vllm/kernels/helion/utils.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Utility functions for Helion kernel management.""" + +import torch + + +def get_gpu_name(device_id: int | None = None) -> str: + if device_id is None: + device_id = torch.cuda.current_device() + props = torch.cuda.get_device_properties(device_id) + return props.name + + +def canonicalize_gpu_name(name: str) -> str: + """ + Canonicalize GPU name for use as a platform identifier. + + Converts to lowercase and replaces spaces and hyphens with underscores. + e.g., "NVIDIA A100-SXM4-80GB" -> "nvidia_a100_sxm4_80gb" + + Raises ValueError if name is empty. + """ + if not name or not name.strip(): + raise ValueError("GPU name cannot be empty") + name = name.lower() + name = name.replace(" ", "_") + name = name.replace("-", "_") + return name + + +def get_canonical_gpu_name(device_id: int | None = None) -> str: + return canonicalize_gpu_name(get_gpu_name(device_id))