diff --git a/tests/kernels/helion/test_register.py b/tests/kernels/helion/test_register.py index faac2765c..02b05be74 100644 --- a/tests/kernels/helion/test_register.py +++ b/tests/kernels/helion/test_register.py @@ -554,10 +554,18 @@ class TestKernelRegistry: """Test suite for kernel registry functionality.""" def setup_method(self): - """Clear the registry before each test.""" + """Save and clear the registry before each test.""" + from vllm.kernels.helion.register import _REGISTERED_KERNELS + + self._saved_registry = dict(_REGISTERED_KERNELS) + _REGISTERED_KERNELS.clear() + + def teardown_method(self): + """Restore the registry after each test.""" from vllm.kernels.helion.register import _REGISTERED_KERNELS _REGISTERED_KERNELS.clear() + _REGISTERED_KERNELS.update(self._saved_registry) def test_get_registered_kernels_returns_copy(self): """Test get_registered_kernels returns copy of registry.""" diff --git a/tests/kernels/helion/test_silu_mul_fp8.py b/tests/kernels/helion/test_silu_mul_fp8.py new file mode 100644 index 000000000..da6405d6c --- /dev/null +++ b/tests/kernels/helion/test_silu_mul_fp8.py @@ -0,0 +1,331 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch +import torch.nn.functional as F + +from vllm.utils.import_utils import has_helion + +if not has_helion(): + pytest.skip( + "Helion is not installed. Install with: pip install vllm[helion]", + allow_module_level=True, + ) + +from vllm.kernels.helion.config_manager import ConfigManager +from vllm.kernels.helion.ops.silu_mul_fp8 import ( + pick_silu_mul_fp8_config, + silu_mul_fp8, + silu_mul_fp8_baseline, +) + + +def skip_if_platform_unsupported(): + try: + from vllm.kernels.helion.utils import get_canonical_gpu_name + + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + platform = get_canonical_gpu_name() + + try: + config_manager = ConfigManager.get_instance() + except RuntimeError: + config_manager = ConfigManager() + + configs = config_manager.get_platform_configs("silu_mul_fp8", platform) + if len(configs) == 0: + pytest.skip("Current GPU platform not supported for silu_mul_fp8 kernel") + + except (ImportError, RuntimeError, KeyError): + pytest.skip("Error detecting platform support for silu_mul_fp8 kernel") + + +@pytest.fixture(autouse=True) +def reset_config_manager_singleton(): + ConfigManager.reset_instance() + ConfigManager() + yield + ConfigManager.reset_instance() + + +class TestSiluMulFp8ConfigPicker: + def test_config_picker_exact_match(self): + config_keys = [ + "intermediate_2048_batchsize_256", + "intermediate_4096_batchsize_256", + ] + + input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda") + scale = torch.tensor([0.5], dtype=torch.float32, device="cuda") + args = (input_tensor, scale) + + selected_key = pick_silu_mul_fp8_config(args, config_keys) + assert selected_key == "intermediate_2048_batchsize_256" + + def test_config_picker_closest_match(self): + config_keys = [ + "intermediate_2048_batchsize_256", + "intermediate_4096_batchsize_256", + ] + # Use 7000 (intermediate_size=3500) which is closer to 4096 than 2048 + input_tensor = torch.randn(32, 7000, dtype=torch.bfloat16, device="cuda") + scale = torch.tensor([0.5], dtype=torch.float32, device="cuda") + args = (input_tensor, scale) + + selected_key = pick_silu_mul_fp8_config(args, config_keys) + assert selected_key == "intermediate_4096_batchsize_256" + + def test_config_picker_fallback_to_default(self): + config_keys = ["default", "some_other_key"] + + input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda") + scale = torch.tensor([0.5], dtype=torch.float32, device="cuda") + args = (input_tensor, scale) + + selected_key = pick_silu_mul_fp8_config(args, config_keys) + assert selected_key == "default" + + def test_config_picker_no_configs(self): + config_keys: list[str] = [] + + input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda") + scale = torch.tensor([0.5], dtype=torch.float32, device="cuda") + args = (input_tensor, scale) + + selected_key = pick_silu_mul_fp8_config(args, config_keys) + assert selected_key is None + + @pytest.mark.parametrize("intermediate_size", [2048, 4096, 5120]) + def test_config_picker_different_sizes(self, intermediate_size): + config_keys = [ + "intermediate_2048_batchsize_256", + "intermediate_4096_batchsize_256", + "intermediate_5120_batchsize_256", + ] + + input_tensor = torch.randn( + 32, 2 * intermediate_size, dtype=torch.bfloat16, device="cuda" + ) + scale = torch.tensor([0.5], dtype=torch.float32, device="cuda") + args = (input_tensor, scale) + + selected_key = pick_silu_mul_fp8_config(args, config_keys) + expected_key = f"intermediate_{intermediate_size}_batchsize_256" + assert selected_key == expected_key + + +class TestSiluMulFp8Correctness: + @pytest.mark.parametrize("batch_size", [1, 8, 32, 128]) + @pytest.mark.parametrize("intermediate_size", [2048, 3000, 3500, 4096, 5000]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_silu_mul_fp8_correctness(self, batch_size, intermediate_size, dtype): + skip_if_platform_unsupported() + + input_size = 2 * intermediate_size + input_tensor = torch.randn(batch_size, input_size, dtype=dtype, device="cuda") + scale = torch.tensor([0.5], dtype=torch.float32, device="cuda") + + reference_output = silu_mul_fp8_baseline(input_tensor, scale) + helion_output = silu_mul_fp8(input_tensor, scale) + + assert helion_output.shape == reference_output.shape + assert helion_output.dtype == torch.float8_e4m3fn + assert reference_output.dtype == torch.float8_e4m3fn + + ref_f32 = reference_output.to(torch.float32) + helion_f32 = helion_output.to(torch.float32) + # FP8 E4M3 has limited precision. Values near quantization boundaries + # can round differently due to intermediate precision differences. + torch.testing.assert_close( + helion_f32, + ref_f32, + atol=0.05, + rtol=0.05, + msg=f"Mismatch at batch={batch_size}, size={intermediate_size}", + ) + + def test_silu_mul_fp8_shape_inference(self): + skip_if_platform_unsupported() + batch_size, input_size = 32, 8192 + intermediate_size = input_size // 2 + + input_tensor = torch.randn( + batch_size, input_size, dtype=torch.bfloat16, device="cuda" + ) + scale = torch.tensor([0.5], dtype=torch.float32, device="cuda") + + output = silu_mul_fp8(input_tensor, scale) + + expected_shape = (batch_size, intermediate_size) + assert output.shape == expected_shape + assert output.dtype == torch.float8_e4m3fn + + def test_silu_mul_fp8_scale_variations(self): + skip_if_platform_unsupported() + batch_size, input_size = 16, 4096 + + input_tensor = torch.randn( + batch_size, input_size, dtype=torch.bfloat16, device="cuda" + ) + + scales = [0.1, 0.5, 1.0, 2.0, 10.0] + + for scale_val in scales: + scale = torch.tensor([scale_val], dtype=torch.float32, device="cuda") + + reference_output = silu_mul_fp8_baseline(input_tensor, scale) + helion_output = silu_mul_fp8(input_tensor, scale) + ref_f32 = reference_output.to(torch.float32) + helion_f32 = helion_output.to(torch.float32) + + torch.testing.assert_close( + helion_f32, + ref_f32, + atol=0.05, + rtol=0.05, + msg=f"Mismatch for scale={scale_val}", + ) + + @pytest.mark.parametrize( + "shape", + [ + (1, 4096), + (16, 4096), + (128, 4096), + (1024, 4096), + (1, 8192), + (16, 8192), + (128, 8192), + ], + ) + def test_silu_mul_fp8_various_shapes(self, shape): + skip_if_platform_unsupported() + + input_tensor = torch.randn(*shape, dtype=torch.bfloat16, device="cuda") + scale = torch.tensor([0.5], dtype=torch.float32, device="cuda") + + reference_output = silu_mul_fp8_baseline(input_tensor, scale) + helion_output = silu_mul_fp8(input_tensor, scale) + + assert helion_output.shape == reference_output.shape + + ref_f32 = reference_output.to(torch.float32) + helion_f32 = helion_output.to(torch.float32) + + torch.testing.assert_close( + helion_f32, ref_f32, atol=0.05, rtol=0.05, msg=f"Mismatch for shape={shape}" + ) + + +def silu_mul_fp8_pytorch(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """Pure PyTorch reference using F.silu. + + This matches vLLM's SiluAndMul.forward_native exactly: + F.silu(x[..., :d]) * x[..., d:] + """ + d = input.shape[-1] // 2 + result = F.silu(input[..., :d]) * input[..., d:] + return (result.to(torch.float32) / scale).to(torch.float8_e4m3fn) + + +class TestSiluMulFp8PytorchReference: + """Tests comparing Helion kernel against pure PyTorch implementation. + + Uses tighter tolerance since both use PyTorch's FP8 conversion + (same rounding mode), unlike the vLLM C++ baseline which uses + NVIDIA's hardware FP8 conversion with different rounding. + """ + + @pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 256]) + @pytest.mark.parametrize("intermediate_size", [1024, 2048, 4096]) + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_silu_mul_fp8_vs_pytorch(self, batch_size, intermediate_size, dtype): + skip_if_platform_unsupported() + + input_tensor = torch.randn( + batch_size, 2 * intermediate_size, dtype=dtype, device="cuda" + ) + scale = torch.tensor([0.5], dtype=torch.float32, device="cuda") + + pytorch_output = silu_mul_fp8_pytorch(input_tensor, scale) + helion_output = silu_mul_fp8(input_tensor, scale) + + assert helion_output.shape == pytorch_output.shape + assert helion_output.dtype == torch.float8_e4m3fn + + pytorch_f32 = pytorch_output.to(torch.float32) + helion_f32 = helion_output.to(torch.float32) + + # Tolerance accounts for FP8 quantization boundary effects + torch.testing.assert_close( + helion_f32, + pytorch_f32, + atol=0.05, + rtol=0.05, + msg=( + f"Mismatch at batch={batch_size}, size={intermediate_size}, " + f"dtype={dtype}" + ), + ) + + @pytest.mark.parametrize( + "shape", + [ + (1, 2, 4096), # 3D input + (2, 4, 2048), # 3D input + (1, 1, 1, 8192), # 4D input + ], + ) + def test_silu_mul_fp8_multidim_vs_pytorch(self, shape): + skip_if_platform_unsupported() + + input_tensor = torch.randn(*shape, dtype=torch.bfloat16, device="cuda") + scale = torch.tensor([0.5], dtype=torch.float32, device="cuda") + + pytorch_output = silu_mul_fp8_pytorch(input_tensor, scale) + helion_output = silu_mul_fp8(input_tensor, scale) + + assert helion_output.shape == pytorch_output.shape + + pytorch_f32 = pytorch_output.to(torch.float32) + helion_f32 = helion_output.to(torch.float32) + + torch.testing.assert_close( + helion_f32, + pytorch_f32, + atol=0.05, + rtol=0.05, + msg=f"Mismatch for shape={shape}", + ) + + +class TestSiluMulFp8Integration: + def test_kernel_registration_integration(self): + from vllm.kernels.helion.register import get_registered_kernels + + registered_kernels = get_registered_kernels() + assert "silu_mul_fp8" in registered_kernels + + kernel_wrapper = registered_kernels["silu_mul_fp8"] + assert kernel_wrapper.op_name == "silu_mul_fp8" + assert kernel_wrapper._config_picker is not None + + def test_fake_impl_functionality(self): + skip_if_platform_unsupported() + from vllm.kernels.helion.register import get_registered_kernels + + input_tensor = torch.randn(32, 4096, dtype=torch.bfloat16, device="cuda") + scale = torch.tensor([0.5], dtype=torch.float32, device="cuda") + registered_kernels = get_registered_kernels() + kernel_wrapper = registered_kernels["silu_mul_fp8"] + fake_impl = kernel_wrapper._fake_impl + + fake_output = fake_impl(input_tensor, scale) + + expected_shape = (32, 2048) + assert fake_output.shape == expected_shape + assert fake_output.dtype == torch.float8_e4m3fn + assert fake_output.device == input_tensor.device diff --git a/vllm/kernels/helion/__init__.py b/vllm/kernels/helion/__init__.py index dfbf28b8d..2568baa20 100644 --- a/vllm/kernels/helion/__init__.py +++ b/vllm/kernels/helion/__init__.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Helion integration for vLLM.""" +import vllm.kernels.helion.ops # noqa: F401 Auto-register all Helion ops from vllm.kernels.helion.config_manager import ( ConfigManager, ConfigSet, diff --git a/vllm/kernels/helion/config_manager.py b/vllm/kernels/helion/config_manager.py index 59d5bf430..63560761e 100644 --- a/vllm/kernels/helion/config_manager.py +++ b/vllm/kernels/helion/config_manager.py @@ -104,9 +104,6 @@ class ConfigSet: result[platform] = {} for config_key, config in config_keys_dict.items(): - # Convert helion.Config to dict using to_json() + json.loads() - import json - result[platform][config_key] = json.loads(config.to_json()) return result diff --git a/vllm/kernels/helion/configs/silu_mul_fp8.json b/vllm/kernels/helion/configs/silu_mul_fp8.json new file mode 100644 index 000000000..c26ca087d --- /dev/null +++ b/vllm/kernels/helion/configs/silu_mul_fp8.json @@ -0,0 +1,550 @@ +{ + "nvidia_h200": { + "intermediate_2048_batchsize_256": { + "block_sizes": [ + 64, + 128 + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "flatten_loops": [ + true + ], + "l2_groupings": [ + 2 + ], + "range_unroll_factors": [ + 0 + ], + "range_num_stages": [ + 0 + ], + "range_multi_buffers": [ + null + ], + "range_flattens": [ + null + ], + "load_eviction_policies": [ + "", + "", + "" + ], + "num_warps": 32, + "num_stages": 1, + "indexing": [ + "pointer", + "tensor_descriptor", + "pointer", + "pointer" + ], + "pid_type": "flat", + "range_warp_specializes": [] + }, + "intermediate_4096_batchsize_256": { + "block_sizes": [ + 16, + 64 + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "flatten_loops": [ + false + ], + "l2_groupings": [ + 1 + ], + "range_unroll_factors": [ + 0 + ], + "range_num_stages": [ + 0 + ], + "range_multi_buffers": [ + null + ], + "range_flattens": [ + null + ], + "load_eviction_policies": [ + "", + "", + "" + ], + "num_warps": 2, + "num_stages": 1, + "indexing": [ + "pointer", + "pointer", + "pointer", + "pointer" + ], + "pid_type": "flat", + "range_warp_specializes": [] + }, + "default": { + "block_sizes": [ + 1, + 512 + ], + "loop_orders": [ + [ + 1, + 0 + ] + ], + "flatten_loops": [ + false + ], + "l2_groupings": [ + 4 + ], + "range_unroll_factors": [ + 0 + ], + "range_num_stages": [ + 0 + ], + "range_multi_buffers": [ + null + ], + "range_flattens": [ + null + ], + "load_eviction_policies": [ + "", + "first", + "" + ], + "num_warps": 8, + "num_stages": 2, + "indexing": [ + "tensor_descriptor", + "tensor_descriptor", + "tensor_descriptor", + "pointer" + ], + "pid_type": "flat", + "range_warp_specializes": [] + } + }, + "nvidia_h100_pcie": { + "intermediate_2048_batchsize_256": { + "block_sizes": [ + 1, + 512 + ], + "loop_orders": [ + [ + 1, + 0 + ] + ], + "flatten_loops": [ + false + ], + "l2_groupings": [ + 4 + ], + "range_unroll_factors": [ + 0 + ], + "range_num_stages": [ + 0 + ], + "range_multi_buffers": [ + null + ], + "range_flattens": [ + null + ], + "load_eviction_policies": [ + "", + "first", + "" + ], + "num_warps": 8, + "num_stages": 2, + "indexing": [ + "tensor_descriptor", + "tensor_descriptor", + "tensor_descriptor", + "pointer" + ], + "pid_type": "flat", + "range_warp_specializes": [] + }, + "intermediate_4096_batchsize_256": { + "block_sizes": [ + 256, + 128 + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "flatten_loops": [ + true + ], + "l2_groupings": [ + 1 + ], + "range_unroll_factors": [ + 2 + ], + "range_num_stages": [ + 3 + ], + "range_multi_buffers": [ + false + ], + "range_flattens": [ + true + ], + "load_eviction_policies": [ + "last", + "last", + "" + ], + "num_warps": 32, + "num_stages": 3, + "indexing": [ + "tensor_descriptor", + "tensor_descriptor", + "tensor_descriptor", + "pointer" + ], + "pid_type": "persistent_blocked", + "range_warp_specializes": [] + }, + "default": { + "block_sizes": [ + 1, + 512 + ], + "loop_orders": [ + [ + 1, + 0 + ] + ], + "flatten_loops": [ + false + ], + "l2_groupings": [ + 4 + ], + "range_unroll_factors": [ + 0 + ], + "range_num_stages": [ + 0 + ], + "range_multi_buffers": [ + null + ], + "range_flattens": [ + null + ], + "load_eviction_policies": [ + "", + "first", + "" + ], + "num_warps": 8, + "num_stages": 2, + "indexing": [ + "tensor_descriptor", + "tensor_descriptor", + "tensor_descriptor", + "pointer" + ], + "pid_type": "flat", + "range_warp_specializes": [] + } + }, + "nvidia_h100_sxm5": { + "intermediate_2048_batchsize_256": { + "block_sizes": [ + 1, + 512 + ], + "loop_orders": [ + [ + 1, + 0 + ] + ], + "flatten_loops": [ + false + ], + "l2_groupings": [ + 4 + ], + "range_unroll_factors": [ + 0 + ], + "range_num_stages": [ + 0 + ], + "range_multi_buffers": [ + null + ], + "range_flattens": [ + null + ], + "load_eviction_policies": [ + "", + "first", + "" + ], + "num_warps": 8, + "num_stages": 2, + "indexing": [ + "tensor_descriptor", + "tensor_descriptor", + "tensor_descriptor", + "pointer" + ], + "pid_type": "flat", + "range_warp_specializes": [] + }, + "intermediate_4096_batchsize_256": { + "block_sizes": [ + 256, + 128 + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "flatten_loops": [ + true + ], + "l2_groupings": [ + 1 + ], + "range_unroll_factors": [ + 2 + ], + "range_num_stages": [ + 3 + ], + "range_multi_buffers": [ + false + ], + "range_flattens": [ + true + ], + "load_eviction_policies": [ + "last", + "last", + "" + ], + "num_warps": 32, + "num_stages": 3, + "indexing": [ + "tensor_descriptor", + "tensor_descriptor", + "tensor_descriptor", + "pointer" + ], + "pid_type": "persistent_blocked", + "range_warp_specializes": [] + }, + "default": { + "block_sizes": [ + 1, + 512 + ], + "loop_orders": [ + [ + 1, + 0 + ] + ], + "flatten_loops": [ + false + ], + "l2_groupings": [ + 4 + ], + "range_unroll_factors": [ + 0 + ], + "range_num_stages": [ + 0 + ], + "range_multi_buffers": [ + null + ], + "range_flattens": [ + null + ], + "load_eviction_policies": [ + "", + "first", + "" + ], + "num_warps": 8, + "num_stages": 2, + "indexing": [ + "tensor_descriptor", + "tensor_descriptor", + "tensor_descriptor", + "pointer" + ], + "pid_type": "flat", + "range_warp_specializes": [] + } + }, + "nvidia_h100": { + "intermediate_2048_batchsize_256": { + "block_sizes": [ + 1, + 512 + ], + "loop_orders": [ + [ + 1, + 0 + ] + ], + "flatten_loops": [ + false + ], + "l2_groupings": [ + 4 + ], + "range_unroll_factors": [ + 0 + ], + "range_num_stages": [ + 0 + ], + "range_multi_buffers": [ + null + ], + "range_flattens": [ + null + ], + "load_eviction_policies": [ + "", + "first", + "" + ], + "num_warps": 8, + "num_stages": 2, + "indexing": [ + "tensor_descriptor", + "tensor_descriptor", + "tensor_descriptor", + "pointer" + ], + "pid_type": "flat", + "range_warp_specializes": [] + }, + "intermediate_4096_batchsize_256": { + "block_sizes": [ + 256, + 128 + ], + "loop_orders": [ + [ + 0, + 1 + ] + ], + "flatten_loops": [ + true + ], + "l2_groupings": [ + 1 + ], + "range_unroll_factors": [ + 2 + ], + "range_num_stages": [ + 3 + ], + "range_multi_buffers": [ + false + ], + "range_flattens": [ + true + ], + "load_eviction_policies": [ + "last", + "last", + "" + ], + "num_warps": 32, + "num_stages": 3, + "indexing": [ + "tensor_descriptor", + "tensor_descriptor", + "tensor_descriptor", + "pointer" + ], + "pid_type": "persistent_blocked", + "range_warp_specializes": [] + }, + "default": { + "block_sizes": [ + 1, + 512 + ], + "loop_orders": [ + [ + 1, + 0 + ] + ], + "flatten_loops": [ + false + ], + "l2_groupings": [ + 4 + ], + "range_unroll_factors": [ + 0 + ], + "range_num_stages": [ + 0 + ], + "range_multi_buffers": [ + null + ], + "range_flattens": [ + null + ], + "load_eviction_policies": [ + "", + "first", + "" + ], + "num_warps": 8, + "num_stages": 2, + "indexing": [ + "tensor_descriptor", + "tensor_descriptor", + "tensor_descriptor", + "pointer" + ], + "pid_type": "flat", + "range_warp_specializes": [] + } + } +} \ No newline at end of file diff --git a/vllm/kernels/helion/ops/__init__.py b/vllm/kernels/helion/ops/__init__.py new file mode 100644 index 000000000..eacd483bb --- /dev/null +++ b/vllm/kernels/helion/ops/__init__.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Auto-import all Helion op modules to trigger kernel registration.""" + +import importlib +import pkgutil + +# Automatically import all submodules so that @register_kernel +# decorators execute and register ops with torch.ops.vllm_helion. +for _module_info in pkgutil.iter_modules(__path__): + importlib.import_module(f"{__name__}.{_module_info.name}") diff --git a/vllm/kernels/helion/ops/silu_mul_fp8.py b/vllm/kernels/helion/ops/silu_mul_fp8.py new file mode 100644 index 000000000..a45943b1a --- /dev/null +++ b/vllm/kernels/helion/ops/silu_mul_fp8.py @@ -0,0 +1,100 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any + +import torch + +from vllm.logger import init_logger +from vllm.utils.import_utils import has_helion + +if not has_helion(): + raise ImportError( + "silu_mul_fp8 Helion kernel requires helion to be installed. " + "Install it with: pip install helion" + ) + +import helion.language as hl + +from vllm.kernels.helion.register import register_kernel + +logger = init_logger(__name__) + + +@register_kernel # type: ignore[misc] +def silu_mul_fp8(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + original_shape = input.shape + two_d = hl.specialize(original_shape[-1]) + d = two_d // 2 + output_shape = original_shape[:-1] + (d,) + + input_2d = input.view(-1, original_shape[-1]) + m = input_2d.shape[0] + + # TODO(gmagogsfm): Support for more float8 subtypes (e4m3fnuz, e5m2) coming + out = torch.empty((m, d), device=input.device, dtype=torch.float8_e4m3fn) + + input_part_a = input_2d[:, :d] + input_part_b = input_2d[:, d:] + + assert scale.numel() == 1, "Scale must be a scalar Tensor" + + for tile_m, tile_n in hl.tile([m, d]): + a_vals = input_part_a[tile_m, tile_n] + silu_result = torch.nn.functional.silu(a_vals) + b_vals = input_part_b[tile_m, tile_n] + result = silu_result * b_vals + result_f32 = result.to(torch.float32) + scale_val = hl.load(scale, [0]) + inv_scale = 1.0 / scale_val + result_scaled = result_f32 * inv_scale + out[tile_m, tile_n] = result_scaled.to(out.dtype) + + return out.view(output_shape) + + +@silu_mul_fp8.register_config_picker # type: ignore[misc] +def pick_silu_mul_fp8_config( + args: tuple[Any, ...], config_keys: list[str] +) -> str | None: + if not config_keys: + return None + + input_tensor, scale = args + intermediate_size = input_tensor.shape[-1] // 2 + + # TODO(gmagosfm): Rerun autotuning to capture config for + # other batch sizes. + target_key = f"intermediate_{intermediate_size}_batchsize_256" + if target_key in config_keys: + return target_key + + intermediate_sizes = [] + for key in config_keys: + if key.startswith("intermediate_") and "_batchsize_256" in key: + try: + size_str = key.split("_")[1] + size = int(size_str) + intermediate_sizes.append((abs(size - intermediate_size), key)) + except (ValueError, IndexError): + continue + + if intermediate_sizes: + _, best_key = min(intermediate_sizes) + logger.debug( + "No exact config for intermediate_size=%d, using closest match: %s", + intermediate_size, + best_key, + ) + return best_key + if "default" in config_keys: + return "default" + + return None + + +def silu_mul_fp8_baseline(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + output_shape = input.shape[:-1] + (input.shape[-1] // 2,) + out = torch.empty(output_shape, dtype=torch.float8_e4m3fn, device=input.device) + torch.ops._C.silu_and_mul_quant(out, input, scale) + return out