[Kernel] [Helion] [12/N] Use FakeTensorMode to avoid GPU allocation during config key computation (#36563)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -27,6 +27,7 @@ import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
|
||||
try:
|
||||
import helion
|
||||
@@ -109,7 +110,8 @@ def autotune_kernel(
|
||||
)
|
||||
|
||||
try:
|
||||
inputs_dict = kernel_wrapper.get_inputs()
|
||||
with FakeTensorMode():
|
||||
all_config_keys = list(kernel_wrapper.get_inputs().keys())
|
||||
except NotImplementedError:
|
||||
error_msg = f"Kernel '{kernel_name}' has no input generator registered"
|
||||
logger.error(error_msg)
|
||||
@@ -126,15 +128,15 @@ def autotune_kernel(
|
||||
"Autotuning kernel '%s' for platform '%s' with %d configs",
|
||||
kernel_name,
|
||||
platform,
|
||||
len(inputs_dict),
|
||||
len(all_config_keys),
|
||||
)
|
||||
|
||||
configs_to_autotune = {}
|
||||
if not force:
|
||||
existing_configs = config_manager.get_platform_configs(
|
||||
kernel_name, platform
|
||||
)
|
||||
for config_key, inputs in inputs_dict.items():
|
||||
keys_to_autotune = []
|
||||
for config_key in all_config_keys:
|
||||
if config_key in existing_configs:
|
||||
logger.debug(
|
||||
"Config '%s' already exists for platform '%s', skipping",
|
||||
@@ -142,12 +144,12 @@ def autotune_kernel(
|
||||
platform,
|
||||
)
|
||||
else:
|
||||
configs_to_autotune[config_key] = inputs
|
||||
keys_to_autotune.append(config_key)
|
||||
else:
|
||||
logger.debug("Force mode enabled, will re-autotune all configs")
|
||||
configs_to_autotune = inputs_dict
|
||||
keys_to_autotune = all_config_keys
|
||||
|
||||
if not configs_to_autotune:
|
||||
if not keys_to_autotune:
|
||||
logger.info(
|
||||
"All configs already exist for kernel '%s' on platform '%s'. "
|
||||
"Use --force to re-autotune.",
|
||||
@@ -162,6 +164,9 @@ def autotune_kernel(
|
||||
configs={},
|
||||
)
|
||||
|
||||
inputs_dict = kernel_wrapper.get_inputs()
|
||||
configs_to_autotune = {k: inputs_dict[k] for k in keys_to_autotune}
|
||||
|
||||
total_start_time = time.time()
|
||||
autotuned_configs = {}
|
||||
failed_configs = []
|
||||
|
||||
Reference in New Issue
Block a user