[Kernel] [Helion] [1/N] Add Helion ConfigManager (#32740)

Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
Yanan Cao
2026-01-30 09:19:19 -08:00
committed by GitHub
parent 67239c4c42
commit 6c1f9e4c18
4 changed files with 608 additions and 0 deletions

View File

@@ -0,0 +1,361 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Unit tests for Helion ConfigManager and ConfigSet.
Tests the simplified configuration management system for Helion custom kernels.
"""
import json
import tempfile
from pathlib import Path
import pytest
from vllm.utils.import_utils import has_helion
# Skip entire module if helion is not available
if not has_helion():
pytest.skip(
"Helion is not installed. Install with: pip install vllm[helion]",
allow_module_level=True,
)
import helion
from vllm.kernels.helion.config_manager import (
ConfigManager,
ConfigSet,
)
@pytest.fixture(autouse=True)
def reset_config_manager_singleton():
"""Reset ConfigManager singleton before each test."""
ConfigManager.reset_instance()
yield
ConfigManager.reset_instance()
class TestConfigSet:
"""Test suite for ConfigSet class."""
def test_config_set_creation(self):
"""Test creating an empty ConfigSet."""
config_set = ConfigSet("test_kernel")
assert config_set.kernel_name == "test_kernel"
assert config_set.get_platforms() == []
def test_config_set_from_dict(self):
"""Test creating ConfigSet from dictionary data."""
# Use realistic config data that helion.Config can handle
config_data = {
"block_sizes": [32, 16],
"num_warps": 4,
"num_stages": 3,
"pid_type": "persistent_interleaved",
}
data = {"h100": {"batch_32_hidden_4096": config_data}}
config_set = ConfigSet.from_dict("test_kernel", data)
assert config_set.kernel_name == "test_kernel"
assert config_set.get_platforms() == ["h100"]
# Verify the config was created correctly
config = config_set.get_config("h100", "batch_32_hidden_4096")
assert isinstance(config, helion.Config)
assert config.block_sizes == [32, 16]
assert config.num_warps == 4
assert config.num_stages == 3
assert config.pid_type == "persistent_interleaved"
def test_config_set_get_config_keyerror(self):
"""Test that accessing non-existent configs raises informative KeyErrors."""
config_set = ConfigSet("test_kernel")
with pytest.raises(KeyError, match="platform 'h100' not found"):
config_set.get_config("h100", "batch_32_hidden_4096")
# Use realistic config data
config_data = {"num_warps": 8, "num_stages": 4}
data = {"h100": {"batch_64_hidden_2048": config_data}}
config_set = ConfigSet.from_dict("test_kernel", data)
with pytest.raises(
KeyError, match="config_key 'batch_32_hidden_4096' not found"
):
config_set.get_config("h100", "batch_32_hidden_4096")
def test_config_set_get_platforms(self):
"""Test get_platforms method."""
# Use realistic config data
config1 = {"num_warps": 4, "num_stages": 3}
config2 = {"num_warps": 8, "num_stages": 5}
data = {
"h100": {"batch_32_hidden_4096": config1},
"a100": {"batch_16_hidden_2048": config2},
}
config_set = ConfigSet.from_dict("test_kernel", data)
platforms = config_set.get_platforms()
assert platforms == ["a100", "h100"] # Should be sorted
def test_config_set_get_config_keys(self):
"""Test get_config_keys method."""
# Use realistic config data
config1 = {"num_warps": 4, "num_stages": 3}
config2 = {"num_warps": 8, "num_stages": 5}
data = {
"h100": {
"batch_32_hidden_4096": config1,
"batch_64_hidden_2048": config2,
}
}
config_set = ConfigSet.from_dict("test_kernel", data)
config_keys = config_set.get_config_keys("h100")
assert config_keys == ["batch_32_hidden_4096", "batch_64_hidden_2048"]
assert config_set.get_config_keys("v100") == []
def test_config_set_to_dict(self):
"""Test converting ConfigSet to dictionary."""
# Use realistic config data
original_config = {
"block_sizes": [64, 32],
"num_warps": 16,
"num_stages": 4,
"pid_type": "persistent_blocked",
}
original_data = {"h100": {"batch_32_hidden_4096": original_config}}
config_set = ConfigSet.from_dict("test_kernel", original_data)
result_data = config_set.to_dict()
# The result should match the original (Config roundtrip should work)
assert result_data == original_data
class TestConfigManager:
"""Test suite for ConfigManager class."""
def test_config_manager_creation_default_base_dir(self):
"""Test creating ConfigManager with default base directory."""
manager = ConfigManager()
assert manager._base_dir.name == "configs"
def test_config_manager_creation_custom_base_dir(self):
"""Test creating ConfigManager with custom base directory."""
custom_dir = "/tmp/custom_configs"
manager = ConfigManager(base_dir=custom_dir)
# Paths are resolved, so compare with resolved path
assert manager._base_dir == Path(custom_dir).resolve()
def test_get_config_file_path(self):
"""Test getting config file path for a kernel."""
manager = ConfigManager(base_dir="/tmp")
file_path = manager.get_config_file_path("silu_mul_fp8")
expected_path = Path("/tmp/silu_mul_fp8.json")
assert file_path == expected_path
def test_ensure_base_dir_exists(self):
"""Test ensuring base directory exists."""
with tempfile.TemporaryDirectory() as temp_dir:
base_dir = Path(temp_dir) / "non_existent" / "configs"
manager = ConfigManager(base_dir=base_dir)
assert not base_dir.exists()
returned_path = manager.ensure_base_dir_exists()
assert base_dir.exists()
assert base_dir.is_dir()
assert returned_path == base_dir
def test_load_config_set_file_not_exists(self):
"""Test loading config set when file doesn't exist."""
with tempfile.TemporaryDirectory() as temp_dir:
manager = ConfigManager(base_dir=temp_dir)
config_set = manager.load_config_set("non_existent_kernel")
assert isinstance(config_set, ConfigSet)
assert config_set.kernel_name == "non_existent_kernel"
assert config_set.get_platforms() == []
def test_load_config_set_valid_file(self):
"""Test loading config set from valid file."""
with tempfile.TemporaryDirectory() as temp_dir:
# Use realistic config data
kernel_config = {
"block_sizes": [128, 64],
"num_warps": 8,
"num_stages": 6,
"pid_type": "persistent_interleaved",
}
config_data = {"h100": {"batch_32_hidden_4096": kernel_config}}
config_file = Path(temp_dir) / "test_kernel.json"
with open(config_file, "w") as f:
json.dump(config_data, f)
manager = ConfigManager(base_dir=temp_dir)
config_set = manager.load_config_set("test_kernel")
assert isinstance(config_set, ConfigSet)
assert config_set.kernel_name == "test_kernel"
assert config_set.get_platforms() == ["h100"]
# Verify the config was loaded correctly
config = config_set.get_config("h100", "batch_32_hidden_4096")
assert isinstance(config, helion.Config)
assert config.block_sizes == [128, 64]
assert config.num_warps == 8
def test_load_config_set_invalid_json(self):
"""Test loading config set from file with invalid JSON."""
with tempfile.TemporaryDirectory() as temp_dir:
config_file = Path(temp_dir) / "test_kernel.json"
with open(config_file, "w") as f:
f.write("invalid json content {")
manager = ConfigManager(base_dir=temp_dir)
config_set = manager.load_config_set("test_kernel")
assert isinstance(config_set, ConfigSet)
assert config_set.kernel_name == "test_kernel"
assert config_set.get_platforms() == []
def test_save_config_set(self):
"""Test saving ConfigSet to file."""
with tempfile.TemporaryDirectory() as temp_dir:
# Use realistic config data
kernel_config = {
"block_sizes": [256, 128],
"num_warps": 16,
"num_stages": 8,
"pid_type": "persistent_blocked",
}
data = {"h100": {"batch_32_hidden_4096": kernel_config}}
config_set = ConfigSet.from_dict("test_kernel", data)
manager = ConfigManager(base_dir=temp_dir)
saved_path = manager.save_config_set(config_set)
expected_path = Path(temp_dir) / "test_kernel.json"
assert saved_path == expected_path
assert saved_path.exists()
with open(saved_path) as f:
loaded_data = json.load(f)
assert loaded_data == data
def test_save_config_set_creates_directory(self):
"""Test that save_config_set creates parent directories if needed."""
with tempfile.TemporaryDirectory() as temp_dir:
nested_dir = Path(temp_dir) / "nested" / "configs"
config_set = ConfigSet("test_kernel")
manager = ConfigManager(base_dir=nested_dir)
saved_path = manager.save_config_set(config_set)
assert nested_dir.exists()
assert nested_dir.is_dir()
assert saved_path.exists()
def test_get_platform_configs(self):
"""Test getting all configs for a specific platform."""
with tempfile.TemporaryDirectory() as temp_dir:
# Use realistic config data
config_1 = {"num_warps": 4, "num_stages": 3, "block_sizes": [64, 32]}
config_2 = {"num_warps": 8, "num_stages": 5, "block_sizes": [128, 64]}
default_config = {
"num_warps": 16,
"num_stages": 7,
"block_sizes": [256, 128],
}
config_3 = {"num_warps": 2, "num_stages": 2, "block_sizes": [32, 16]}
config_data = {
"h100": {
"batch_32_hidden_4096": config_1,
"batch_64_hidden_2048": config_2,
"default": default_config,
},
"a100": {"batch_16_hidden_1024": config_3},
}
config_file = Path(temp_dir) / "test_kernel.json"
with open(config_file, "w") as f:
json.dump(config_data, f)
manager = ConfigManager(base_dir=temp_dir)
h100_configs = manager.get_platform_configs("test_kernel", "h100")
assert len(h100_configs) == 3
assert "batch_32_hidden_4096" in h100_configs
assert "batch_64_hidden_2048" in h100_configs
assert "default" in h100_configs
for config in h100_configs.values():
assert isinstance(config, helion.Config)
# Verify specific config details
assert h100_configs["batch_32_hidden_4096"].num_warps == 4
assert h100_configs["default"].num_stages == 7
a100_configs = manager.get_platform_configs("test_kernel", "a100")
assert len(a100_configs) == 1
assert "batch_16_hidden_1024" in a100_configs
assert isinstance(a100_configs["batch_16_hidden_1024"], helion.Config)
assert a100_configs["batch_16_hidden_1024"].num_warps == 2
nonexistent_configs = manager.get_platform_configs("test_kernel", "v100")
assert len(nonexistent_configs) == 0
def test_singleton_returns_same_instance(self):
"""Test that ConfigManager returns the same instance on repeated calls."""
manager1 = ConfigManager(base_dir="/tmp/test_singleton")
manager2 = ConfigManager(base_dir="/tmp/test_singleton")
assert manager1 is manager2
def test_singleton_with_default_base_dir(self):
"""Test singleton behavior with default base directory."""
manager1 = ConfigManager()
manager2 = ConfigManager()
assert manager1 is manager2
assert manager1._base_dir == manager2._base_dir
def test_singleton_error_on_different_base_dir(self):
"""Test that ConfigManager raises error when created with different base_dir."""
ConfigManager(base_dir="/tmp/first_dir")
with pytest.raises(ValueError, match="singleton already exists"):
ConfigManager(base_dir="/tmp/different_dir")
def test_reset_instance_allows_new_base_dir(self):
"""Test that reset_instance allows creating with a new base_dir."""
manager1 = ConfigManager(base_dir="/tmp/first_dir")
assert manager1._base_dir == Path("/tmp/first_dir").resolve()
ConfigManager.reset_instance()
manager2 = ConfigManager(base_dir="/tmp/second_dir")
assert manager2._base_dir == Path("/tmp/second_dir").resolve()
assert manager1 is not manager2
def test_get_instance_returns_existing(self):
"""Test that get_instance returns the existing singleton."""
manager1 = ConfigManager(base_dir="/tmp/test_get_instance")
manager2 = ConfigManager.get_instance()
assert manager1 is manager2
def test_get_instance_raises_if_not_initialized(self):
"""Test that get_instance raises RuntimeError if no instance exists."""
with pytest.raises(RuntimeError, match="has not been created"):
ConfigManager.get_instance()

3
vllm/kernels/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Kernel implementations for vLLM."""

View File

@@ -0,0 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Helion integration for vLLM."""
from vllm.kernels.helion.config_manager import (
ConfigManager,
ConfigSet,
)
__all__ = [
"ConfigManager",
"ConfigSet",
]

View File

@@ -0,0 +1,231 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Configuration management for Helion kernels.
This module provides centralized configuration file management for Helion custom
operations, including naming conventions, directory resolution, and file I/O.
Config File Structure
---------------------
Each kernel has a single JSON config file: {kernel_name}.json
The file uses a simplified 2-layer hierarchical structure:
{
"h100": { # GPU platform
"default": { ... }, # Fallback configuration
"batch_32_hidden_4096": { ... },
"batch_64_hidden_8192": { ... }
},
"a100": {
"default": { ... },
"batch_16_hidden_2048": { ... }
}
}
Example file: silu_mul_fp8.json
Config keys should be structured strings that encode the relevant
parameters (e.g., "batch_32_hidden_4096", "seq_512_heads_16", "fp8_batch_64", etc.).
Classes
-------
- ConfigSet: In-memory collection of configs for a kernel with lookup/query APIs.
- ConfigManager: File-level operations for config persistence.
"""
import json
from pathlib import Path
from typing import Any
from vllm.logger import init_logger
from vllm.utils.import_utils import has_helion
if not has_helion():
raise ImportError(
"ConfigManager requires helion to be installed. "
"Install it with: pip install helion"
)
import helion
logger = init_logger(__name__)
class ConfigSet:
"""In-memory collection of Helion configs with lookup/query capabilities."""
# Type alias for nested config structure:
# platform -> config_key -> helion.Config
_ConfigDict = dict[str, dict[str, "helion.Config"]]
def __init__(self, kernel_name: str):
self._kernel_name = kernel_name
self._configs: ConfigSet._ConfigDict = {}
@property
def kernel_name(self) -> str:
return self._kernel_name
def get_config(self, platform: str, config_key: str) -> helion.Config:
platform_dict = self._configs.get(platform)
if platform_dict is None:
avail_platforms = self.get_platforms()
raise KeyError(
f"Config not found for kernel '{self._kernel_name}': "
f"platform '{platform}' not found. "
f"Available platforms: {avail_platforms or '(none)'}"
)
config = platform_dict.get(config_key)
if config is None:
avail_keys = self.get_config_keys(platform)
raise KeyError(
f"Config not found for kernel '{self._kernel_name}': "
f"config_key '{config_key}' not found for platform '{platform}'. "
f"Available config_keys: {avail_keys or '(none)'}"
)
return config
def get_platforms(self) -> list[str]:
return sorted(self._configs.keys())
def get_config_keys(self, platform: str) -> list[str]:
platform_dict = self._configs.get(platform.lower())
if platform_dict is None:
return []
return sorted(platform_dict.keys())
def to_dict(self) -> dict[str, Any]:
result: dict[str, Any] = {}
for platform, config_keys_dict in self._configs.items():
result[platform] = {}
for config_key, config in config_keys_dict.items():
# Convert helion.Config to dict using to_json() + json.loads()
import json
result[platform][config_key] = json.loads(config.to_json())
return result
@classmethod
def from_dict(cls, kernel_name: str, data: dict[str, Any]) -> "ConfigSet":
config_set = cls(kernel_name)
count = 0
for platform, platform_data in data.items():
if platform not in config_set._configs:
config_set._configs[platform] = {}
for config_key, config_data in platform_data.items():
config = helion.Config(**config_data)
config_set._configs[platform][config_key] = config
count += 1
if count > 0:
logger.debug(
"Loaded %d configs for kernel '%s'",
count,
kernel_name,
)
return config_set
class ConfigManager:
"""File-level configuration management for Helion kernels (global singleton)."""
_instance: "ConfigManager | None" = None
_instance_base_dir: Path | None = None
def __new__(cls, base_dir: str | Path | None = None) -> "ConfigManager":
resolved_base_dir = cls._resolve_base_dir(base_dir)
if cls._instance is not None:
# Instance already exists - check for base_dir mismatch
if cls._instance_base_dir != resolved_base_dir:
raise ValueError(
f"ConfigManager singleton already exists with base_dir "
f"'{cls._instance_base_dir}', cannot create with different "
f"base_dir '{resolved_base_dir}'"
)
return cls._instance
# Create new instance
instance = super().__new__(cls)
cls._instance = instance
cls._instance_base_dir = resolved_base_dir
return instance
def __init__(self, base_dir: str | Path | None = None):
# Only initialize if not already initialized
if hasattr(self, "_base_dir"):
return
self._base_dir = self._resolve_base_dir(base_dir)
logger.debug("ConfigManager initialized with base_dir: %s", self._base_dir)
@staticmethod
def _resolve_base_dir(base_dir: str | Path | None) -> Path:
if base_dir is not None:
return Path(base_dir).resolve()
return (Path(__file__).parent / "configs").resolve()
@classmethod
def get_instance(cls) -> "ConfigManager":
if cls._instance is None:
raise RuntimeError(
"ConfigManager instance has not been created. "
"Call ConfigManager(base_dir=...) first to initialize."
)
return cls._instance
@classmethod
def reset_instance(cls) -> None:
"""For testing purposes only."""
cls._instance = None
cls._instance_base_dir = None
def get_config_file_path(self, kernel_name: str) -> Path:
return self._base_dir / f"{kernel_name}.json"
def ensure_base_dir_exists(self) -> Path:
self._base_dir.mkdir(parents=True, exist_ok=True)
return self._base_dir
def load_config_set(self, kernel_name: str) -> ConfigSet:
config_path = self.get_config_file_path(kernel_name)
if not config_path.exists():
return ConfigSet.from_dict(kernel_name, {})
try:
with open(config_path) as f:
data = json.load(f)
return ConfigSet.from_dict(kernel_name, data)
except (json.JSONDecodeError, OSError) as e:
logger.error("Failed to load config file %s: %s", config_path, e)
return ConfigSet.from_dict(kernel_name, {})
def get_platform_configs(
self, kernel_name: str, platform: str
) -> dict[str, helion.Config]:
config_set = self.load_config_set(kernel_name)
config_keys = config_set.get_config_keys(platform)
return {
config_key: config_set.get_config(platform, config_key)
for config_key in config_keys
}
def save_config_set(self, config_set: ConfigSet) -> Path:
config_path = self.get_config_file_path(config_set.kernel_name)
config_path.parent.mkdir(parents=True, exist_ok=True)
with open(config_path, "w") as f:
json.dump(config_set.to_dict(), f, indent=2)
logger.info("Saved config to: %s", config_path)
return config_path