diff --git a/tests/kernels/helion/test_config_manager.py b/tests/kernels/helion/test_config_manager.py new file mode 100644 index 000000000..d95909c92 --- /dev/null +++ b/tests/kernels/helion/test_config_manager.py @@ -0,0 +1,361 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for Helion ConfigManager and ConfigSet. + +Tests the simplified configuration management system for Helion custom kernels. +""" + +import json +import tempfile +from pathlib import Path + +import pytest + +from vllm.utils.import_utils import has_helion + +# Skip entire module if helion is not available +if not has_helion(): + pytest.skip( + "Helion is not installed. Install with: pip install vllm[helion]", + allow_module_level=True, + ) + +import helion + +from vllm.kernels.helion.config_manager import ( + ConfigManager, + ConfigSet, +) + + +@pytest.fixture(autouse=True) +def reset_config_manager_singleton(): + """Reset ConfigManager singleton before each test.""" + ConfigManager.reset_instance() + yield + ConfigManager.reset_instance() + + +class TestConfigSet: + """Test suite for ConfigSet class.""" + + def test_config_set_creation(self): + """Test creating an empty ConfigSet.""" + config_set = ConfigSet("test_kernel") + + assert config_set.kernel_name == "test_kernel" + assert config_set.get_platforms() == [] + + def test_config_set_from_dict(self): + """Test creating ConfigSet from dictionary data.""" + # Use realistic config data that helion.Config can handle + config_data = { + "block_sizes": [32, 16], + "num_warps": 4, + "num_stages": 3, + "pid_type": "persistent_interleaved", + } + data = {"h100": {"batch_32_hidden_4096": config_data}} + + config_set = ConfigSet.from_dict("test_kernel", data) + + assert config_set.kernel_name == "test_kernel" + assert config_set.get_platforms() == ["h100"] + + # Verify the config was created correctly + config = config_set.get_config("h100", "batch_32_hidden_4096") + assert isinstance(config, helion.Config) + assert config.block_sizes == [32, 16] + assert config.num_warps == 4 + assert config.num_stages == 3 + assert config.pid_type == "persistent_interleaved" + + def test_config_set_get_config_keyerror(self): + """Test that accessing non-existent configs raises informative KeyErrors.""" + config_set = ConfigSet("test_kernel") + + with pytest.raises(KeyError, match="platform 'h100' not found"): + config_set.get_config("h100", "batch_32_hidden_4096") + + # Use realistic config data + config_data = {"num_warps": 8, "num_stages": 4} + data = {"h100": {"batch_64_hidden_2048": config_data}} + config_set = ConfigSet.from_dict("test_kernel", data) + + with pytest.raises( + KeyError, match="config_key 'batch_32_hidden_4096' not found" + ): + config_set.get_config("h100", "batch_32_hidden_4096") + + def test_config_set_get_platforms(self): + """Test get_platforms method.""" + # Use realistic config data + config1 = {"num_warps": 4, "num_stages": 3} + config2 = {"num_warps": 8, "num_stages": 5} + + data = { + "h100": {"batch_32_hidden_4096": config1}, + "a100": {"batch_16_hidden_2048": config2}, + } + config_set = ConfigSet.from_dict("test_kernel", data) + + platforms = config_set.get_platforms() + assert platforms == ["a100", "h100"] # Should be sorted + + def test_config_set_get_config_keys(self): + """Test get_config_keys method.""" + # Use realistic config data + config1 = {"num_warps": 4, "num_stages": 3} + config2 = {"num_warps": 8, "num_stages": 5} + + data = { + "h100": { + "batch_32_hidden_4096": config1, + "batch_64_hidden_2048": config2, + } + } + config_set = ConfigSet.from_dict("test_kernel", data) + + config_keys = config_set.get_config_keys("h100") + assert config_keys == ["batch_32_hidden_4096", "batch_64_hidden_2048"] + + assert config_set.get_config_keys("v100") == [] + + def test_config_set_to_dict(self): + """Test converting ConfigSet to dictionary.""" + # Use realistic config data + original_config = { + "block_sizes": [64, 32], + "num_warps": 16, + "num_stages": 4, + "pid_type": "persistent_blocked", + } + original_data = {"h100": {"batch_32_hidden_4096": original_config}} + + config_set = ConfigSet.from_dict("test_kernel", original_data) + result_data = config_set.to_dict() + + # The result should match the original (Config roundtrip should work) + assert result_data == original_data + + +class TestConfigManager: + """Test suite for ConfigManager class.""" + + def test_config_manager_creation_default_base_dir(self): + """Test creating ConfigManager with default base directory.""" + manager = ConfigManager() + assert manager._base_dir.name == "configs" + + def test_config_manager_creation_custom_base_dir(self): + """Test creating ConfigManager with custom base directory.""" + custom_dir = "/tmp/custom_configs" + manager = ConfigManager(base_dir=custom_dir) + + # Paths are resolved, so compare with resolved path + assert manager._base_dir == Path(custom_dir).resolve() + + def test_get_config_file_path(self): + """Test getting config file path for a kernel.""" + manager = ConfigManager(base_dir="/tmp") + + file_path = manager.get_config_file_path("silu_mul_fp8") + + expected_path = Path("/tmp/silu_mul_fp8.json") + assert file_path == expected_path + + def test_ensure_base_dir_exists(self): + """Test ensuring base directory exists.""" + with tempfile.TemporaryDirectory() as temp_dir: + base_dir = Path(temp_dir) / "non_existent" / "configs" + manager = ConfigManager(base_dir=base_dir) + assert not base_dir.exists() + + returned_path = manager.ensure_base_dir_exists() + + assert base_dir.exists() + assert base_dir.is_dir() + assert returned_path == base_dir + + def test_load_config_set_file_not_exists(self): + """Test loading config set when file doesn't exist.""" + with tempfile.TemporaryDirectory() as temp_dir: + manager = ConfigManager(base_dir=temp_dir) + config_set = manager.load_config_set("non_existent_kernel") + + assert isinstance(config_set, ConfigSet) + assert config_set.kernel_name == "non_existent_kernel" + assert config_set.get_platforms() == [] + + def test_load_config_set_valid_file(self): + """Test loading config set from valid file.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Use realistic config data + kernel_config = { + "block_sizes": [128, 64], + "num_warps": 8, + "num_stages": 6, + "pid_type": "persistent_interleaved", + } + config_data = {"h100": {"batch_32_hidden_4096": kernel_config}} + config_file = Path(temp_dir) / "test_kernel.json" + with open(config_file, "w") as f: + json.dump(config_data, f) + + manager = ConfigManager(base_dir=temp_dir) + config_set = manager.load_config_set("test_kernel") + + assert isinstance(config_set, ConfigSet) + assert config_set.kernel_name == "test_kernel" + assert config_set.get_platforms() == ["h100"] + + # Verify the config was loaded correctly + config = config_set.get_config("h100", "batch_32_hidden_4096") + assert isinstance(config, helion.Config) + assert config.block_sizes == [128, 64] + assert config.num_warps == 8 + + def test_load_config_set_invalid_json(self): + """Test loading config set from file with invalid JSON.""" + with tempfile.TemporaryDirectory() as temp_dir: + config_file = Path(temp_dir) / "test_kernel.json" + with open(config_file, "w") as f: + f.write("invalid json content {") + + manager = ConfigManager(base_dir=temp_dir) + config_set = manager.load_config_set("test_kernel") + + assert isinstance(config_set, ConfigSet) + assert config_set.kernel_name == "test_kernel" + assert config_set.get_platforms() == [] + + def test_save_config_set(self): + """Test saving ConfigSet to file.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Use realistic config data + kernel_config = { + "block_sizes": [256, 128], + "num_warps": 16, + "num_stages": 8, + "pid_type": "persistent_blocked", + } + data = {"h100": {"batch_32_hidden_4096": kernel_config}} + config_set = ConfigSet.from_dict("test_kernel", data) + + manager = ConfigManager(base_dir=temp_dir) + saved_path = manager.save_config_set(config_set) + + expected_path = Path(temp_dir) / "test_kernel.json" + assert saved_path == expected_path + assert saved_path.exists() + + with open(saved_path) as f: + loaded_data = json.load(f) + assert loaded_data == data + + def test_save_config_set_creates_directory(self): + """Test that save_config_set creates parent directories if needed.""" + with tempfile.TemporaryDirectory() as temp_dir: + nested_dir = Path(temp_dir) / "nested" / "configs" + config_set = ConfigSet("test_kernel") + + manager = ConfigManager(base_dir=nested_dir) + saved_path = manager.save_config_set(config_set) + + assert nested_dir.exists() + assert nested_dir.is_dir() + assert saved_path.exists() + + def test_get_platform_configs(self): + """Test getting all configs for a specific platform.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Use realistic config data + config_1 = {"num_warps": 4, "num_stages": 3, "block_sizes": [64, 32]} + config_2 = {"num_warps": 8, "num_stages": 5, "block_sizes": [128, 64]} + default_config = { + "num_warps": 16, + "num_stages": 7, + "block_sizes": [256, 128], + } + config_3 = {"num_warps": 2, "num_stages": 2, "block_sizes": [32, 16]} + + config_data = { + "h100": { + "batch_32_hidden_4096": config_1, + "batch_64_hidden_2048": config_2, + "default": default_config, + }, + "a100": {"batch_16_hidden_1024": config_3}, + } + config_file = Path(temp_dir) / "test_kernel.json" + with open(config_file, "w") as f: + json.dump(config_data, f) + + manager = ConfigManager(base_dir=temp_dir) + + h100_configs = manager.get_platform_configs("test_kernel", "h100") + assert len(h100_configs) == 3 + assert "batch_32_hidden_4096" in h100_configs + assert "batch_64_hidden_2048" in h100_configs + assert "default" in h100_configs + for config in h100_configs.values(): + assert isinstance(config, helion.Config) + + # Verify specific config details + assert h100_configs["batch_32_hidden_4096"].num_warps == 4 + assert h100_configs["default"].num_stages == 7 + + a100_configs = manager.get_platform_configs("test_kernel", "a100") + assert len(a100_configs) == 1 + assert "batch_16_hidden_1024" in a100_configs + assert isinstance(a100_configs["batch_16_hidden_1024"], helion.Config) + assert a100_configs["batch_16_hidden_1024"].num_warps == 2 + + nonexistent_configs = manager.get_platform_configs("test_kernel", "v100") + assert len(nonexistent_configs) == 0 + + def test_singleton_returns_same_instance(self): + """Test that ConfigManager returns the same instance on repeated calls.""" + manager1 = ConfigManager(base_dir="/tmp/test_singleton") + manager2 = ConfigManager(base_dir="/tmp/test_singleton") + + assert manager1 is manager2 + + def test_singleton_with_default_base_dir(self): + """Test singleton behavior with default base directory.""" + manager1 = ConfigManager() + manager2 = ConfigManager() + + assert manager1 is manager2 + assert manager1._base_dir == manager2._base_dir + + def test_singleton_error_on_different_base_dir(self): + """Test that ConfigManager raises error when created with different base_dir.""" + ConfigManager(base_dir="/tmp/first_dir") + + with pytest.raises(ValueError, match="singleton already exists"): + ConfigManager(base_dir="/tmp/different_dir") + + def test_reset_instance_allows_new_base_dir(self): + """Test that reset_instance allows creating with a new base_dir.""" + manager1 = ConfigManager(base_dir="/tmp/first_dir") + assert manager1._base_dir == Path("/tmp/first_dir").resolve() + + ConfigManager.reset_instance() + + manager2 = ConfigManager(base_dir="/tmp/second_dir") + assert manager2._base_dir == Path("/tmp/second_dir").resolve() + assert manager1 is not manager2 + + def test_get_instance_returns_existing(self): + """Test that get_instance returns the existing singleton.""" + manager1 = ConfigManager(base_dir="/tmp/test_get_instance") + manager2 = ConfigManager.get_instance() + + assert manager1 is manager2 + + def test_get_instance_raises_if_not_initialized(self): + """Test that get_instance raises RuntimeError if no instance exists.""" + with pytest.raises(RuntimeError, match="has not been created"): + ConfigManager.get_instance() diff --git a/vllm/kernels/__init__.py b/vllm/kernels/__init__.py new file mode 100644 index 000000000..3d0c9805e --- /dev/null +++ b/vllm/kernels/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Kernel implementations for vLLM.""" diff --git a/vllm/kernels/helion/__init__.py b/vllm/kernels/helion/__init__.py new file mode 100644 index 000000000..68385e5eb --- /dev/null +++ b/vllm/kernels/helion/__init__.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Helion integration for vLLM.""" + +from vllm.kernels.helion.config_manager import ( + ConfigManager, + ConfigSet, +) + +__all__ = [ + "ConfigManager", + "ConfigSet", +] diff --git a/vllm/kernels/helion/config_manager.py b/vllm/kernels/helion/config_manager.py new file mode 100644 index 000000000..59d5bf430 --- /dev/null +++ b/vllm/kernels/helion/config_manager.py @@ -0,0 +1,231 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Configuration management for Helion kernels. + +This module provides centralized configuration file management for Helion custom +operations, including naming conventions, directory resolution, and file I/O. + +Config File Structure +--------------------- +Each kernel has a single JSON config file: {kernel_name}.json + +The file uses a simplified 2-layer hierarchical structure: +{ + "h100": { # GPU platform + "default": { ... }, # Fallback configuration + "batch_32_hidden_4096": { ... }, + "batch_64_hidden_8192": { ... } + }, + "a100": { + "default": { ... }, + "batch_16_hidden_2048": { ... } + } +} + +Example file: silu_mul_fp8.json + +Config keys should be structured strings that encode the relevant +parameters (e.g., "batch_32_hidden_4096", "seq_512_heads_16", "fp8_batch_64", etc.). + +Classes +------- +- ConfigSet: In-memory collection of configs for a kernel with lookup/query APIs. +- ConfigManager: File-level operations for config persistence. +""" + +import json +from pathlib import Path +from typing import Any + +from vllm.logger import init_logger +from vllm.utils.import_utils import has_helion + +if not has_helion(): + raise ImportError( + "ConfigManager requires helion to be installed. " + "Install it with: pip install helion" + ) + +import helion + +logger = init_logger(__name__) + + +class ConfigSet: + """In-memory collection of Helion configs with lookup/query capabilities.""" + + # Type alias for nested config structure: + # platform -> config_key -> helion.Config + _ConfigDict = dict[str, dict[str, "helion.Config"]] + + def __init__(self, kernel_name: str): + self._kernel_name = kernel_name + self._configs: ConfigSet._ConfigDict = {} + + @property + def kernel_name(self) -> str: + return self._kernel_name + + def get_config(self, platform: str, config_key: str) -> helion.Config: + platform_dict = self._configs.get(platform) + if platform_dict is None: + avail_platforms = self.get_platforms() + raise KeyError( + f"Config not found for kernel '{self._kernel_name}': " + f"platform '{platform}' not found. " + f"Available platforms: {avail_platforms or '(none)'}" + ) + + config = platform_dict.get(config_key) + if config is None: + avail_keys = self.get_config_keys(platform) + raise KeyError( + f"Config not found for kernel '{self._kernel_name}': " + f"config_key '{config_key}' not found for platform '{platform}'. " + f"Available config_keys: {avail_keys or '(none)'}" + ) + + return config + + def get_platforms(self) -> list[str]: + return sorted(self._configs.keys()) + + def get_config_keys(self, platform: str) -> list[str]: + platform_dict = self._configs.get(platform.lower()) + if platform_dict is None: + return [] + return sorted(platform_dict.keys()) + + def to_dict(self) -> dict[str, Any]: + result: dict[str, Any] = {} + + for platform, config_keys_dict in self._configs.items(): + result[platform] = {} + + for config_key, config in config_keys_dict.items(): + # Convert helion.Config to dict using to_json() + json.loads() + import json + + result[platform][config_key] = json.loads(config.to_json()) + + return result + + @classmethod + def from_dict(cls, kernel_name: str, data: dict[str, Any]) -> "ConfigSet": + config_set = cls(kernel_name) + count = 0 + + for platform, platform_data in data.items(): + if platform not in config_set._configs: + config_set._configs[platform] = {} + + for config_key, config_data in platform_data.items(): + config = helion.Config(**config_data) + config_set._configs[platform][config_key] = config + count += 1 + + if count > 0: + logger.debug( + "Loaded %d configs for kernel '%s'", + count, + kernel_name, + ) + + return config_set + + +class ConfigManager: + """File-level configuration management for Helion kernels (global singleton).""" + + _instance: "ConfigManager | None" = None + _instance_base_dir: Path | None = None + + def __new__(cls, base_dir: str | Path | None = None) -> "ConfigManager": + resolved_base_dir = cls._resolve_base_dir(base_dir) + + if cls._instance is not None: + # Instance already exists - check for base_dir mismatch + if cls._instance_base_dir != resolved_base_dir: + raise ValueError( + f"ConfigManager singleton already exists with base_dir " + f"'{cls._instance_base_dir}', cannot create with different " + f"base_dir '{resolved_base_dir}'" + ) + return cls._instance + + # Create new instance + instance = super().__new__(cls) + cls._instance = instance + cls._instance_base_dir = resolved_base_dir + return instance + + def __init__(self, base_dir: str | Path | None = None): + # Only initialize if not already initialized + if hasattr(self, "_base_dir"): + return + + self._base_dir = self._resolve_base_dir(base_dir) + logger.debug("ConfigManager initialized with base_dir: %s", self._base_dir) + + @staticmethod + def _resolve_base_dir(base_dir: str | Path | None) -> Path: + if base_dir is not None: + return Path(base_dir).resolve() + return (Path(__file__).parent / "configs").resolve() + + @classmethod + def get_instance(cls) -> "ConfigManager": + if cls._instance is None: + raise RuntimeError( + "ConfigManager instance has not been created. " + "Call ConfigManager(base_dir=...) first to initialize." + ) + return cls._instance + + @classmethod + def reset_instance(cls) -> None: + """For testing purposes only.""" + cls._instance = None + cls._instance_base_dir = None + + def get_config_file_path(self, kernel_name: str) -> Path: + return self._base_dir / f"{kernel_name}.json" + + def ensure_base_dir_exists(self) -> Path: + self._base_dir.mkdir(parents=True, exist_ok=True) + return self._base_dir + + def load_config_set(self, kernel_name: str) -> ConfigSet: + config_path = self.get_config_file_path(kernel_name) + if not config_path.exists(): + return ConfigSet.from_dict(kernel_name, {}) + + try: + with open(config_path) as f: + data = json.load(f) + return ConfigSet.from_dict(kernel_name, data) + except (json.JSONDecodeError, OSError) as e: + logger.error("Failed to load config file %s: %s", config_path, e) + return ConfigSet.from_dict(kernel_name, {}) + + def get_platform_configs( + self, kernel_name: str, platform: str + ) -> dict[str, helion.Config]: + config_set = self.load_config_set(kernel_name) + config_keys = config_set.get_config_keys(platform) + + return { + config_key: config_set.get_config(platform, config_key) + for config_key in config_keys + } + + def save_config_set(self, config_set: ConfigSet) -> Path: + config_path = self.get_config_file_path(config_set.kernel_name) + config_path.parent.mkdir(parents=True, exist_ok=True) + + with open(config_path, "w") as f: + json.dump(config_set.to_dict(), f, indent=2) + + logger.info("Saved config to: %s", config_path) + return config_path