[Kernel] [Helion] [15/N] Split config files into per-platform files (#36698)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -160,10 +160,11 @@ class TestConfigManager:
|
||||
"""Test getting config file path for a kernel."""
|
||||
manager = ConfigManager(base_dir="/tmp")
|
||||
|
||||
file_path = manager.get_config_file_path("silu_mul_fp8")
|
||||
dir_path = manager.get_config_file_path("silu_mul_fp8")
|
||||
assert dir_path == Path("/tmp/silu_mul_fp8")
|
||||
|
||||
expected_path = Path("/tmp/silu_mul_fp8.json")
|
||||
assert file_path == expected_path
|
||||
file_path = manager.get_config_file_path("silu_mul_fp8", "nvidia_h100")
|
||||
assert file_path == Path("/tmp/silu_mul_fp8/nvidia_h100.json")
|
||||
|
||||
def test_ensure_base_dir_exists(self):
|
||||
"""Test ensuring base directory exists."""
|
||||
@@ -189,19 +190,19 @@ class TestConfigManager:
|
||||
assert config_set.get_platforms() == []
|
||||
|
||||
def test_load_config_set_valid_file(self):
|
||||
"""Test loading config set from valid file."""
|
||||
"""Test loading config set from per-platform files."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Use realistic config data
|
||||
kernel_config = {
|
||||
"block_sizes": [128, 64],
|
||||
"num_warps": 8,
|
||||
"num_stages": 6,
|
||||
"pid_type": "persistent_interleaved",
|
||||
}
|
||||
config_data = {"h100": {"batch_32_hidden_4096": kernel_config}}
|
||||
config_file = Path(temp_dir) / "test_kernel.json"
|
||||
with open(config_file, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
kernel_dir = Path(temp_dir) / "test_kernel"
|
||||
kernel_dir.mkdir()
|
||||
platform_file = kernel_dir / "h100.json"
|
||||
with open(platform_file, "w") as f:
|
||||
json.dump({"batch_32_hidden_4096": kernel_config}, f)
|
||||
|
||||
manager = ConfigManager(base_dir=temp_dir)
|
||||
config_set = manager.load_config_set("test_kernel")
|
||||
@@ -210,7 +211,6 @@ class TestConfigManager:
|
||||
assert config_set.kernel_name == "test_kernel"
|
||||
assert config_set.get_platforms() == ["h100"]
|
||||
|
||||
# Verify the config was loaded correctly
|
||||
config = config_set.get_config("h100", "batch_32_hidden_4096")
|
||||
assert isinstance(config, helion.Config)
|
||||
assert config.block_sizes == [128, 64]
|
||||
@@ -219,7 +219,9 @@ class TestConfigManager:
|
||||
def test_load_config_set_invalid_json(self):
|
||||
"""Test loading config set from file with invalid JSON."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
config_file = Path(temp_dir) / "test_kernel.json"
|
||||
kernel_dir = Path(temp_dir) / "test_kernel"
|
||||
kernel_dir.mkdir()
|
||||
config_file = kernel_dir / "h100.json"
|
||||
with open(config_file, "w") as f:
|
||||
f.write("invalid json content {")
|
||||
|
||||
@@ -231,9 +233,8 @@ class TestConfigManager:
|
||||
assert config_set.get_platforms() == []
|
||||
|
||||
def test_save_config_set(self):
|
||||
"""Test saving ConfigSet to file."""
|
||||
"""Test saving ConfigSet to per-platform files."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Use realistic config data
|
||||
kernel_config = {
|
||||
"block_sizes": [256, 128],
|
||||
"num_warps": 16,
|
||||
@@ -246,31 +247,34 @@ class TestConfigManager:
|
||||
manager = ConfigManager(base_dir=temp_dir)
|
||||
saved_path = manager.save_config_set(config_set)
|
||||
|
||||
expected_path = Path(temp_dir) / "test_kernel.json"
|
||||
assert saved_path == expected_path
|
||||
assert saved_path.exists()
|
||||
expected_dir = Path(temp_dir) / "test_kernel"
|
||||
assert saved_path == expected_dir
|
||||
assert saved_path.is_dir()
|
||||
|
||||
with open(saved_path) as f:
|
||||
platform_file = expected_dir / "h100.json"
|
||||
assert platform_file.exists()
|
||||
with open(platform_file) as f:
|
||||
loaded_data = json.load(f)
|
||||
assert loaded_data == data
|
||||
assert loaded_data == data["h100"]
|
||||
|
||||
def test_save_config_set_creates_directory(self):
|
||||
"""Test that save_config_set creates parent directories if needed."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
nested_dir = Path(temp_dir) / "nested" / "configs"
|
||||
config_set = ConfigSet("test_kernel")
|
||||
data = {"h100": {"default": {"num_warps": 4}}}
|
||||
config_set = ConfigSet.from_dict("test_kernel", data)
|
||||
|
||||
manager = ConfigManager(base_dir=nested_dir)
|
||||
saved_path = manager.save_config_set(config_set)
|
||||
|
||||
assert nested_dir.exists()
|
||||
assert nested_dir.is_dir()
|
||||
assert saved_path.exists()
|
||||
assert saved_path.is_dir()
|
||||
assert (saved_path / "h100.json").exists()
|
||||
|
||||
def test_get_platform_configs(self):
|
||||
"""Test getting all configs for a specific platform."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Use realistic config data
|
||||
config_1 = {"num_warps": 4, "num_stages": 3, "block_sizes": [64, 32]}
|
||||
config_2 = {"num_warps": 8, "num_stages": 5, "block_sizes": [128, 64]}
|
||||
default_config = {
|
||||
@@ -280,17 +284,19 @@ class TestConfigManager:
|
||||
}
|
||||
config_3 = {"num_warps": 2, "num_stages": 2, "block_sizes": [32, 16]}
|
||||
|
||||
config_data = {
|
||||
"h100": {
|
||||
"batch_32_hidden_4096": config_1,
|
||||
"batch_64_hidden_2048": config_2,
|
||||
"default": default_config,
|
||||
},
|
||||
"a100": {"batch_16_hidden_1024": config_3},
|
||||
}
|
||||
config_file = Path(temp_dir) / "test_kernel.json"
|
||||
with open(config_file, "w") as f:
|
||||
json.dump(config_data, f)
|
||||
kernel_dir = Path(temp_dir) / "test_kernel"
|
||||
kernel_dir.mkdir()
|
||||
with open(kernel_dir / "h100.json", "w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"batch_32_hidden_4096": config_1,
|
||||
"batch_64_hidden_2048": config_2,
|
||||
"default": default_config,
|
||||
},
|
||||
f,
|
||||
)
|
||||
with open(kernel_dir / "a100.json", "w") as f:
|
||||
json.dump({"batch_16_hidden_1024": config_3}, f)
|
||||
|
||||
manager = ConfigManager(base_dir=temp_dir)
|
||||
|
||||
@@ -302,7 +308,6 @@ class TestConfigManager:
|
||||
for config in h100_configs.values():
|
||||
assert isinstance(config, helion.Config)
|
||||
|
||||
# Verify specific config details
|
||||
assert h100_configs["batch_32_hidden_4096"].num_warps == 4
|
||||
assert h100_configs["default"].num_stages == 7
|
||||
|
||||
|
||||
@@ -8,23 +8,15 @@ operations, including naming conventions, directory resolution, and file I/O.
|
||||
|
||||
Config File Structure
|
||||
---------------------
|
||||
Each kernel has a single JSON config file: {kernel_name}.json
|
||||
Each kernel has a directory: {kernel_name}/
|
||||
Inside, each GPU platform has its own JSON file: {kernel_name}/{platform}.json
|
||||
|
||||
The file uses a simplified 2-layer hierarchical structure:
|
||||
{
|
||||
"h100": { # GPU platform
|
||||
"default": { ... }, # Fallback configuration
|
||||
"batch_32_hidden_4096": { ... },
|
||||
"batch_64_hidden_8192": { ... }
|
||||
},
|
||||
"a100": {
|
||||
"default": { ... },
|
||||
"batch_16_hidden_2048": { ... }
|
||||
}
|
||||
}
|
||||
|
||||
Example file: silu_mul_fp8.json
|
||||
For example:
|
||||
silu_mul_fp8/
|
||||
nvidia_h100.json # { "default": {...}, "batch_32_hidden_4096": {...} }
|
||||
nvidia_h200.json # { "batch_16_hidden_2048": {...} }
|
||||
|
||||
Each platform file maps config keys to Helion config objects.
|
||||
Config keys should be structured strings that encode the relevant
|
||||
parameters (e.g., "batch_32_hidden_4096", "seq_512_heads_16", "fp8_batch_64", etc.).
|
||||
|
||||
@@ -212,8 +204,15 @@ class ConfigManager:
|
||||
cls._instance = None
|
||||
cls._instance_base_dir = None
|
||||
|
||||
def get_config_file_path(self, kernel_name: str) -> Path:
|
||||
return self._base_dir / f"{kernel_name}.json"
|
||||
def get_kernel_dir(self, kernel_name: str) -> Path:
|
||||
return self._base_dir / kernel_name
|
||||
|
||||
def get_config_file_path(
|
||||
self, kernel_name: str, platform: str | None = None
|
||||
) -> Path:
|
||||
if platform is not None:
|
||||
return self.get_kernel_dir(kernel_name) / f"{platform}.json"
|
||||
return self.get_kernel_dir(kernel_name)
|
||||
|
||||
def ensure_base_dir_exists(self) -> Path:
|
||||
self._base_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -230,39 +229,59 @@ class ConfigManager:
|
||||
f"Config directory '{self._base_dir}' is not writable: {e}"
|
||||
) from e
|
||||
|
||||
def load_config_set(self, kernel_name: str) -> ConfigSet:
|
||||
config_path = self.get_config_file_path(kernel_name)
|
||||
def _load_platform_file(self, kernel_name: str, platform: str) -> dict[str, Any]:
|
||||
config_path = self.get_config_file_path(kernel_name, platform)
|
||||
if not config_path.exists():
|
||||
return ConfigSet.from_dict(kernel_name, {})
|
||||
|
||||
return {}
|
||||
try:
|
||||
with open(config_path) as f:
|
||||
data = json.load(f)
|
||||
return ConfigSet.from_dict(kernel_name, data)
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.error("Failed to load config file %s: %s", config_path, e)
|
||||
return {}
|
||||
|
||||
def load_config_set(self, kernel_name: str) -> ConfigSet:
|
||||
kernel_dir = self.get_kernel_dir(kernel_name)
|
||||
if not kernel_dir.is_dir():
|
||||
return ConfigSet.from_dict(kernel_name, {})
|
||||
|
||||
data: dict[str, Any] = {}
|
||||
for platform_file in sorted(kernel_dir.glob("*.json")):
|
||||
platform = platform_file.stem
|
||||
try:
|
||||
with open(platform_file) as f:
|
||||
platform_data = json.load(f)
|
||||
data[platform] = platform_data
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.error("Failed to load config file %s: %s", platform_file, e)
|
||||
|
||||
return ConfigSet.from_dict(kernel_name, data)
|
||||
|
||||
def get_platform_configs(
|
||||
self, kernel_name: str, platform: str
|
||||
) -> dict[str, helion.Config]:
|
||||
config_set = self.load_config_set(kernel_name)
|
||||
platform_data = self._load_platform_file(kernel_name, platform)
|
||||
if not platform_data:
|
||||
return {}
|
||||
config_set = ConfigSet.from_dict(kernel_name, {platform: platform_data})
|
||||
config_keys = config_set.get_config_keys(platform)
|
||||
|
||||
return {
|
||||
config_key: config_set.get_config(platform, config_key)
|
||||
for config_key in config_keys
|
||||
}
|
||||
|
||||
def save_config_set(self, config_set: ConfigSet) -> Path:
|
||||
config_path = self.get_config_file_path(config_set.kernel_name)
|
||||
config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
kernel_dir = self.get_kernel_dir(config_set.kernel_name)
|
||||
kernel_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config_set.to_dict(), f, indent=2)
|
||||
full_data = config_set.to_dict()
|
||||
for platform, platform_data in full_data.items():
|
||||
platform_path = kernel_dir / f"{platform}.json"
|
||||
with open(platform_path, "w") as f:
|
||||
json.dump(platform_data, f, indent=2)
|
||||
logger.info("Saved config to: %s", platform_path)
|
||||
|
||||
logger.info("Saved config to: %s", config_path)
|
||||
return config_path
|
||||
return kernel_dir
|
||||
|
||||
def save_configs(
|
||||
self,
|
||||
@@ -271,11 +290,18 @@ class ConfigManager:
|
||||
configs: dict[str, "helion.Config"],
|
||||
) -> Path:
|
||||
"""Save configs for a kernel/platform, merging with existing."""
|
||||
config_set = self.load_config_set(kernel_name)
|
||||
platform_data = self._load_platform_file(kernel_name, platform)
|
||||
for config_key, config in configs.items():
|
||||
config_set.set_config(platform, config_key, config)
|
||||
return self.save_config_set(config_set)
|
||||
platform_data[config_key] = json.loads(config.to_json())
|
||||
|
||||
platform_path = self.get_config_file_path(kernel_name, platform)
|
||||
platform_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(platform_path, "w") as f:
|
||||
json.dump(platform_data, f, indent=2)
|
||||
|
||||
logger.info("Saved config to: %s", platform_path)
|
||||
return platform_path
|
||||
|
||||
def config_exists(self, kernel_name: str, platform: str, config_key: str) -> bool:
|
||||
config_set = self.load_config_set(kernel_name)
|
||||
return config_set.has_config(platform, config_key)
|
||||
platform_data = self._load_platform_file(kernel_name, platform)
|
||||
return config_key in platform_data
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
13866
vllm/kernels/helion/configs/silu_mul_fp8/nvidia_h100.json
Normal file
13866
vllm/kernels/helion/configs/silu_mul_fp8/nvidia_h100.json
Normal file
File diff suppressed because it is too large
Load Diff
13866
vllm/kernels/helion/configs/silu_mul_fp8/nvidia_h200.json
Normal file
13866
vllm/kernels/helion/configs/silu_mul_fp8/nvidia_h200.json
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user