[Kernel] [Helion] [12/N] Use FakeTensorMode to avoid GPU allocation during config key computation (#36563)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -27,6 +27,7 @@ import time
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import helion
|
import helion
|
||||||
@@ -109,7 +110,8 @@ def autotune_kernel(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
inputs_dict = kernel_wrapper.get_inputs()
|
with FakeTensorMode():
|
||||||
|
all_config_keys = list(kernel_wrapper.get_inputs().keys())
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
error_msg = f"Kernel '{kernel_name}' has no input generator registered"
|
error_msg = f"Kernel '{kernel_name}' has no input generator registered"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
@@ -126,15 +128,15 @@ def autotune_kernel(
|
|||||||
"Autotuning kernel '%s' for platform '%s' with %d configs",
|
"Autotuning kernel '%s' for platform '%s' with %d configs",
|
||||||
kernel_name,
|
kernel_name,
|
||||||
platform,
|
platform,
|
||||||
len(inputs_dict),
|
len(all_config_keys),
|
||||||
)
|
)
|
||||||
|
|
||||||
configs_to_autotune = {}
|
|
||||||
if not force:
|
if not force:
|
||||||
existing_configs = config_manager.get_platform_configs(
|
existing_configs = config_manager.get_platform_configs(
|
||||||
kernel_name, platform
|
kernel_name, platform
|
||||||
)
|
)
|
||||||
for config_key, inputs in inputs_dict.items():
|
keys_to_autotune = []
|
||||||
|
for config_key in all_config_keys:
|
||||||
if config_key in existing_configs:
|
if config_key in existing_configs:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Config '%s' already exists for platform '%s', skipping",
|
"Config '%s' already exists for platform '%s', skipping",
|
||||||
@@ -142,12 +144,12 @@ def autotune_kernel(
|
|||||||
platform,
|
platform,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
configs_to_autotune[config_key] = inputs
|
keys_to_autotune.append(config_key)
|
||||||
else:
|
else:
|
||||||
logger.debug("Force mode enabled, will re-autotune all configs")
|
logger.debug("Force mode enabled, will re-autotune all configs")
|
||||||
configs_to_autotune = inputs_dict
|
keys_to_autotune = all_config_keys
|
||||||
|
|
||||||
if not configs_to_autotune:
|
if not keys_to_autotune:
|
||||||
logger.info(
|
logger.info(
|
||||||
"All configs already exist for kernel '%s' on platform '%s'. "
|
"All configs already exist for kernel '%s' on platform '%s'. "
|
||||||
"Use --force to re-autotune.",
|
"Use --force to re-autotune.",
|
||||||
@@ -162,6 +164,9 @@ def autotune_kernel(
|
|||||||
configs={},
|
configs={},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
inputs_dict = kernel_wrapper.get_inputs()
|
||||||
|
configs_to_autotune = {k: inputs_dict[k] for k in keys_to_autotune}
|
||||||
|
|
||||||
total_start_time = time.time()
|
total_start_time = time.time()
|
||||||
autotuned_configs = {}
|
autotuned_configs = {}
|
||||||
failed_configs = []
|
failed_configs = []
|
||||||
|
|||||||
Reference in New Issue
Block a user