[Kernel] [Helion] [15/N] Split config files into per-platform files (#36698)

Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Yanan Cao
2026-03-11 14:25:29 -07:00
committed by GitHub
parent a3774a8198
commit cf632499ee
5 changed files with 27832 additions and 27803 deletions

View File

@@ -160,10 +160,11 @@ class TestConfigManager:
"""Test getting config file path for a kernel."""
manager = ConfigManager(base_dir="/tmp")
file_path = manager.get_config_file_path("silu_mul_fp8")
dir_path = manager.get_config_file_path("silu_mul_fp8")
assert dir_path == Path("/tmp/silu_mul_fp8")
expected_path = Path("/tmp/silu_mul_fp8.json")
assert file_path == expected_path
file_path = manager.get_config_file_path("silu_mul_fp8", "nvidia_h100")
assert file_path == Path("/tmp/silu_mul_fp8/nvidia_h100.json")
def test_ensure_base_dir_exists(self):
"""Test ensuring base directory exists."""
@@ -189,19 +190,19 @@ class TestConfigManager:
assert config_set.get_platforms() == []
def test_load_config_set_valid_file(self):
"""Test loading config set from valid file."""
"""Test loading config set from per-platform files."""
with tempfile.TemporaryDirectory() as temp_dir:
# Use realistic config data
kernel_config = {
"block_sizes": [128, 64],
"num_warps": 8,
"num_stages": 6,
"pid_type": "persistent_interleaved",
}
config_data = {"h100": {"batch_32_hidden_4096": kernel_config}}
config_file = Path(temp_dir) / "test_kernel.json"
with open(config_file, "w") as f:
json.dump(config_data, f)
kernel_dir = Path(temp_dir) / "test_kernel"
kernel_dir.mkdir()
platform_file = kernel_dir / "h100.json"
with open(platform_file, "w") as f:
json.dump({"batch_32_hidden_4096": kernel_config}, f)
manager = ConfigManager(base_dir=temp_dir)
config_set = manager.load_config_set("test_kernel")
@@ -210,7 +211,6 @@ class TestConfigManager:
assert config_set.kernel_name == "test_kernel"
assert config_set.get_platforms() == ["h100"]
# Verify the config was loaded correctly
config = config_set.get_config("h100", "batch_32_hidden_4096")
assert isinstance(config, helion.Config)
assert config.block_sizes == [128, 64]
@@ -219,7 +219,9 @@ class TestConfigManager:
def test_load_config_set_invalid_json(self):
"""Test loading config set from file with invalid JSON."""
with tempfile.TemporaryDirectory() as temp_dir:
config_file = Path(temp_dir) / "test_kernel.json"
kernel_dir = Path(temp_dir) / "test_kernel"
kernel_dir.mkdir()
config_file = kernel_dir / "h100.json"
with open(config_file, "w") as f:
f.write("invalid json content {")
@@ -231,9 +233,8 @@ class TestConfigManager:
assert config_set.get_platforms() == []
def test_save_config_set(self):
"""Test saving ConfigSet to file."""
"""Test saving ConfigSet to per-platform files."""
with tempfile.TemporaryDirectory() as temp_dir:
# Use realistic config data
kernel_config = {
"block_sizes": [256, 128],
"num_warps": 16,
@@ -246,31 +247,34 @@ class TestConfigManager:
manager = ConfigManager(base_dir=temp_dir)
saved_path = manager.save_config_set(config_set)
expected_path = Path(temp_dir) / "test_kernel.json"
assert saved_path == expected_path
assert saved_path.exists()
expected_dir = Path(temp_dir) / "test_kernel"
assert saved_path == expected_dir
assert saved_path.is_dir()
with open(saved_path) as f:
platform_file = expected_dir / "h100.json"
assert platform_file.exists()
with open(platform_file) as f:
loaded_data = json.load(f)
assert loaded_data == data
assert loaded_data == data["h100"]
def test_save_config_set_creates_directory(self):
"""Test that save_config_set creates parent directories if needed."""
with tempfile.TemporaryDirectory() as temp_dir:
nested_dir = Path(temp_dir) / "nested" / "configs"
config_set = ConfigSet("test_kernel")
data = {"h100": {"default": {"num_warps": 4}}}
config_set = ConfigSet.from_dict("test_kernel", data)
manager = ConfigManager(base_dir=nested_dir)
saved_path = manager.save_config_set(config_set)
assert nested_dir.exists()
assert nested_dir.is_dir()
assert saved_path.exists()
assert saved_path.is_dir()
assert (saved_path / "h100.json").exists()
def test_get_platform_configs(self):
"""Test getting all configs for a specific platform."""
with tempfile.TemporaryDirectory() as temp_dir:
# Use realistic config data
config_1 = {"num_warps": 4, "num_stages": 3, "block_sizes": [64, 32]}
config_2 = {"num_warps": 8, "num_stages": 5, "block_sizes": [128, 64]}
default_config = {
@@ -280,17 +284,19 @@ class TestConfigManager:
}
config_3 = {"num_warps": 2, "num_stages": 2, "block_sizes": [32, 16]}
config_data = {
"h100": {
"batch_32_hidden_4096": config_1,
"batch_64_hidden_2048": config_2,
"default": default_config,
},
"a100": {"batch_16_hidden_1024": config_3},
}
config_file = Path(temp_dir) / "test_kernel.json"
with open(config_file, "w") as f:
json.dump(config_data, f)
kernel_dir = Path(temp_dir) / "test_kernel"
kernel_dir.mkdir()
with open(kernel_dir / "h100.json", "w") as f:
json.dump(
{
"batch_32_hidden_4096": config_1,
"batch_64_hidden_2048": config_2,
"default": default_config,
},
f,
)
with open(kernel_dir / "a100.json", "w") as f:
json.dump({"batch_16_hidden_1024": config_3}, f)
manager = ConfigManager(base_dir=temp_dir)
@@ -302,7 +308,6 @@ class TestConfigManager:
for config in h100_configs.values():
assert isinstance(config, helion.Config)
# Verify specific config details
assert h100_configs["batch_32_hidden_4096"].num_warps == 4
assert h100_configs["default"].num_stages == 7

View File

@@ -8,23 +8,15 @@ operations, including naming conventions, directory resolution, and file I/O.
Config File Structure
---------------------
Each kernel has a single JSON config file: {kernel_name}.json
Each kernel has a directory: {kernel_name}/
Inside, each GPU platform has its own JSON file: {kernel_name}/{platform}.json
The file uses a simplified 2-layer hierarchical structure:
{
"h100": { # GPU platform
"default": { ... }, # Fallback configuration
"batch_32_hidden_4096": { ... },
"batch_64_hidden_8192": { ... }
},
"a100": {
"default": { ... },
"batch_16_hidden_2048": { ... }
}
}
Example file: silu_mul_fp8.json
For example:
silu_mul_fp8/
nvidia_h100.json # { "default": {...}, "batch_32_hidden_4096": {...} }
nvidia_h200.json # { "batch_16_hidden_2048": {...} }
Each platform file maps config keys to Helion config objects.
Config keys should be structured strings that encode the relevant
parameters (e.g., "batch_32_hidden_4096", "seq_512_heads_16", "fp8_batch_64", etc.).
@@ -212,8 +204,15 @@ class ConfigManager:
cls._instance = None
cls._instance_base_dir = None
def get_config_file_path(self, kernel_name: str) -> Path:
return self._base_dir / f"{kernel_name}.json"
def get_kernel_dir(self, kernel_name: str) -> Path:
return self._base_dir / kernel_name
def get_config_file_path(
self, kernel_name: str, platform: str | None = None
) -> Path:
if platform is not None:
return self.get_kernel_dir(kernel_name) / f"{platform}.json"
return self.get_kernel_dir(kernel_name)
def ensure_base_dir_exists(self) -> Path:
self._base_dir.mkdir(parents=True, exist_ok=True)
@@ -230,39 +229,59 @@ class ConfigManager:
f"Config directory '{self._base_dir}' is not writable: {e}"
) from e
def load_config_set(self, kernel_name: str) -> ConfigSet:
config_path = self.get_config_file_path(kernel_name)
def _load_platform_file(self, kernel_name: str, platform: str) -> dict[str, Any]:
config_path = self.get_config_file_path(kernel_name, platform)
if not config_path.exists():
return ConfigSet.from_dict(kernel_name, {})
return {}
try:
with open(config_path) as f:
data = json.load(f)
return ConfigSet.from_dict(kernel_name, data)
return json.load(f)
except (json.JSONDecodeError, OSError) as e:
logger.error("Failed to load config file %s: %s", config_path, e)
return {}
def load_config_set(self, kernel_name: str) -> ConfigSet:
kernel_dir = self.get_kernel_dir(kernel_name)
if not kernel_dir.is_dir():
return ConfigSet.from_dict(kernel_name, {})
data: dict[str, Any] = {}
for platform_file in sorted(kernel_dir.glob("*.json")):
platform = platform_file.stem
try:
with open(platform_file) as f:
platform_data = json.load(f)
data[platform] = platform_data
except (json.JSONDecodeError, OSError) as e:
logger.error("Failed to load config file %s: %s", platform_file, e)
return ConfigSet.from_dict(kernel_name, data)
def get_platform_configs(
self, kernel_name: str, platform: str
) -> dict[str, helion.Config]:
config_set = self.load_config_set(kernel_name)
platform_data = self._load_platform_file(kernel_name, platform)
if not platform_data:
return {}
config_set = ConfigSet.from_dict(kernel_name, {platform: platform_data})
config_keys = config_set.get_config_keys(platform)
return {
config_key: config_set.get_config(platform, config_key)
for config_key in config_keys
}
def save_config_set(self, config_set: ConfigSet) -> Path:
config_path = self.get_config_file_path(config_set.kernel_name)
config_path.parent.mkdir(parents=True, exist_ok=True)
kernel_dir = self.get_kernel_dir(config_set.kernel_name)
kernel_dir.mkdir(parents=True, exist_ok=True)
with open(config_path, "w") as f:
json.dump(config_set.to_dict(), f, indent=2)
full_data = config_set.to_dict()
for platform, platform_data in full_data.items():
platform_path = kernel_dir / f"{platform}.json"
with open(platform_path, "w") as f:
json.dump(platform_data, f, indent=2)
logger.info("Saved config to: %s", platform_path)
logger.info("Saved config to: %s", config_path)
return config_path
return kernel_dir
def save_configs(
self,
@@ -271,11 +290,18 @@ class ConfigManager:
configs: dict[str, "helion.Config"],
) -> Path:
"""Save configs for a kernel/platform, merging with existing."""
config_set = self.load_config_set(kernel_name)
platform_data = self._load_platform_file(kernel_name, platform)
for config_key, config in configs.items():
config_set.set_config(platform, config_key, config)
return self.save_config_set(config_set)
platform_data[config_key] = json.loads(config.to_json())
platform_path = self.get_config_file_path(kernel_name, platform)
platform_path.parent.mkdir(parents=True, exist_ok=True)
with open(platform_path, "w") as f:
json.dump(platform_data, f, indent=2)
logger.info("Saved config to: %s", platform_path)
return platform_path
def config_exists(self, kernel_name: str, platform: str, config_key: str) -> bool:
config_set = self.load_config_set(kernel_name)
return config_set.has_config(platform, config_key)
platform_data = self._load_platform_file(kernel_name, platform)
return config_key in platform_data

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff