diff --git a/scripts/autotune_helion_kernels.py b/scripts/autotune_helion_kernels.py index 755ba3115..c02d2a020 100644 --- a/scripts/autotune_helion_kernels.py +++ b/scripts/autotune_helion_kernels.py @@ -27,6 +27,7 @@ import time from dataclasses import dataclass import torch +from torch._subclasses.fake_tensor import FakeTensorMode try: import helion @@ -109,7 +110,8 @@ def autotune_kernel( ) try: - inputs_dict = kernel_wrapper.get_inputs() + with FakeTensorMode(): + all_config_keys = list(kernel_wrapper.get_inputs().keys()) except NotImplementedError: error_msg = f"Kernel '{kernel_name}' has no input generator registered" logger.error(error_msg) @@ -126,15 +128,15 @@ def autotune_kernel( "Autotuning kernel '%s' for platform '%s' with %d configs", kernel_name, platform, - len(inputs_dict), + len(all_config_keys), ) - configs_to_autotune = {} if not force: existing_configs = config_manager.get_platform_configs( kernel_name, platform ) - for config_key, inputs in inputs_dict.items(): + keys_to_autotune = [] + for config_key in all_config_keys: if config_key in existing_configs: logger.debug( "Config '%s' already exists for platform '%s', skipping", @@ -142,12 +144,12 @@ def autotune_kernel( platform, ) else: - configs_to_autotune[config_key] = inputs + keys_to_autotune.append(config_key) else: logger.debug("Force mode enabled, will re-autotune all configs") - configs_to_autotune = inputs_dict + keys_to_autotune = all_config_keys - if not configs_to_autotune: + if not keys_to_autotune: logger.info( "All configs already exist for kernel '%s' on platform '%s'. " "Use --force to re-autotune.", @@ -162,6 +164,9 @@ def autotune_kernel( configs={}, ) + inputs_dict = kernel_wrapper.get_inputs() + configs_to_autotune = {k: inputs_dict[k] for k in keys_to_autotune} + total_start_time = time.time() autotuned_configs = {} failed_configs = []