[Kernel] [Helion] [7/N] Use HOP to represent Helion Kernel call to enable fx tracing and pattern matching (#34390)

Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
Yanan Cao
2026-02-27 09:21:35 -08:00
committed by GitHub
parent 876312f0b5
commit 9098ce690c
3 changed files with 368 additions and 26 deletions

View File

@@ -0,0 +1,203 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test make_fx tracing and inductor pattern matching with HelionKernelWrapper."""
import contextlib
from unittest.mock import Mock, patch
import pytest
import torch
from vllm.utils.import_utils import has_helion
if not has_helion():
pytest.skip(
"Helion is not installed. Install with: pip install vllm[helion]",
allow_module_level=True,
)
import helion
import helion.language as hl
from helion._compat import requires_torch_version
if not requires_torch_version("2.11"):
pytest.skip(
"HigherOrderOp requires PyTorch >= 2.11",
allow_module_level=True,
)
from helion._compiler._dynamo.higher_order_ops import (
helion_kernel_side_table,
helion_kernel_wrapper_mutation,
)
from torch._inductor.pattern_matcher import (
PatternMatcherPass,
fwd_only,
register_replacement,
select_decomp_table,
)
from torch.fx.experimental.proxy_tensor import make_fx
from vllm.kernels.helion.config_manager import ConfigManager
from vllm.kernels.helion.register import HelionKernelWrapper
@contextlib.contextmanager
def _helion_mock_context():
configs = {
"default": helion.Config(block_sizes=[64], num_warps=2, num_stages=2),
}
mock_config_manager = Mock(spec=ConfigManager)
mock_config_manager.get_platform_configs = Mock(return_value=configs)
with (
patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance",
return_value=mock_config_manager,
),
patch(
"vllm.kernels.helion.utils.get_canonical_gpu_name",
return_value="nvidia_h200",
),
):
yield
class TestMakeFxHop:
def setup_method(self):
helion_kernel_side_table.reset_table()
def test_make_fx_symbolic(self):
def raw_add_scale(
x: torch.Tensor, y: torch.Tensor, scale: float
) -> tuple[torch.Tensor, int, torch.Tensor]:
out_x = torch.empty_like(x)
out_y = torch.empty_like(x)
for tile in hl.tile(x.size()):
out_x[tile] = x[tile] + y[tile] * scale
out_y[tile] = out_x[tile] * 2.0
return out_x, 42, out_y
input_x = torch.randn(7, 13)
input_y = torch.randn(7, 13)
scale = 0.5
with _helion_mock_context():
wrapper = HelionKernelWrapper(
raw_kernel_func=raw_add_scale,
op_name="test_make_fx",
fake_impl=lambda *a, **kw: None,
)
wrapper.register_config_picker(lambda args, keys: "default")
def fn(x, y):
return wrapper(x, y, scale)
gm = make_fx(fn, tracing_mode="symbolic")(input_x, input_y)
hop_nodes = [
n
for n in gm.graph.nodes
if n.op == "call_function" and n.target is helion_kernel_wrapper_mutation
]
assert len(hop_nodes) == 1
node = hop_nodes[0]
assert node.kwargs["constant_args"]["scale"] == scale
assert set(node.kwargs["tensor_args"]) == {"x", "y"}
specs = node.kwargs["output_spec"]["leaf_specs"]
tensor_specs = [s for s in specs if s["type"] == "tensor"]
scalar_specs = [s for s in specs if s["type"] == "scalar"]
assert len(tensor_specs) == 2
assert len(scalar_specs) == 1
for spec in tensor_specs:
assert spec["dtype"] == input_x.dtype
assert scalar_specs[0]["scalar_value"] == 42
for val in node.meta["val"]:
assert all(isinstance(s, torch.SymInt) for s in val.shape)
# Both out_x and out_y are empty_like(x), so output shapes == input shape
input_node = next(n for n in gm.graph.nodes if n.op == "placeholder")
input_shape = input_node.meta["val"].shape
for val in node.meta["val"]:
assert len(val.shape) == len(input_shape)
for out_s, in_s in zip(val.shape, input_shape):
assert out_s == in_s
def test_pattern_matcher_replaces_with_helion_hop(self):
def raw_silu_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
M, N = x.size()
out = torch.empty_like(x)
for tile_m, tile_n in hl.tile([M, N]):
out[tile_m, tile_n] = (
torch.nn.functional.silu(x[tile_m, tile_n]) * y[tile_m, tile_n]
)
return out
with _helion_mock_context():
wrapper = HelionKernelWrapper(
raw_kernel_func=raw_silu_mul,
op_name="test_pm_silu_mul",
fake_impl=lambda *a, **kw: None,
)
wrapper.register_config_picker(lambda args, keys: "default")
def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.silu(x) * y
def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return wrapper(x, y)
inputs = [torch.randn(8, 16), torch.randn(8, 16)]
pm_pass = PatternMatcherPass(pass_name="test_helion_replacement")
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
def model(x, y):
return torch.nn.functional.silu(x) * y
decompositions = select_decomp_table()
input_x = torch.randn(8, 16)
input_y = torch.randn(8, 16)
gm = make_fx(model, decompositions, tracing_mode="symbolic")(
input_x, input_y
)
def count_hop_nodes(graph):
return sum(
1
for n in graph.nodes
if n.op == "call_function"
and n.target is helion_kernel_wrapper_mutation
)
assert count_hop_nodes(gm.graph) == 0
match_count = pm_pass.apply(gm.graph)
gm.graph.lint()
gm.recompile()
assert match_count == 1
assert count_hop_nodes(gm.graph) == 1
hop_node = next(
n
for n in gm.graph.nodes
if n.op == "call_function"
and n.target is helion_kernel_wrapper_mutation
)
# raw_silu_mul returns empty_like(x), so output shape == input shape
for val in hop_node.meta["val"]:
assert all(isinstance(s, torch.SymInt) for s in val.shape)
input_node = next(n for n in gm.graph.nodes if n.op == "placeholder")
input_shape = input_node.meta["val"].shape
output_shape = hop_node.meta["val"][0].shape
assert len(output_shape) == len(input_shape)
for out_s, in_s in zip(output_shape, input_shape):
assert out_s == in_s

View File

@@ -4,8 +4,7 @@
Unit tests for Helion kernel registration.
Tests ConfiguredHelionKernel, HelionKernelWrapper, and PresetConfigSearch
including config picker registration, custom autotuner integration, and
PyTorch op registration.
including config picker registration and custom autotuner integration.
"""
from unittest.mock import Mock, patch
@@ -25,6 +24,7 @@ import helion
from vllm.kernels.helion.config_manager import ConfigManager
from vllm.kernels.helion.register import (
_HOP_AVAILABLE,
ConfiguredHelionKernel,
HelionKernelWrapper,
get_kernel_by_name,
@@ -451,8 +451,10 @@ class TestHelionKernelWrapper:
):
wrapper.get_configured_op()
def test_get_configured_op_returns_cached_op(self, sample_kernel, sample_configs):
"""Test get_configured_op returns cached op when already registered."""
def test_get_configured_op_returns_cached_kernel(
self, sample_kernel, sample_configs
):
"""Test get_configured_op returns cached ConfiguredHelionKernel."""
def fake_impl(*args, **kwargs):
return torch.zeros_like(args[0])
@@ -470,6 +472,46 @@ class TestHelionKernelWrapper:
mock_config_manager = Mock(spec=ConfigManager)
mock_config_manager.get_platform_configs = Mock(return_value=sample_configs)
with (
patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance",
return_value=mock_config_manager,
),
patch(
"vllm.kernels.helion.utils.get_canonical_gpu_name",
return_value="nvidia_h200",
),
patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
):
mock_decorated = Mock()
mock_kernel.return_value = Mock(return_value=mock_decorated)
result1 = wrapper.get_configured_op()
result2 = wrapper.get_configured_op()
assert result1 is result2
@pytest.mark.skipif(
_HOP_AVAILABLE, reason="CustomOp path not used when HOP available"
)
def test_get_or_register_custom_op_returns_cached_op(
self, sample_kernel, sample_configs
):
def fake_impl(*args, **kwargs):
return torch.zeros_like(args[0])
def default_picker(args, config_keys):
return "default"
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
)
wrapper._config_picker = default_picker
mock_config_manager = Mock(spec=ConfigManager)
mock_config_manager.get_platform_configs = Mock(return_value=sample_configs)
existing_op = Mock()
mock_namespace = Mock()
mock_namespace.test_kernel = existing_op
@@ -488,12 +530,15 @@ class TestHelionKernelWrapper:
):
mock_decorated = Mock()
mock_kernel.return_value = Mock(return_value=mock_decorated)
result = wrapper.get_configured_op()
result = wrapper._get_or_register_custom_op()
assert result is existing_op
def test_get_configured_op_registers_new_op(self, sample_kernel, sample_configs):
"""Test get_configured_op creates and registers new op."""
@pytest.mark.skipif(
_HOP_AVAILABLE, reason="CustomOp path not used when HOP available"
)
def test_get_or_register_custom_op_registers_new_op(
self, sample_kernel, sample_configs
):
def fake_impl(*args, **kwargs):
return torch.zeros_like(args[0])
@@ -542,11 +587,10 @@ class TestHelionKernelWrapper:
):
mock_decorated = Mock()
mock_kernel.return_value = Mock(return_value=mock_decorated)
result = wrapper.get_configured_op()
result = wrapper._get_or_register_custom_op()
mock_register.assert_called_once()
assert result is new_op
# Check that op_func is the decorated kernel, not ConfiguredHelionKernel
assert mock_register.call_args[1]["op_func"] is mock_decorated

