[Kernel] [Helion] [1/N] Add Helion ConfigManager (#32740)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
361
tests/kernels/helion/test_config_manager.py
Normal file
361
tests/kernels/helion/test_config_manager.py
Normal file
@@ -0,0 +1,361 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Unit tests for Helion ConfigManager and ConfigSet.
|
||||
|
||||
Tests the simplified configuration management system for Helion custom kernels.
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.utils.import_utils import has_helion
|
||||
|
||||
# Skip entire module if helion is not available
|
||||
if not has_helion():
|
||||
pytest.skip(
|
||||
"Helion is not installed. Install with: pip install vllm[helion]",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
import helion
|
||||
|
||||
from vllm.kernels.helion.config_manager import (
|
||||
ConfigManager,
|
||||
ConfigSet,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_config_manager_singleton():
|
||||
"""Reset ConfigManager singleton before each test."""
|
||||
ConfigManager.reset_instance()
|
||||
yield
|
||||
ConfigManager.reset_instance()
|
||||
|
||||
|
||||
class TestConfigSet:
|
||||
"""Test suite for ConfigSet class."""
|
||||
|
||||
def test_config_set_creation(self):
|
||||
"""Test creating an empty ConfigSet."""
|
||||
config_set = ConfigSet("test_kernel")
|
||||
|
||||
assert config_set.kernel_name == "test_kernel"
|
||||
assert config_set.get_platforms() == []
|
||||
|
||||
def test_config_set_from_dict(self):
|
||||
"""Test creating ConfigSet from dictionary data."""
|
||||
# Use realistic config data that helion.Config can handle
|
||||
config_data = {
|
||||
"block_sizes": [32, 16],
|
||||
"num_warps": 4,
|
||||
"num_stages": 3,
|
||||
"pid_type": "persistent_interleaved",
|
||||
}
|
||||
data = {"h100": {"batch_32_hidden_4096": config_data}}
|
||||
|
||||
config_set = ConfigSet.from_dict("test_kernel", data)
|
||||
|
||||
assert config_set.kernel_name == "test_kernel"
|
||||
assert config_set.get_platforms() == ["h100"]
|
||||
|
||||
# Verify the config was created correctly
|
||||
config = config_set.get_config("h100", "batch_32_hidden_4096")
|
||||
assert isinstance(config, helion.Config)
|
||||
assert config.block_sizes == [32, 16]
|
||||
assert config.num_warps == 4
|
||||
assert config.num_stages == 3
|
||||
assert config.pid_type == "persistent_interleaved"
|
||||
|
||||
def test_config_set_get_config_keyerror(self):
|
||||
"""Test that accessing non-existent configs raises informative KeyErrors."""
|
||||
config_set = ConfigSet("test_kernel")
|
||||
|
||||
with pytest.raises(KeyError, match="platform 'h100' not found"):
|
||||
config_set.get_config("h100", "batch_32_hidden_4096")
|
||||
|
||||
# Use realistic config data
|
||||
config_data = {"num_warps": 8, "num_stages": 4}
|
||||
data = {"h100": {"batch_64_hidden_2048": config_data}}
|
||||
config_set = ConfigSet.from_dict("test_kernel", data)
|
||||
|
||||
with pytest.raises(
|
||||
KeyError, match="config_key 'batch_32_hidden_4096' not found"
|
||||
):
|
||||
config_set.get_config("h100", "batch_32_hidden_4096")
|
||||
|
||||
def test_config_set_get_platforms(self):
|
||||
"""Test get_platforms method."""
|
||||
# Use realistic config data
|
||||
config1 = {"num_warps": 4, "num_stages": 3}
|
||||
config2 = {"num_warps": 8, "num_stages": 5}
|
||||
|
||||
data = {
|
||||
"h100": {"batch_32_hidden_4096": config1},
|
||||
"a100": {"batch_16_hidden_2048": config2},
|
||||
}
|
||||
config_set = ConfigSet.from_dict("test_kernel", data)
|
||||
|
||||
platforms = config_set.get_platforms()
|
||||
assert platforms == ["a100", "h100"] # Should be sorted
|
||||
|
||||
def test_config_set_get_config_keys(self):
|
||||
"""Test get_config_keys method."""
|
||||
# Use realistic config data
|
||||
config1 = {"num_warps": 4, "num_stages": 3}
|
||||
config2 = {"num_warps": 8, "num_stages": 5}
|
||||
|
||||
data = {
|
||||
"h100": {
|
||||
"batch_32_hidden_4096": config1,
|
||||
"batch_64_hidden_2048": config2,
|
||||
}
|
||||
}
|
||||
config_set = ConfigSet.from_dict("test_kernel", data)
|
||||
|
||||
config_keys = config_set.get_config_keys("h100")
|
||||
assert config_keys == ["batch_32_hidden_4096", "batch_64_hidden_2048"]
|
||||
|
||||
assert config_set.get_config_keys("v100") == []
|
||||
|
||||
def test_config_set_to_dict(self):
|
||||
"""Test converting ConfigSet to dictionary."""
|
||||
# Use realistic config data
|
||||
original_config = {
|
||||
"block_sizes": [64, 32],
|
||||
"num_warps": 16,
|
||||
"num_stages": 4,
|
||||
"pid_type": "persistent_blocked",
|
||||
}
|
||||
original_data = {"h100": {"batch_32_hidden_4096": original_config}}
|
||||
|
||||
config_set = ConfigSet.from_dict("test_kernel", original_data)
|
||||
result_data = config_set.to_dict()
|
||||
|
||||
# The result should match the original (Config roundtrip should work)
|
||||
assert result_data == original_data
|
||||
|
||||
|
||||
class TestConfigManager:
|
||||
"""Test suite for ConfigManager class."""
|
||||
|
||||
def test_config_manager_creation_default_base_dir(self):
|
||||
"""Test creating ConfigManager with default base directory."""
|
||||
manager = ConfigManager()
|
||||
assert manager._base_dir.name == "configs"
|
||||
|
||||
def test_config_manager_creation_custom_base_dir(self):
|
||||
"""Test creating ConfigManager with custom base directory."""
|
||||
custom_dir = "/tmp/custom_configs"
|
||||
manager = ConfigManager(base_dir=custom_dir)
|
||||
|
||||
# Paths are resolved, so compare with resolved path
|
||||
assert manager._base_dir == Path(custom_dir).resolve()
|
||||
|
||||
def test_get_config_file_path(self):
|
||||
"""Test getting config file path for a kernel."""
|
||||
manager = ConfigManager(base_dir="/tmp")
|
||||
|
||||
file_path = manager.get_config_file_path("silu_mul_fp8")
|
||||
|
||||
expected_path = Path("/tmp/silu_mul_fp8.json")
|
||||
assert file_path == expected_path
|
||||
|
||||
def test_ensure_base_dir_exists(self):
|
||||
"""Test ensuring base directory exists."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
base_dir = Path(temp_dir) / "non_existent" / "configs"
|
||||
manager = ConfigManager(base_dir=base_dir)
|
||||
assert not base_dir.exists()
|
||||
|
||||
returned_path = manager.ensure_base_dir_exists()
|
||||
|
||||
assert base_dir.exists()
|
||||
assert base_dir.is_dir()
|
||||
assert returned_path == base_dir
|
||||
|
||||
def test_load_config_set_file_not_exists(self):
|
||||
"""Test loading config set when file doesn't exist."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
manager = ConfigManager(base_dir=temp_dir)
|
||||
config_set = manager.load_config_set("non_existent_kernel")
|
||||
|
||||
assert isinstance(config_set, ConfigSet)
|
||||
assert config_set.kernel_name == "non_existent_kernel"
|
||||
assert config_set.get_platforms() == []
|
||||
|
||||
def test_load_config_set_valid_file(self):
|
||||
"""Test loading config set from valid file."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Use realistic config data
|
||||
kernel_config = {
|
||||
"block_sizes": [128, 64],
|
||||
"num_warps": 8,
|
||||
"num_stages": 6,
|
||||
"pid_type": "persistent_interleaved",
|
||||
}
|
||||
config_data = {"h100": {"batch_32_hidden_4096": kernel_config}}
|
||||
config_file = Path(temp_dir) / "test_kernel.json"
|
||||
with open(config_file, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
manager = ConfigManager(base_dir=temp_dir)
|
||||
config_set = manager.load_config_set("test_kernel")
|
||||
|
||||
assert isinstance(config_set, ConfigSet)
|
||||
assert config_set.kernel_name == "test_kernel"
|
||||
assert config_set.get_platforms() == ["h100"]
|
||||
|
||||
# Verify the config was loaded correctly
|
||||
config = config_set.get_config("h100", "batch_32_hidden_4096")
|
||||
assert isinstance(config, helion.Config)
|
||||
assert config.block_sizes == [128, 64]
|
||||
assert config.num_warps == 8
|
||||
|
||||
def test_load_config_set_invalid_json(self):
|
||||
"""Test loading config set from file with invalid JSON."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
config_file = Path(temp_dir) / "test_kernel.json"
|
||||
with open(config_file, "w") as f:
|
||||
f.write("invalid json content {")
|
||||
|
||||
manager = ConfigManager(base_dir=temp_dir)
|
||||
config_set = manager.load_config_set("test_kernel")
|
||||
|
||||
assert isinstance(config_set, ConfigSet)
|
||||
assert config_set.kernel_name == "test_kernel"
|
||||
assert config_set.get_platforms() == []
|
||||
|
||||
def test_save_config_set(self):
|
||||
"""Test saving ConfigSet to file."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Use realistic config data
|
||||
kernel_config = {
|
||||
"block_sizes": [256, 128],
|
||||
"num_warps": 16,
|
||||
"num_stages": 8,
|
||||
"pid_type": "persistent_blocked",
|
||||
}
|
||||
data = {"h100": {"batch_32_hidden_4096": kernel_config}}
|
||||
config_set = ConfigSet.from_dict("test_kernel", data)
|
||||
|
||||
manager = ConfigManager(base_dir=temp_dir)
|
||||
saved_path = manager.save_config_set(config_set)
|
||||
|
||||
expected_path = Path(temp_dir) / "test_kernel.json"
|
||||
assert saved_path == expected_path
|
||||
assert saved_path.exists()
|
||||
|
||||
with open(saved_path) as f:
|
||||
loaded_data = json.load(f)
|
||||
assert loaded_data == data
|
||||
|
||||
def test_save_config_set_creates_directory(self):
|
||||
"""Test that save_config_set creates parent directories if needed."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
nested_dir = Path(temp_dir) / "nested" / "configs"
|
||||
config_set = ConfigSet("test_kernel")
|
||||
|
||||
manager = ConfigManager(base_dir=nested_dir)
|
||||
saved_path = manager.save_config_set(config_set)
|
||||
|
||||
assert nested_dir.exists()
|
||||
assert nested_dir.is_dir()
|
||||
assert saved_path.exists()
|
||||
|
||||
def test_get_platform_configs(self):
|
||||
"""Test getting all configs for a specific platform."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Use realistic config data
|
||||
config_1 = {"num_warps": 4, "num_stages": 3, "block_sizes": [64, 32]}
|
||||
config_2 = {"num_warps": 8, "num_stages": 5, "block_sizes": [128, 64]}
|
||||
default_config = {
|
||||
"num_warps": 16,
|
||||
"num_stages": 7,
|
||||
"block_sizes": [256, 128],
|
||||
}
|
||||
config_3 = {"num_warps": 2, "num_stages": 2, "block_sizes": [32, 16]}
|
||||
|
||||
config_data = {
|
||||
"h100": {
|
||||
"batch_32_hidden_4096": config_1,
|
||||
"batch_64_hidden_2048": config_2,
|
||||
"default": default_config,
|
||||
},
|
||||
"a100": {"batch_16_hidden_1024": config_3},
|
||||
}
|
||||
config_file = Path(temp_dir) / "test_kernel.json"
|
||||
with open(config_file, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
|
||||
manager = ConfigManager(base_dir=temp_dir)
|
||||
|
||||
h100_configs = manager.get_platform_configs("test_kernel", "h100")
|
||||
assert len(h100_configs) == 3
|
||||
assert "batch_32_hidden_4096" in h100_configs
|
||||
assert "batch_64_hidden_2048" in h100_configs
|
||||
assert "default" in h100_configs
|
||||
for config in h100_configs.values():
|
||||
assert isinstance(config, helion.Config)
|
||||
|
||||
# Verify specific config details
|
||||
assert h100_configs["batch_32_hidden_4096"].num_warps == 4
|
||||
assert h100_configs["default"].num_stages == 7
|
||||
|
||||
a100_configs = manager.get_platform_configs("test_kernel", "a100")
|
||||
assert len(a100_configs) == 1
|
||||
assert "batch_16_hidden_1024" in a100_configs
|
||||
assert isinstance(a100_configs["batch_16_hidden_1024"], helion.Config)
|
||||
assert a100_configs["batch_16_hidden_1024"].num_warps == 2
|
||||
|
||||
nonexistent_configs = manager.get_platform_configs("test_kernel", "v100")
|
||||
assert len(nonexistent_configs) == 0
|
||||
|
||||
def test_singleton_returns_same_instance(self):
|
||||
"""Test that ConfigManager returns the same instance on repeated calls."""
|
||||
manager1 = ConfigManager(base_dir="/tmp/test_singleton")
|
||||
manager2 = ConfigManager(base_dir="/tmp/test_singleton")
|
||||
|
||||
assert manager1 is manager2
|
||||
|
||||
def test_singleton_with_default_base_dir(self):
|
||||
"""Test singleton behavior with default base directory."""
|
||||
manager1 = ConfigManager()
|
||||
manager2 = ConfigManager()
|
||||
|
||||
assert manager1 is manager2
|
||||
assert manager1._base_dir == manager2._base_dir
|
||||
|
||||
def test_singleton_error_on_different_base_dir(self):
|
||||
"""Test that ConfigManager raises error when created with different base_dir."""
|
||||
ConfigManager(base_dir="/tmp/first_dir")
|
||||
|
||||
with pytest.raises(ValueError, match="singleton already exists"):
|
||||
ConfigManager(base_dir="/tmp/different_dir")
|
||||
|
||||
def test_reset_instance_allows_new_base_dir(self):
|
||||
"""Test that reset_instance allows creating with a new base_dir."""
|
||||
manager1 = ConfigManager(base_dir="/tmp/first_dir")
|
||||
assert manager1._base_dir == Path("/tmp/first_dir").resolve()
|
||||
|
||||
ConfigManager.reset_instance()
|
||||
|
||||
manager2 = ConfigManager(base_dir="/tmp/second_dir")
|
||||
assert manager2._base_dir == Path("/tmp/second_dir").resolve()
|
||||
assert manager1 is not manager2
|
||||
|
||||
def test_get_instance_returns_existing(self):
|
||||
"""Test that get_instance returns the existing singleton."""
|
||||
manager1 = ConfigManager(base_dir="/tmp/test_get_instance")
|
||||
manager2 = ConfigManager.get_instance()
|
||||
|
||||
assert manager1 is manager2
|
||||
|
||||
def test_get_instance_raises_if_not_initialized(self):
|
||||
"""Test that get_instance raises RuntimeError if no instance exists."""
|
||||
with pytest.raises(RuntimeError, match="has not been created"):
|
||||
ConfigManager.get_instance()
|
||||
3
vllm/kernels/__init__.py
Normal file
3
vllm/kernels/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Kernel implementations for vLLM."""
|
||||
13
vllm/kernels/helion/__init__.py
Normal file
13
vllm/kernels/helion/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Helion integration for vLLM."""
|
||||
|
||||
from vllm.kernels.helion.config_manager import (
|
||||
ConfigManager,
|
||||
ConfigSet,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ConfigManager",
|
||||
"ConfigSet",
|
||||
]
|
||||
231
vllm/kernels/helion/config_manager.py
Normal file
231
vllm/kernels/helion/config_manager.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Configuration management for Helion kernels.
|
||||
|
||||
This module provides centralized configuration file management for Helion custom
|
||||
operations, including naming conventions, directory resolution, and file I/O.
|
||||
|
||||
Config File Structure
|
||||
---------------------
|
||||
Each kernel has a single JSON config file: {kernel_name}.json
|
||||
|
||||
The file uses a simplified 2-layer hierarchical structure:
|
||||
{
|
||||
"h100": { # GPU platform
|
||||
"default": { ... }, # Fallback configuration
|
||||
"batch_32_hidden_4096": { ... },
|
||||
"batch_64_hidden_8192": { ... }
|
||||
},
|
||||
"a100": {
|
||||
"default": { ... },
|
||||
"batch_16_hidden_2048": { ... }
|
||||
}
|
||||
}
|
||||
|
||||
Example file: silu_mul_fp8.json
|
||||
|
||||
Config keys should be structured strings that encode the relevant
|
||||
parameters (e.g., "batch_32_hidden_4096", "seq_512_heads_16", "fp8_batch_64", etc.).
|
||||
|
||||
Classes
|
||||
-------
|
||||
- ConfigSet: In-memory collection of configs for a kernel with lookup/query APIs.
|
||||
- ConfigManager: File-level operations for config persistence.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import has_helion
|
||||
|
||||
if not has_helion():
|
||||
raise ImportError(
|
||||
"ConfigManager requires helion to be installed. "
|
||||
"Install it with: pip install helion"
|
||||
)
|
||||
|
||||
import helion
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ConfigSet:
|
||||
"""In-memory collection of Helion configs with lookup/query capabilities."""
|
||||
|
||||
# Type alias for nested config structure:
|
||||
# platform -> config_key -> helion.Config
|
||||
_ConfigDict = dict[str, dict[str, "helion.Config"]]
|
||||
|
||||
def __init__(self, kernel_name: str):
|
||||
self._kernel_name = kernel_name
|
||||
self._configs: ConfigSet._ConfigDict = {}
|
||||
|
||||
@property
|
||||
def kernel_name(self) -> str:
|
||||
return self._kernel_name
|
||||
|
||||
def get_config(self, platform: str, config_key: str) -> helion.Config:
|
||||
platform_dict = self._configs.get(platform)
|
||||
if platform_dict is None:
|
||||
avail_platforms = self.get_platforms()
|
||||
raise KeyError(
|
||||
f"Config not found for kernel '{self._kernel_name}': "
|
||||
f"platform '{platform}' not found. "
|
||||
f"Available platforms: {avail_platforms or '(none)'}"
|
||||
)
|
||||
|
||||
config = platform_dict.get(config_key)
|
||||
if config is None:
|
||||
avail_keys = self.get_config_keys(platform)
|
||||
raise KeyError(
|
||||
f"Config not found for kernel '{self._kernel_name}': "
|
||||
f"config_key '{config_key}' not found for platform '{platform}'. "
|
||||
f"Available config_keys: {avail_keys or '(none)'}"
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
def get_platforms(self) -> list[str]:
|
||||
return sorted(self._configs.keys())
|
||||
|
||||
def get_config_keys(self, platform: str) -> list[str]:
|
||||
platform_dict = self._configs.get(platform.lower())
|
||||
if platform_dict is None:
|
||||
return []
|
||||
return sorted(platform_dict.keys())
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
result: dict[str, Any] = {}
|
||||
|
||||
for platform, config_keys_dict in self._configs.items():
|
||||
result[platform] = {}
|
||||
|
||||
for config_key, config in config_keys_dict.items():
|
||||
# Convert helion.Config to dict using to_json() + json.loads()
|
||||
import json
|
||||
|
||||
result[platform][config_key] = json.loads(config.to_json())
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, kernel_name: str, data: dict[str, Any]) -> "ConfigSet":
|
||||
config_set = cls(kernel_name)
|
||||
count = 0
|
||||
|
||||
for platform, platform_data in data.items():
|
||||
if platform not in config_set._configs:
|
||||
config_set._configs[platform] = {}
|
||||
|
||||
for config_key, config_data in platform_data.items():
|
||||
config = helion.Config(**config_data)
|
||||
config_set._configs[platform][config_key] = config
|
||||
count += 1
|
||||
|
||||
if count > 0:
|
||||
logger.debug(
|
||||
"Loaded %d configs for kernel '%s'",
|
||||
count,
|
||||
kernel_name,
|
||||
)
|
||||
|
||||
return config_set
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""File-level configuration management for Helion kernels (global singleton)."""
|
||||
|
||||
_instance: "ConfigManager | None" = None
|
||||
_instance_base_dir: Path | None = None
|
||||
|
||||
def __new__(cls, base_dir: str | Path | None = None) -> "ConfigManager":
|
||||
resolved_base_dir = cls._resolve_base_dir(base_dir)
|
||||
|
||||
if cls._instance is not None:
|
||||
# Instance already exists - check for base_dir mismatch
|
||||
if cls._instance_base_dir != resolved_base_dir:
|
||||
raise ValueError(
|
||||
f"ConfigManager singleton already exists with base_dir "
|
||||
f"'{cls._instance_base_dir}', cannot create with different "
|
||||
f"base_dir '{resolved_base_dir}'"
|
||||
)
|
||||
return cls._instance
|
||||
|
||||
# Create new instance
|
||||
instance = super().__new__(cls)
|
||||
cls._instance = instance
|
||||
cls._instance_base_dir = resolved_base_dir
|
||||
return instance
|
||||
|
||||
def __init__(self, base_dir: str | Path | None = None):
|
||||
# Only initialize if not already initialized
|
||||
if hasattr(self, "_base_dir"):
|
||||
return
|
||||
|
||||
self._base_dir = self._resolve_base_dir(base_dir)
|
||||
logger.debug("ConfigManager initialized with base_dir: %s", self._base_dir)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_base_dir(base_dir: str | Path | None) -> Path:
|
||||
if base_dir is not None:
|
||||
return Path(base_dir).resolve()
|
||||
return (Path(__file__).parent / "configs").resolve()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "ConfigManager":
|
||||
if cls._instance is None:
|
||||
raise RuntimeError(
|
||||
"ConfigManager instance has not been created. "
|
||||
"Call ConfigManager(base_dir=...) first to initialize."
|
||||
)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def reset_instance(cls) -> None:
|
||||
"""For testing purposes only."""
|
||||
cls._instance = None
|
||||
cls._instance_base_dir = None
|
||||
|
||||
def get_config_file_path(self, kernel_name: str) -> Path:
|
||||
return self._base_dir / f"{kernel_name}.json"
|
||||
|
||||
def ensure_base_dir_exists(self) -> Path:
|
||||
self._base_dir.mkdir(parents=True, exist_ok=True)
|
||||
return self._base_dir
|
||||
|
||||
def load_config_set(self, kernel_name: str) -> ConfigSet:
|
||||
config_path = self.get_config_file_path(kernel_name)
|
||||
if not config_path.exists():
|
||||
return ConfigSet.from_dict(kernel_name, {})
|
||||
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
data = json.load(f)
|
||||
return ConfigSet.from_dict(kernel_name, data)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.error("Failed to load config file %s: %s", config_path, e)
|
||||
return ConfigSet.from_dict(kernel_name, {})
|
||||
|
||||
def get_platform_configs(
|
||||
self, kernel_name: str, platform: str
|
||||
) -> dict[str, helion.Config]:
|
||||
config_set = self.load_config_set(kernel_name)
|
||||
config_keys = config_set.get_config_keys(platform)
|
||||
|
||||
return {
|
||||
config_key: config_set.get_config(platform, config_key)
|
||||
for config_key in config_keys
|
||||
}
|
||||
|
||||
def save_config_set(self, config_set: ConfigSet) -> Path:
|
||||
config_path = self.get_config_file_path(config_set.kernel_name)
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config_set.to_dict(), f, indent=2)
|
||||
|
||||
logger.info("Saved config to: %s", config_path)
|
||||
return config_path
|
||||
Reference in New Issue
Block a user