diff --git a/tests/kernels/helion/test_pattern_matching.py b/tests/kernels/helion/test_pattern_matching.py new file mode 100644 index 000000000..1cab249a1 --- /dev/null +++ b/tests/kernels/helion/test_pattern_matching.py @@ -0,0 +1,203 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Test make_fx tracing and inductor pattern matching with HelionKernelWrapper.""" + +import contextlib +from unittest.mock import Mock, patch + +import pytest +import torch + +from vllm.utils.import_utils import has_helion + +if not has_helion(): + pytest.skip( + "Helion is not installed. Install with: pip install vllm[helion]", + allow_module_level=True, + ) + +import helion +import helion.language as hl +from helion._compat import requires_torch_version + +if not requires_torch_version("2.11"): + pytest.skip( + "HigherOrderOp requires PyTorch >= 2.11", + allow_module_level=True, + ) + +from helion._compiler._dynamo.higher_order_ops import ( + helion_kernel_side_table, + helion_kernel_wrapper_mutation, +) +from torch._inductor.pattern_matcher import ( + PatternMatcherPass, + fwd_only, + register_replacement, + select_decomp_table, +) +from torch.fx.experimental.proxy_tensor import make_fx + +from vllm.kernels.helion.config_manager import ConfigManager +from vllm.kernels.helion.register import HelionKernelWrapper + + +@contextlib.contextmanager +def _helion_mock_context(): + configs = { + "default": helion.Config(block_sizes=[64], num_warps=2, num_stages=2), + } + mock_config_manager = Mock(spec=ConfigManager) + mock_config_manager.get_platform_configs = Mock(return_value=configs) + + with ( + patch( + "vllm.kernels.helion.config_manager.ConfigManager.get_instance", + return_value=mock_config_manager, + ), + patch( + "vllm.kernels.helion.utils.get_canonical_gpu_name", + return_value="nvidia_h200", + ), + ): + yield + + +class TestMakeFxHop: + def setup_method(self): + helion_kernel_side_table.reset_table() + + def test_make_fx_symbolic(self): + def raw_add_scale( + x: torch.Tensor, y: torch.Tensor, scale: float + ) -> tuple[torch.Tensor, int, torch.Tensor]: + out_x = torch.empty_like(x) + out_y = torch.empty_like(x) + for tile in hl.tile(x.size()): + out_x[tile] = x[tile] + y[tile] * scale + out_y[tile] = out_x[tile] * 2.0 + return out_x, 42, out_y + + input_x = torch.randn(7, 13) + input_y = torch.randn(7, 13) + scale = 0.5 + + with _helion_mock_context(): + wrapper = HelionKernelWrapper( + raw_kernel_func=raw_add_scale, + op_name="test_make_fx", + fake_impl=lambda *a, **kw: None, + ) + wrapper.register_config_picker(lambda args, keys: "default") + + def fn(x, y): + return wrapper(x, y, scale) + + gm = make_fx(fn, tracing_mode="symbolic")(input_x, input_y) + + hop_nodes = [ + n + for n in gm.graph.nodes + if n.op == "call_function" and n.target is helion_kernel_wrapper_mutation + ] + assert len(hop_nodes) == 1 + node = hop_nodes[0] + + assert node.kwargs["constant_args"]["scale"] == scale + assert set(node.kwargs["tensor_args"]) == {"x", "y"} + + specs = node.kwargs["output_spec"]["leaf_specs"] + tensor_specs = [s for s in specs if s["type"] == "tensor"] + scalar_specs = [s for s in specs if s["type"] == "scalar"] + assert len(tensor_specs) == 2 + assert len(scalar_specs) == 1 + + for spec in tensor_specs: + assert spec["dtype"] == input_x.dtype + + assert scalar_specs[0]["scalar_value"] == 42 + + for val in node.meta["val"]: + assert all(isinstance(s, torch.SymInt) for s in val.shape) + + # Both out_x and out_y are empty_like(x), so output shapes == input shape + input_node = next(n for n in gm.graph.nodes if n.op == "placeholder") + input_shape = input_node.meta["val"].shape + for val in node.meta["val"]: + assert len(val.shape) == len(input_shape) + for out_s, in_s in zip(val.shape, input_shape): + assert out_s == in_s + + def test_pattern_matcher_replaces_with_helion_hop(self): + def raw_silu_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + M, N = x.size() + out = torch.empty_like(x) + for tile_m, tile_n in hl.tile([M, N]): + out[tile_m, tile_n] = ( + torch.nn.functional.silu(x[tile_m, tile_n]) * y[tile_m, tile_n] + ) + return out + + with _helion_mock_context(): + wrapper = HelionKernelWrapper( + raw_kernel_func=raw_silu_mul, + op_name="test_pm_silu_mul", + fake_impl=lambda *a, **kw: None, + ) + wrapper.register_config_picker(lambda args, keys: "default") + + def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.silu(x) * y + + def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return wrapper(x, y) + + inputs = [torch.randn(8, 16), torch.randn(8, 16)] + + pm_pass = PatternMatcherPass(pass_name="test_helion_replacement") + register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) + + def model(x, y): + return torch.nn.functional.silu(x) * y + + decompositions = select_decomp_table() + input_x = torch.randn(8, 16) + input_y = torch.randn(8, 16) + gm = make_fx(model, decompositions, tracing_mode="symbolic")( + input_x, input_y + ) + + def count_hop_nodes(graph): + return sum( + 1 + for n in graph.nodes + if n.op == "call_function" + and n.target is helion_kernel_wrapper_mutation + ) + + assert count_hop_nodes(gm.graph) == 0 + + match_count = pm_pass.apply(gm.graph) + gm.graph.lint() + gm.recompile() + + assert match_count == 1 + assert count_hop_nodes(gm.graph) == 1 + + hop_node = next( + n + for n in gm.graph.nodes + if n.op == "call_function" + and n.target is helion_kernel_wrapper_mutation + ) + + # raw_silu_mul returns empty_like(x), so output shape == input shape + for val in hop_node.meta["val"]: + assert all(isinstance(s, torch.SymInt) for s in val.shape) + + input_node = next(n for n in gm.graph.nodes if n.op == "placeholder") + input_shape = input_node.meta["val"].shape + output_shape = hop_node.meta["val"][0].shape + assert len(output_shape) == len(input_shape) + for out_s, in_s in zip(output_shape, input_shape): + assert out_s == in_s diff --git a/tests/kernels/helion/test_register.py b/tests/kernels/helion/test_register.py index 02b05be74..bee72d58a 100644 --- a/tests/kernels/helion/test_register.py +++ b/tests/kernels/helion/test_register.py @@ -4,8 +4,7 @@ Unit tests for Helion kernel registration. Tests ConfiguredHelionKernel, HelionKernelWrapper, and PresetConfigSearch -including config picker registration, custom autotuner integration, and -PyTorch op registration. +including config picker registration and custom autotuner integration. """ from unittest.mock import Mock, patch @@ -25,6 +24,7 @@ import helion from vllm.kernels.helion.config_manager import ConfigManager from vllm.kernels.helion.register import ( + _HOP_AVAILABLE, ConfiguredHelionKernel, HelionKernelWrapper, get_kernel_by_name, @@ -451,8 +451,10 @@ class TestHelionKernelWrapper: ): wrapper.get_configured_op() - def test_get_configured_op_returns_cached_op(self, sample_kernel, sample_configs): - """Test get_configured_op returns cached op when already registered.""" + def test_get_configured_op_returns_cached_kernel( + self, sample_kernel, sample_configs + ): + """Test get_configured_op returns cached ConfiguredHelionKernel.""" def fake_impl(*args, **kwargs): return torch.zeros_like(args[0]) @@ -470,6 +472,46 @@ class TestHelionKernelWrapper: mock_config_manager = Mock(spec=ConfigManager) mock_config_manager.get_platform_configs = Mock(return_value=sample_configs) + with ( + patch( + "vllm.kernels.helion.config_manager.ConfigManager.get_instance", + return_value=mock_config_manager, + ), + patch( + "vllm.kernels.helion.utils.get_canonical_gpu_name", + return_value="nvidia_h200", + ), + patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel, + ): + mock_decorated = Mock() + mock_kernel.return_value = Mock(return_value=mock_decorated) + + result1 = wrapper.get_configured_op() + result2 = wrapper.get_configured_op() + assert result1 is result2 + + @pytest.mark.skipif( + _HOP_AVAILABLE, reason="CustomOp path not used when HOP available" + ) + def test_get_or_register_custom_op_returns_cached_op( + self, sample_kernel, sample_configs + ): + def fake_impl(*args, **kwargs): + return torch.zeros_like(args[0]) + + def default_picker(args, config_keys): + return "default" + + wrapper = HelionKernelWrapper( + raw_kernel_func=sample_kernel, + op_name="test_kernel", + fake_impl=fake_impl, + ) + wrapper._config_picker = default_picker + + mock_config_manager = Mock(spec=ConfigManager) + mock_config_manager.get_platform_configs = Mock(return_value=sample_configs) + existing_op = Mock() mock_namespace = Mock() mock_namespace.test_kernel = existing_op @@ -488,12 +530,15 @@ class TestHelionKernelWrapper: ): mock_decorated = Mock() mock_kernel.return_value = Mock(return_value=mock_decorated) - result = wrapper.get_configured_op() + result = wrapper._get_or_register_custom_op() assert result is existing_op - def test_get_configured_op_registers_new_op(self, sample_kernel, sample_configs): - """Test get_configured_op creates and registers new op.""" - + @pytest.mark.skipif( + _HOP_AVAILABLE, reason="CustomOp path not used when HOP available" + ) + def test_get_or_register_custom_op_registers_new_op( + self, sample_kernel, sample_configs + ): def fake_impl(*args, **kwargs): return torch.zeros_like(args[0]) @@ -542,11 +587,10 @@ class TestHelionKernelWrapper: ): mock_decorated = Mock() mock_kernel.return_value = Mock(return_value=mock_decorated) - result = wrapper.get_configured_op() + result = wrapper._get_or_register_custom_op() mock_register.assert_called_once() assert result is new_op - # Check that op_func is the decorated kernel, not ConfiguredHelionKernel assert mock_register.call_args[1]["op_func"] is mock_decorated diff --git a/vllm/kernels/helion/register.py b/vllm/kernels/helion/register.py index 3114631dd..cd0ef83fc 100644 --- a/vllm/kernels/helion/register.py +++ b/vllm/kernels/helion/register.py @@ -31,8 +31,8 @@ by key matches the config returned by the autotuner. Key Classes ----------- -- HelionKernelWrapper: Wraps raw kernel + config_picker, creates configured ops -- ConfiguredHelionKernel: Platform-specific kernel registered as PyTorch custom op +- HelionKernelWrapper: Wraps raw kernel + config_picker, creates configured kernels +- ConfiguredHelionKernel: Platform-specific kernel with pre-tuned configs - PresetConfigSearch: Custom autotuner that returns pre-tuned configs """ @@ -53,10 +53,27 @@ if not has_helion(): ) import helion +from helion._compat import requires_torch_version from helion.autotuner.base_search import BaseAutotuner from helion.runtime.config import Config from helion.runtime.settings import default_autotuner_fn +# TODO(gmagogsfm): Remove CustomOp fallback path (_get_or_register_custom_op, +# vllm_helion_lib, direct_register_custom_op) once vLLM requires PyTorch >= 2.11. +_HOP_AVAILABLE = requires_torch_version("2.11") + +if _HOP_AVAILABLE: + import torch.utils._pytree as pytree + from helion._compiler._dynamo.higher_order_ops import ( + helion_kernel_side_table, + helion_kernel_wrapper_mutation, + ) + from helion._compiler._dynamo.variables import infer_output_spec + from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + get_proxy_mode, + ) + logger = init_logger(__name__) vllm_helion_lib = Library("vllm_helion", "FRAGMENT") # noqa @@ -233,7 +250,7 @@ class ConfiguredHelionKernel: class HelionKernelWrapper: - """Wrapper for Helion kernels that creates config-specific PyTorch custom ops.""" + """Wrapper for Helion kernels with pre-tuned config selection and HOP support.""" def __init__( self, @@ -252,11 +269,86 @@ class HelionKernelWrapper: self._config_picker: ( Callable[[tuple[Any, ...], list[str]], str | None] | None ) = None + self._configured_kernel: ConfiguredHelionKernel | None = None self._input_generator: Callable[[], dict[str, tuple[Any, ...]]] | None = None def __call__(self, *args, **kwargs): - configured_op = self.get_configured_op() - return configured_op(*args, **kwargs) + # CustomOp fallback: register as torch custom op for torch.compile + # compatibility on older PyTorch lacking HOP/EffectType support + if not _HOP_AVAILABLE: + custom_op = self._get_or_register_custom_op() + return custom_op(*args, **kwargs) + # HOP tracing: record HigherOrderOp in the FX graph + if get_proxy_mode() is not None: + return self._call_via_hop(args, kwargs) + # Eager: run the configured kernel directly + return self.get_configured_op()(*args, **kwargs) + + def _call_via_hop( + self, + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> Any: + kernel = self.get_configured_op()._decorated_kernel + kernel_idx = helion_kernel_side_table.add_kernel(kernel) + + constant_args, tensor_args = self._partition_args(kernel, args, kwargs) + + all_named = {**constant_args, **tensor_args} + full_args = tuple( + all_named.get(n, p.default) + for n, p in kernel.signature.parameters.items() # type: ignore[attr-defined] + if n in all_named or p.default is not p.empty + ) + + with disable_proxy_modes_tracing(): + output_spec = infer_output_spec(kernel, full_args) + + hop_result = helion_kernel_wrapper_mutation( + kernel_idx=kernel_idx, + constant_args=constant_args, + tensor_args=tensor_args, + output_spec=output_spec, + ) + + tree_spec_str = output_spec.get("tree_spec_str") + if tree_spec_str is None: + return None + tree_spec = pytree.treespec_loads(tree_spec_str) + + hop_iter = iter(hop_result) + reconstructed = [] + for spec in output_spec["leaf_specs"]: + is_constant_scalar = spec["type"] == "scalar" and not isinstance( + spec.get("scalar_value"), torch.SymInt + ) + if is_constant_scalar: + reconstructed.append(spec["scalar_value"]) + else: + reconstructed.append(next(hop_iter)) + return pytree.tree_unflatten(reconstructed, tree_spec) + + @staticmethod + def _partition_args( + kernel: Any, + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> tuple[dict[str, Any], dict[str, Any]]: + constant_args: dict[str, Any] = {} + tensor_args: dict[str, Any] = {} + params = list(kernel.signature.parameters.keys()) + for i, val in enumerate(args): + name = params[i] + if isinstance(val, torch.Tensor): + tensor_args[name] = val + else: + constant_args[name] = val + for name, val in kwargs.items(): + if isinstance(val, torch.Tensor): + tensor_args[name] = val + else: + constant_args[name] = val + return constant_args, tensor_args def register_config_picker( self, picker_func: Callable[[tuple[Any, ...], list[str]], str | None] @@ -309,29 +401,32 @@ class HelionKernelWrapper: ) return autotune_kernel.autotune(inputs) - def get_configured_op(self) -> Any: + def get_configured_op(self) -> ConfiguredHelionKernel: assert self._config_picker is not None, ( f"No config picker registered for kernel '{self.op_name}'. " f"Use @{self.op_name}.register_config_picker to register one." ) + if self._configured_kernel is None: + self._configured_kernel = ConfiguredHelionKernel( + op_name=self.op_name, + config_picker=self._config_picker, + raw_kernel_func=self.raw_kernel_func, + helion_settings=self.helion_settings, + ) + + return self._configured_kernel + + def _get_or_register_custom_op(self) -> Any: if hasattr(torch.ops.vllm_helion, self.op_name): - logger.debug("Op vllm_helion::%s already registered", self.op_name) return getattr(torch.ops.vllm_helion, self.op_name) - configured_kernel = ConfiguredHelionKernel( - op_name=self.op_name, - config_picker=self._config_picker, - raw_kernel_func=self.raw_kernel_func, - helion_settings=self.helion_settings, - ) + configured_kernel = self.get_configured_op() logger.info("Registering op: vllm_helion::%s", self.op_name) direct_register_custom_op( op_name=self.op_name, - op_func=configured_kernel._decorated_kernel, # Register decorated kernel - # TODO(gmagogsfm): Implement automatic mutation/aliasing detection - # for Helion kernels. + op_func=configured_kernel._decorated_kernel, mutates_args=None, fake_impl=self._fake_impl, target_lib=vllm_helion_lib,