View File

@@ -31,8 +31,8 @@ by key matches the config returned by the autotuner.
Key Classes
-----------
- HelionKernelWrapper: Wraps raw kernel + config_picker, creates configured ops
- ConfiguredHelionKernel: Platform-specific kernel registered as PyTorch custom op
- HelionKernelWrapper: Wraps raw kernel + config_picker, creates configured kernels
- ConfiguredHelionKernel: Platform-specific kernel with pre-tuned configs
- PresetConfigSearch: Custom autotuner that returns pre-tuned configs
"""
@@ -53,10 +53,27 @@ if not has_helion():
)
import helion
from helion._compat import requires_torch_version
from helion.autotuner.base_search import BaseAutotuner
from helion.runtime.config import Config
from helion.runtime.settings import default_autotuner_fn
# TODO(gmagogsfm): Remove CustomOp fallback path (_get_or_register_custom_op,
# vllm_helion_lib, direct_register_custom_op) once vLLM requires PyTorch >= 2.11.
_HOP_AVAILABLE = requires_torch_version("2.11")
if _HOP_AVAILABLE:
import torch.utils._pytree as pytree
from helion._compiler._dynamo.higher_order_ops import (
helion_kernel_side_table,
helion_kernel_wrapper_mutation,
)
from helion._compiler._dynamo.variables import infer_output_spec
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
get_proxy_mode,
)
logger = init_logger(__name__)
vllm_helion_lib = Library("vllm_helion", "FRAGMENT") # noqa
@@ -233,7 +250,7 @@ class ConfiguredHelionKernel:
class HelionKernelWrapper:
"""Wrapper for Helion kernels that creates config-specific PyTorch custom ops."""
"""Wrapper for Helion kernels with pre-tuned config selection and HOP support."""
def __init__(
self,
@@ -252,11 +269,86 @@ class HelionKernelWrapper:
self._config_picker: (
Callable[[tuple[Any, ...], list[str]], str | None] | None
) = None
self._configured_kernel: ConfiguredHelionKernel | None = None
self._input_generator: Callable[[], dict[str, tuple[Any, ...]]] | None = None
def __call__(self, *args, **kwargs):
configured_op = self.get_configured_op()
return configured_op(*args, **kwargs)
# CustomOp fallback: register as torch custom op for torch.compile
# compatibility on older PyTorch lacking HOP/EffectType support
if not _HOP_AVAILABLE:
custom_op = self._get_or_register_custom_op()
return custom_op(*args, **kwargs)
# HOP tracing: record HigherOrderOp in the FX graph
if get_proxy_mode() is not None:
return self._call_via_hop(args, kwargs)
# Eager: run the configured kernel directly
return self.get_configured_op()(*args, **kwargs)
def _call_via_hop(
self,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> Any:
kernel = self.get_configured_op()._decorated_kernel
kernel_idx = helion_kernel_side_table.add_kernel(kernel)
constant_args, tensor_args = self._partition_args(kernel, args, kwargs)
all_named = {**constant_args, **tensor_args}
full_args = tuple(
all_named.get(n, p.default)
for n, p in kernel.signature.parameters.items() # type: ignore[attr-defined]
if n in all_named or p.default is not p.empty
)
with disable_proxy_modes_tracing():
output_spec = infer_output_spec(kernel, full_args)
hop_result = helion_kernel_wrapper_mutation(
kernel_idx=kernel_idx,
constant_args=constant_args,
tensor_args=tensor_args,
output_spec=output_spec,
)
tree_spec_str = output_spec.get("tree_spec_str")
if tree_spec_str is None:
return None
tree_spec = pytree.treespec_loads(tree_spec_str)
hop_iter = iter(hop_result)
reconstructed = []
for spec in output_spec["leaf_specs"]:
is_constant_scalar = spec["type"] == "scalar" and not isinstance(
spec.get("scalar_value"), torch.SymInt
)
if is_constant_scalar:
reconstructed.append(spec["scalar_value"])
else:
reconstructed.append(next(hop_iter))
return pytree.tree_unflatten(reconstructed, tree_spec)
@staticmethod
def _partition_args(
kernel: Any,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> tuple[dict[str, Any], dict[str, Any]]:
constant_args: dict[str, Any] = {}
tensor_args: dict[str, Any] = {}
params = list(kernel.signature.parameters.keys())
for i, val in enumerate(args):
name = params[i]
if isinstance(val, torch.Tensor):
tensor_args[name] = val
else:
constant_args[name] = val
for name, val in kwargs.items():
if isinstance(val, torch.Tensor):
tensor_args[name] = val
else:
constant_args[name] = val
return constant_args, tensor_args
def register_config_picker(
self, picker_func: Callable[[tuple[Any, ...], list[str]], str | None]
@@ -309,29 +401,32 @@ class HelionKernelWrapper:
)
return autotune_kernel.autotune(inputs)
def get_configured_op(self) -> Any:
def get_configured_op(self) -> ConfiguredHelionKernel:
assert self._config_picker is not None, (
f"No config picker registered for kernel '{self.op_name}'. "
f"Use @{self.op_name}.register_config_picker to register one."
)
if self._configured_kernel is None:
self._configured_kernel = ConfiguredHelionKernel(
op_name=self.op_name,
config_picker=self._config_picker,
raw_kernel_func=self.raw_kernel_func,
helion_settings=self.helion_settings,
)
return self._configured_kernel
def _get_or_register_custom_op(self) -> Any:
if hasattr(torch.ops.vllm_helion, self.op_name):
logger.debug("Op vllm_helion::%s already registered", self.op_name)
return getattr(torch.ops.vllm_helion, self.op_name)
configured_kernel = ConfiguredHelionKernel(
op_name=self.op_name,
config_picker=self._config_picker,
raw_kernel_func=self.raw_kernel_func,
helion_settings=self.helion_settings,
)
configured_kernel = self.get_configured_op()
logger.info("Registering op: vllm_helion::%s", self.op_name)
direct_register_custom_op(
op_name=self.op_name,
op_func=configured_kernel._decorated_kernel, # Register decorated kernel
# TODO(gmagogsfm): Implement automatic mutation/aliasing detection
# for Helion kernels.
op_func=configured_kernel._decorated_kernel,
mutates_args=None,
fake_impl=self._fake_impl,
target_lib=vllm_helion_lib,