[Kernel] [Helion] [15/N] Split config files into per-platform files (#36698)

Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Yanan Cao
2026-03-11 14:25:29 -07:00
committed by GitHub
parent a3774a8198
commit cf632499ee
5 changed files with 27832 additions and 27803 deletions

View File

@@ -160,10 +160,11 @@ class TestConfigManager:
"""Test getting config file path for a kernel."""
manager = ConfigManager(base_dir="/tmp")
file_path = manager.get_config_file_path("silu_mul_fp8")
dir_path = manager.get_config_file_path("silu_mul_fp8")
assert dir_path == Path("/tmp/silu_mul_fp8")
expected_path = Path("/tmp/silu_mul_fp8.json")
assert file_path == expected_path
file_path = manager.get_config_file_path("silu_mul_fp8", "nvidia_h100")
assert file_path == Path("/tmp/silu_mul_fp8/nvidia_h100.json")
def test_ensure_base_dir_exists(self):
"""Test ensuring base directory exists."""
@@ -189,19 +190,19 @@ class TestConfigManager:
assert config_set.get_platforms() == []
def test_load_config_set_valid_file(self):
"""Test loading config set from valid file."""
"""Test loading config set from per-platform files."""
with tempfile.TemporaryDirectory() as temp_dir:
# Use realistic config data
kernel_config = {
"block_sizes": [128, 64],
"num_warps": 8,
"num_stages": 6,
"pid_type": "persistent_interleaved",
}
config_data = {"h100": {"batch_32_hidden_4096": kernel_config}}
config_file = Path(temp_dir) / "test_kernel.json"
with open(config_file, "w") as f:
json.dump(config_data, f)
kernel_dir = Path(temp_dir) / "test_kernel"
kernel_dir.mkdir()
platform_file = kernel_dir / "h100.json"
with open(platform_file, "w") as f:
json.dump({"batch_32_hidden_4096": kernel_config}, f)
manager = ConfigManager(base_dir=temp_dir)
config_set = manager.load_config_set("test_kernel")
@@ -210,7 +211,6 @@ class TestConfigManager:
assert config_set.kernel_name == "test_kernel"
assert config_set.get_platforms() == ["h100"]
# Verify the config was loaded correctly
config = config_set.get_config("h100", "batch_32_hidden_4096")
assert isinstance(config, helion.Config)
assert config.block_sizes == [128, 64]
@@ -219,7 +219,9 @@ class TestConfigManager:
def test_load_config_set_invalid_json(self):
"""Test loading config set from file with invalid JSON."""
with tempfile.TemporaryDirectory() as temp_dir:
config_file = Path(temp_dir) / "test_kernel.json"
kernel_dir = Path(temp_dir) / "test_kernel"
kernel_dir.mkdir()
config_file = kernel_dir / "h100.json"
with open(config_file, "w") as f:
f.write("invalid json content {")
@@ -231,9 +233,8 @@ class TestConfigManager:
assert config_set.get_platforms() == []
def test_save_config_set(self):
"""Test saving ConfigSet to file."""
"""Test saving ConfigSet to per-platform files."""
with tempfile.TemporaryDirectory() as temp_dir:
# Use realistic config data
kernel_config = {
"block_sizes": [256, 128],
"num_warps": 16,
@@ -246,31 +247,34 @@ class TestConfigManager:
manager = ConfigManager(base_dir=temp_dir)
saved_path = manager.save_config_set(config_set)
expected_path = Path(temp_dir) / "test_kernel.json"
assert saved_path == expected_path
assert saved_path.exists()
expected_dir = Path(temp_dir) / "test_kernel"
assert saved_path == expected_dir
assert saved_path.is_dir()
with open(saved_path) as f:
platform_file = expected_dir / "h100.json"
assert platform_file.exists()
with open(platform_file) as f:
loaded_data = json.load(f)
assert loaded_data == data
assert loaded_data == data["h100"]
def test_save_config_set_creates_directory(self):
"""Test that save_config_set creates parent directories if needed."""
with tempfile.TemporaryDirectory() as temp_dir:
nested_dir = Path(temp_dir) / "nested" / "configs"
config_set = ConfigSet("test_kernel")
data = {"h100": {"default": {"num_warps": 4}}}
config_set = ConfigSet.from_dict("test_kernel", data)
manager = ConfigManager(base_dir=nested_dir)
saved_path = manager.save_config_set(config_set)
assert nested_dir.exists()
assert nested_dir.is_dir()
assert saved_path.exists()
assert saved_path.is_dir()
assert (saved_path / "h100.json").exists()
def test_get_platform_configs(self):
"""Test getting all configs for a specific platform."""
with tempfile.TemporaryDirectory() as temp_dir:
# Use realistic config data
config_1 = {"num_warps": 4, "num_stages": 3, "block_sizes": [64, 32]}
config_2 = {"num_warps": 8, "num_stages": 5, "block_sizes": [128, 64]}
default_config = {
@@ -280,17 +284,19 @@ class TestConfigManager:
}
config_3 = {"num_warps": 2, "num_stages": 2, "block_sizes": [32, 16]}
config_data = {
"h100": {
"batch_32_hidden_4096": config_1,
"batch_64_hidden_2048": config_2,
"default": default_config,
},
"a100": {"batch_16_hidden_1024": config_3},
}
config_file = Path(temp_dir) / "test_kernel.json"
with open(config_file, "w") as f:
json.dump(config_data, f)
kernel_dir = Path(temp_dir) / "test_kernel"
kernel_dir.mkdir()
with open(kernel_dir / "h100.json", "w") as f:
json.dump(
{
"batch_32_hidden_4096": config_1,
"batch_64_hidden_2048": config_2,
"default": default_config,
},
f,
)
with open(kernel_dir / "a100.json", "w") as f:
json.dump({"batch_16_hidden_1024": config_3}, f)
manager = ConfigManager(base_dir=temp_dir)
@@ -302,7 +308,6 @@ class TestConfigManager:
for config in h100_configs.values():
assert isinstance(config, helion.Config)
# Verify specific config details
assert h100_configs["batch_32_hidden_4096"].num_warps == 4
assert h100_configs["default"].num_stages == 7