[Kernel] [Helion] [15/N] Split config files into per-platform files (#36698)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -160,10 +160,11 @@ class TestConfigManager:
|
||||
"""Test getting config file path for a kernel."""
|
||||
manager = ConfigManager(base_dir="/tmp")
|
||||
|
||||
file_path = manager.get_config_file_path("silu_mul_fp8")
|
||||
dir_path = manager.get_config_file_path("silu_mul_fp8")
|
||||
assert dir_path == Path("/tmp/silu_mul_fp8")
|
||||
|
||||
expected_path = Path("/tmp/silu_mul_fp8.json")
|
||||
assert file_path == expected_path
|
||||
file_path = manager.get_config_file_path("silu_mul_fp8", "nvidia_h100")
|
||||
assert file_path == Path("/tmp/silu_mul_fp8/nvidia_h100.json")
|
||||
|
||||
def test_ensure_base_dir_exists(self):
|
||||
"""Test ensuring base directory exists."""
|
||||
@@ -189,19 +190,19 @@ class TestConfigManager:
|
||||
assert config_set.get_platforms() == []
|
||||
|
||||
def test_load_config_set_valid_file(self):
|
||||
"""Test loading config set from valid file."""
|
||||
"""Test loading config set from per-platform files."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Use realistic config data
|
||||
kernel_config = {
|
||||
"block_sizes": [128, 64],
|
||||
"num_warps": 8,
|
||||
"num_stages": 6,
|
||||
"pid_type": "persistent_interleaved",
|
||||
}
|
||||
config_data = {"h100": {"batch_32_hidden_4096": kernel_config}}
|
||||
config_file = Path(temp_dir) / "test_kernel.json"
|
||||
with open(config_file, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
kernel_dir = Path(temp_dir) / "test_kernel"
|
||||
kernel_dir.mkdir()
|
||||
platform_file = kernel_dir / "h100.json"
|
||||
with open(platform_file, "w") as f:
|
||||
json.dump({"batch_32_hidden_4096": kernel_config}, f)
|
||||
|
||||
manager = ConfigManager(base_dir=temp_dir)
|
||||
config_set = manager.load_config_set("test_kernel")
|
||||
@@ -210,7 +211,6 @@ class TestConfigManager:
|
||||
assert config_set.kernel_name == "test_kernel"
|
||||
assert config_set.get_platforms() == ["h100"]
|
||||
|
||||
# Verify the config was loaded correctly
|
||||
config = config_set.get_config("h100", "batch_32_hidden_4096")
|
||||
assert isinstance(config, helion.Config)
|
||||
assert config.block_sizes == [128, 64]
|
||||
@@ -219,7 +219,9 @@ class TestConfigManager:
|
||||
def test_load_config_set_invalid_json(self):
|
||||
"""Test loading config set from file with invalid JSON."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
config_file = Path(temp_dir) / "test_kernel.json"
|
||||
kernel_dir = Path(temp_dir) / "test_kernel"
|
||||
kernel_dir.mkdir()
|
||||
config_file = kernel_dir / "h100.json"
|
||||
with open(config_file, "w") as f:
|
||||
f.write("invalid json content {")
|
||||
|
||||
@@ -231,9 +233,8 @@ class TestConfigManager:
|
||||
assert config_set.get_platforms() == []
|
||||
|
||||
def test_save_config_set(self):
|
||||
"""Test saving ConfigSet to file."""
|
||||
"""Test saving ConfigSet to per-platform files."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Use realistic config data
|
||||
kernel_config = {
|
||||
"block_sizes": [256, 128],
|
||||
"num_warps": 16,
|
||||
@@ -246,31 +247,34 @@ class TestConfigManager:
|
||||
manager = ConfigManager(base_dir=temp_dir)
|
||||
saved_path = manager.save_config_set(config_set)
|
||||
|
||||
expected_path = Path(temp_dir) / "test_kernel.json"
|
||||
assert saved_path == expected_path
|
||||
assert saved_path.exists()
|
||||
expected_dir = Path(temp_dir) / "test_kernel"
|
||||
assert saved_path == expected_dir
|
||||
assert saved_path.is_dir()
|
||||
|
||||
with open(saved_path) as f:
|
||||
platform_file = expected_dir / "h100.json"
|
||||
assert platform_file.exists()
|
||||
with open(platform_file) as f:
|
||||
loaded_data = json.load(f)
|
||||
assert loaded_data == data
|
||||
assert loaded_data == data["h100"]
|
||||
|
||||
def test_save_config_set_creates_directory(self):
|
||||
"""Test that save_config_set creates parent directories if needed."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
nested_dir = Path(temp_dir) / "nested" / "configs"
|
||||
config_set = ConfigSet("test_kernel")
|
||||
data = {"h100": {"default": {"num_warps": 4}}}
|
||||
config_set = ConfigSet.from_dict("test_kernel", data)
|
||||
|
||||
manager = ConfigManager(base_dir=nested_dir)
|
||||
saved_path = manager.save_config_set(config_set)
|
||||
|
||||
assert nested_dir.exists()
|
||||
assert nested_dir.is_dir()
|
||||
assert saved_path.exists()
|
||||
assert saved_path.is_dir()
|
||||
assert (saved_path / "h100.json").exists()
|
||||
|
||||
def test_get_platform_configs(self):
|
||||
"""Test getting all configs for a specific platform."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Use realistic config data
|
||||
config_1 = {"num_warps": 4, "num_stages": 3, "block_sizes": [64, 32]}
|
||||
config_2 = {"num_warps": 8, "num_stages": 5, "block_sizes": [128, 64]}
|
||||
default_config = {
|
||||
@@ -280,17 +284,19 @@ class TestConfigManager:
|
||||
}
|
||||
config_3 = {"num_warps": 2, "num_stages": 2, "block_sizes": [32, 16]}
|
||||
|
||||
config_data = {
|
||||
"h100": {
|
||||
"batch_32_hidden_4096": config_1,
|
||||
"batch_64_hidden_2048": config_2,
|
||||
"default": default_config,
|
||||
},
|
||||
"a100": {"batch_16_hidden_1024": config_3},
|
||||
}
|
||||
config_file = Path(temp_dir) / "test_kernel.json"
|
||||
with open(config_file, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
kernel_dir = Path(temp_dir) / "test_kernel"
|
||||
kernel_dir.mkdir()
|
||||
with open(kernel_dir / "h100.json", "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"batch_32_hidden_4096": config_1,
|
||||
"batch_64_hidden_2048": config_2,
|
||||
"default": default_config,
|
||||
},
|
||||
f,
|
||||
)
|
||||
with open(kernel_dir / "a100.json", "w") as f:
|
||||
json.dump({"batch_16_hidden_1024": config_3}, f)
|
||||
|
||||
manager = ConfigManager(base_dir=temp_dir)
|
||||
|
||||
@@ -302,7 +308,6 @@ class TestConfigManager:
|
||||
for config in h100_configs.values():
|
||||
assert isinstance(config, helion.Config)
|
||||
|
||||
# Verify specific config details
|
||||
assert h100_configs["batch_32_hidden_4096"].num_warps == 4
|
||||
assert h100_configs["default"].num_stages == 7
|
||||
|
||||
|
||||
Reference in New Issue
Block a user