[Kernel] [Helion] [7/N] Use HOP to represent Helion Kernel call to enable fx tracing and pattern matching (#34390)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
This commit is contained in:
203
tests/kernels/helion/test_pattern_matching.py
Normal file
203
tests/kernels/helion/test_pattern_matching.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Test make_fx tracing and inductor pattern matching with HelionKernelWrapper."""
|
||||
|
||||
import contextlib
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.utils.import_utils import has_helion
|
||||
|
||||
if not has_helion():
|
||||
pytest.skip(
|
||||
"Helion is not installed. Install with: pip install vllm[helion]",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
import helion
|
||||
import helion.language as hl
|
||||
from helion._compat import requires_torch_version
|
||||
|
||||
if not requires_torch_version("2.11"):
|
||||
pytest.skip(
|
||||
"HigherOrderOp requires PyTorch >= 2.11",
|
||||
allow_module_level=True,
|
||||
)
|
||||
|
||||
from helion._compiler._dynamo.higher_order_ops import (
|
||||
helion_kernel_side_table,
|
||||
helion_kernel_wrapper_mutation,
|
||||
)
|
||||
from torch._inductor.pattern_matcher import (
|
||||
PatternMatcherPass,
|
||||
fwd_only,
|
||||
register_replacement,
|
||||
select_decomp_table,
|
||||
)
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
|
||||
from vllm.kernels.helion.config_manager import ConfigManager
|
||||
from vllm.kernels.helion.register import HelionKernelWrapper
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _helion_mock_context():
|
||||
configs = {
|
||||
"default": helion.Config(block_sizes=[64], num_warps=2, num_stages=2),
|
||||
}
|
||||
mock_config_manager = Mock(spec=ConfigManager)
|
||||
mock_config_manager.get_platform_configs = Mock(return_value=configs)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"vllm.kernels.helion.config_manager.ConfigManager.get_instance",
|
||||
return_value=mock_config_manager,
|
||||
),
|
||||
patch(
|
||||
"vllm.kernels.helion.utils.get_canonical_gpu_name",
|
||||
return_value="nvidia_h200",
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
class TestMakeFxHop:
|
||||
def setup_method(self):
|
||||
helion_kernel_side_table.reset_table()
|
||||
|
||||
def test_make_fx_symbolic(self):
|
||||
def raw_add_scale(
|
||||
x: torch.Tensor, y: torch.Tensor, scale: float
|
||||
) -> tuple[torch.Tensor, int, torch.Tensor]:
|
||||
out_x = torch.empty_like(x)
|
||||
out_y = torch.empty_like(x)
|
||||
for tile in hl.tile(x.size()):
|
||||
out_x[tile] = x[tile] + y[tile] * scale
|
||||
out_y[tile] = out_x[tile] * 2.0
|
||||
return out_x, 42, out_y
|
||||
|
||||
input_x = torch.randn(7, 13)
|
||||
input_y = torch.randn(7, 13)
|
||||
scale = 0.5
|
||||
|
||||
with _helion_mock_context():
|
||||
wrapper = HelionKernelWrapper(
|
||||
raw_kernel_func=raw_add_scale,
|
||||
op_name="test_make_fx",
|
||||
fake_impl=lambda *a, **kw: None,
|
||||
)
|
||||
wrapper.register_config_picker(lambda args, keys: "default")
|
||||
|
||||
def fn(x, y):
|
||||
return wrapper(x, y, scale)
|
||||
|
||||
gm = make_fx(fn, tracing_mode="symbolic")(input_x, input_y)
|
||||
|
||||
hop_nodes = [
|
||||
n
|
||||
for n in gm.graph.nodes
|
||||
if n.op == "call_function" and n.target is helion_kernel_wrapper_mutation
|
||||
]
|
||||
assert len(hop_nodes) == 1
|
||||
node = hop_nodes[0]
|
||||
|
||||
assert node.kwargs["constant_args"]["scale"] == scale
|
||||
assert set(node.kwargs["tensor_args"]) == {"x", "y"}
|
||||
|
||||
specs = node.kwargs["output_spec"]["leaf_specs"]
|
||||
tensor_specs = [s for s in specs if s["type"] == "tensor"]
|
||||
scalar_specs = [s for s in specs if s["type"] == "scalar"]
|
||||
assert len(tensor_specs) == 2
|
||||
assert len(scalar_specs) == 1
|
||||
|
||||
for spec in tensor_specs:
|
||||
assert spec["dtype"] == input_x.dtype
|
||||
|
||||
assert scalar_specs[0]["scalar_value"] == 42
|
||||
|
||||
for val in node.meta["val"]:
|
||||
assert all(isinstance(s, torch.SymInt) for s in val.shape)
|
||||
|
||||
# Both out_x and out_y are empty_like(x), so output shapes == input shape
|
||||
input_node = next(n for n in gm.graph.nodes if n.op == "placeholder")
|
||||
input_shape = input_node.meta["val"].shape
|
||||
for val in node.meta["val"]:
|
||||
assert len(val.shape) == len(input_shape)
|
||||
for out_s, in_s in zip(val.shape, input_shape):
|
||||
assert out_s == in_s
|
||||
|
||||
def test_pattern_matcher_replaces_with_helion_hop(self):
|
||||
def raw_silu_mul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
M, N = x.size()
|
||||
out = torch.empty_like(x)
|
||||
for tile_m, tile_n in hl.tile([M, N]):
|
||||
out[tile_m, tile_n] = (
|
||||
torch.nn.functional.silu(x[tile_m, tile_n]) * y[tile_m, tile_n]
|
||||
)
|
||||
return out
|
||||
|
||||
with _helion_mock_context():
|
||||
wrapper = HelionKernelWrapper(
|
||||
raw_kernel_func=raw_silu_mul,
|
||||
op_name="test_pm_silu_mul",
|
||||
fake_impl=lambda *a, **kw: None,
|
||||
)
|
||||
wrapper.register_config_picker(lambda args, keys: "default")
|
||||
|
||||
def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return torch.nn.functional.silu(x) * y
|
||||
|
||||
def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return wrapper(x, y)
|
||||
|
||||
inputs = [torch.randn(8, 16), torch.randn(8, 16)]
|
||||
|
||||
pm_pass = PatternMatcherPass(pass_name="test_helion_replacement")
|
||||
register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)
|
||||
|
||||
def model(x, y):
|
||||
return torch.nn.functional.silu(x) * y
|
||||
|
||||
decompositions = select_decomp_table()
|
||||
input_x = torch.randn(8, 16)
|
||||
input_y = torch.randn(8, 16)
|
||||
gm = make_fx(model, decompositions, tracing_mode="symbolic")(
|
||||
input_x, input_y
|
||||
)
|
||||
|
||||
def count_hop_nodes(graph):
|
||||
return sum(
|
||||
1
|
||||
for n in graph.nodes
|
||||
if n.op == "call_function"
|
||||
and n.target is helion_kernel_wrapper_mutation
|
||||
)
|
||||
|
||||
assert count_hop_nodes(gm.graph) == 0
|
||||
|
||||
match_count = pm_pass.apply(gm.graph)
|
||||
gm.graph.lint()
|
||||
gm.recompile()
|
||||
|
||||
assert match_count == 1
|
||||
assert count_hop_nodes(gm.graph) == 1
|
||||
|
||||
hop_node = next(
|
||||
n
|
||||
for n in gm.graph.nodes
|
||||
if n.op == "call_function"
|
||||
and n.target is helion_kernel_wrapper_mutation
|
||||
)
|
||||
|
||||
# raw_silu_mul returns empty_like(x), so output shape == input shape
|
||||
for val in hop_node.meta["val"]:
|
||||
assert all(isinstance(s, torch.SymInt) for s in val.shape)
|
||||
|
||||
input_node = next(n for n in gm.graph.nodes if n.op == "placeholder")
|
||||
input_shape = input_node.meta["val"].shape
|
||||
output_shape = hop_node.meta["val"][0].shape
|
||||
assert len(output_shape) == len(input_shape)
|
||||
for out_s, in_s in zip(output_shape, input_shape):
|
||||
assert out_s == in_s
|
||||
@@ -4,8 +4,7 @@
|
||||
Unit tests for Helion kernel registration.
|
||||
|
||||
Tests ConfiguredHelionKernel, HelionKernelWrapper, and PresetConfigSearch
|
||||
including config picker registration, custom autotuner integration, and
|
||||
PyTorch op registration.
|
||||
including config picker registration and custom autotuner integration.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
@@ -25,6 +24,7 @@ import helion
|
||||
|
||||
from vllm.kernels.helion.config_manager import ConfigManager
|
||||
from vllm.kernels.helion.register import (
|
||||
_HOP_AVAILABLE,
|
||||
ConfiguredHelionKernel,
|
||||
HelionKernelWrapper,
|
||||
get_kernel_by_name,
|
||||
@@ -451,8 +451,10 @@ class TestHelionKernelWrapper:
|
||||
):
|
||||
wrapper.get_configured_op()
|
||||
|
||||
def test_get_configured_op_returns_cached_op(self, sample_kernel, sample_configs):
|
||||
"""Test get_configured_op returns cached op when already registered."""
|
||||
def test_get_configured_op_returns_cached_kernel(
|
||||
self, sample_kernel, sample_configs
|
||||
):
|
||||
"""Test get_configured_op returns cached ConfiguredHelionKernel."""
|
||||
|
||||
def fake_impl(*args, **kwargs):
|
||||
return torch.zeros_like(args[0])
|
||||
@@ -470,6 +472,46 @@ class TestHelionKernelWrapper:
|
||||
mock_config_manager = Mock(spec=ConfigManager)
|
||||
mock_config_manager.get_platform_configs = Mock(return_value=sample_configs)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"vllm.kernels.helion.config_manager.ConfigManager.get_instance",
|
||||
return_value=mock_config_manager,
|
||||
),
|
||||
patch(
|
||||
"vllm.kernels.helion.utils.get_canonical_gpu_name",
|
||||
return_value="nvidia_h200",
|
||||
),
|
||||
patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
|
||||
):
|
||||
mock_decorated = Mock()
|
||||
mock_kernel.return_value = Mock(return_value=mock_decorated)
|
||||
|
||||
result1 = wrapper.get_configured_op()
|
||||
result2 = wrapper.get_configured_op()
|
||||
assert result1 is result2
|
||||
|
||||
@pytest.mark.skipif(
|
||||
_HOP_AVAILABLE, reason="CustomOp path not used when HOP available"
|
||||
)
|
||||
def test_get_or_register_custom_op_returns_cached_op(
|
||||
self, sample_kernel, sample_configs
|
||||
):
|
||||
def fake_impl(*args, **kwargs):
|
||||
return torch.zeros_like(args[0])
|
||||
|
||||
def default_picker(args, config_keys):
|
||||
return "default"
|
||||
|
||||
wrapper = HelionKernelWrapper(
|
||||
raw_kernel_func=sample_kernel,
|
||||
op_name="test_kernel",
|
||||
fake_impl=fake_impl,
|
||||
)
|
||||
wrapper._config_picker = default_picker
|
||||
|
||||
mock_config_manager = Mock(spec=ConfigManager)
|
||||
mock_config_manager.get_platform_configs = Mock(return_value=sample_configs)
|
||||
|
||||
existing_op = Mock()
|
||||
mock_namespace = Mock()
|
||||
mock_namespace.test_kernel = existing_op
|
||||
@@ -488,12 +530,15 @@ class TestHelionKernelWrapper:
|
||||
):
|
||||
mock_decorated = Mock()
|
||||
mock_kernel.return_value = Mock(return_value=mock_decorated)
|
||||
result = wrapper.get_configured_op()
|
||||
result = wrapper._get_or_register_custom_op()
|
||||
assert result is existing_op
|
||||
|
||||
def test_get_configured_op_registers_new_op(self, sample_kernel, sample_configs):
|
||||
"""Test get_configured_op creates and registers new op."""
|
||||
|
||||
@pytest.mark.skipif(
|
||||
_HOP_AVAILABLE, reason="CustomOp path not used when HOP available"
|
||||
)
|
||||
def test_get_or_register_custom_op_registers_new_op(
|
||||
self, sample_kernel, sample_configs
|
||||
):
|
||||
def fake_impl(*args, **kwargs):
|
||||
return torch.zeros_like(args[0])
|
||||
|
||||
@@ -542,11 +587,10 @@ class TestHelionKernelWrapper:
|
||||
):
|
||||
mock_decorated = Mock()
|
||||
mock_kernel.return_value = Mock(return_value=mock_decorated)
|
||||
result = wrapper.get_configured_op()
|
||||
result = wrapper._get_or_register_custom_op()
|
||||
|
||||
mock_register.assert_called_once()
|
||||
assert result is new_op
|
||||
# Check that op_func is the decorated kernel, not ConfiguredHelionKernel
|
||||
assert mock_register.call_args[1]["op_func"] is mock_decorated
|
||||
|
||||
|
||||
|
||||
@@ -31,8 +31,8 @@ by key matches the config returned by the autotuner.
|
||||
|
||||
Key Classes
|
||||
-----------
|
||||
- HelionKernelWrapper: Wraps raw kernel + config_picker, creates configured ops
|
||||
- ConfiguredHelionKernel: Platform-specific kernel registered as PyTorch custom op
|
||||
- HelionKernelWrapper: Wraps raw kernel + config_picker, creates configured kernels
|
||||
- ConfiguredHelionKernel: Platform-specific kernel with pre-tuned configs
|
||||
- PresetConfigSearch: Custom autotuner that returns pre-tuned configs
|
||||
"""
|
||||
|
||||
@@ -53,10 +53,27 @@ if not has_helion():
|
||||
)
|
||||
|
||||
import helion
|
||||
from helion._compat import requires_torch_version
|
||||
from helion.autotuner.base_search import BaseAutotuner
|
||||
from helion.runtime.config import Config
|
||||
from helion.runtime.settings import default_autotuner_fn
|
||||
|
||||
# TODO(gmagogsfm): Remove CustomOp fallback path (_get_or_register_custom_op,
|
||||
# vllm_helion_lib, direct_register_custom_op) once vLLM requires PyTorch >= 2.11.
|
||||
_HOP_AVAILABLE = requires_torch_version("2.11")
|
||||
|
||||
if _HOP_AVAILABLE:
|
||||
import torch.utils._pytree as pytree
|
||||
from helion._compiler._dynamo.higher_order_ops import (
|
||||
helion_kernel_side_table,
|
||||
helion_kernel_wrapper_mutation,
|
||||
)
|
||||
from helion._compiler._dynamo.variables import infer_output_spec
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
disable_proxy_modes_tracing,
|
||||
get_proxy_mode,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
vllm_helion_lib = Library("vllm_helion", "FRAGMENT") # noqa
|
||||
@@ -233,7 +250,7 @@ class ConfiguredHelionKernel:
|
||||
|
||||
|
||||
class HelionKernelWrapper:
|
||||
"""Wrapper for Helion kernels that creates config-specific PyTorch custom ops."""
|
||||
"""Wrapper for Helion kernels with pre-tuned config selection and HOP support."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -252,11 +269,86 @@ class HelionKernelWrapper:
|
||||
self._config_picker: (
|
||||
Callable[[tuple[Any, ...], list[str]], str | None] | None
|
||||
) = None
|
||||
self._configured_kernel: ConfiguredHelionKernel | None = None
|
||||
self._input_generator: Callable[[], dict[str, tuple[Any, ...]]] | None = None
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
configured_op = self.get_configured_op()
|
||||
return configured_op(*args, **kwargs)
|
||||
# CustomOp fallback: register as torch custom op for torch.compile
|
||||
# compatibility on older PyTorch lacking HOP/EffectType support
|
||||
if not _HOP_AVAILABLE:
|
||||
custom_op = self._get_or_register_custom_op()
|
||||
return custom_op(*args, **kwargs)
|
||||
# HOP tracing: record HigherOrderOp in the FX graph
|
||||
if get_proxy_mode() is not None:
|
||||
return self._call_via_hop(args, kwargs)
|
||||
# Eager: run the configured kernel directly
|
||||
return self.get_configured_op()(*args, **kwargs)
|
||||
|
||||
def _call_via_hop(
|
||||
self,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
kernel = self.get_configured_op()._decorated_kernel
|
||||
kernel_idx = helion_kernel_side_table.add_kernel(kernel)
|
||||
|
||||
constant_args, tensor_args = self._partition_args(kernel, args, kwargs)
|
||||
|
||||
all_named = {**constant_args, **tensor_args}
|
||||
full_args = tuple(
|
||||
all_named.get(n, p.default)
|
||||
for n, p in kernel.signature.parameters.items() # type: ignore[attr-defined]
|
||||
if n in all_named or p.default is not p.empty
|
||||
)
|
||||
|
||||
with disable_proxy_modes_tracing():
|
||||
output_spec = infer_output_spec(kernel, full_args)
|
||||
|
||||
hop_result = helion_kernel_wrapper_mutation(
|
||||
kernel_idx=kernel_idx,
|
||||
constant_args=constant_args,
|
||||
tensor_args=tensor_args,
|
||||
output_spec=output_spec,
|
||||
)
|
||||
|
||||
tree_spec_str = output_spec.get("tree_spec_str")
|
||||
if tree_spec_str is None:
|
||||
return None
|
||||
tree_spec = pytree.treespec_loads(tree_spec_str)
|
||||
|
||||
hop_iter = iter(hop_result)
|
||||
reconstructed = []
|
||||
for spec in output_spec["leaf_specs"]:
|
||||
is_constant_scalar = spec["type"] == "scalar" and not isinstance(
|
||||
spec.get("scalar_value"), torch.SymInt
|
||||
)
|
||||
if is_constant_scalar:
|
||||
reconstructed.append(spec["scalar_value"])
|
||||
else:
|
||||
reconstructed.append(next(hop_iter))
|
||||
return pytree.tree_unflatten(reconstructed, tree_spec)
|
||||
|
||||
@staticmethod
|
||||
def _partition_args(
|
||||
kernel: Any,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
constant_args: dict[str, Any] = {}
|
||||
tensor_args: dict[str, Any] = {}
|
||||
params = list(kernel.signature.parameters.keys())
|
||||
for i, val in enumerate(args):
|
||||
name = params[i]
|
||||
if isinstance(val, torch.Tensor):
|
||||
tensor_args[name] = val
|
||||
else:
|
||||
constant_args[name] = val
|
||||
for name, val in kwargs.items():
|
||||
if isinstance(val, torch.Tensor):
|
||||
tensor_args[name] = val
|
||||
else:
|
||||
constant_args[name] = val
|
||||
return constant_args, tensor_args
|
||||
|
||||
def register_config_picker(
|
||||
self, picker_func: Callable[[tuple[Any, ...], list[str]], str | None]
|
||||
@@ -309,29 +401,32 @@ class HelionKernelWrapper:
|
||||
)
|
||||
return autotune_kernel.autotune(inputs)
|
||||
|
||||
def get_configured_op(self) -> Any:
|
||||
def get_configured_op(self) -> ConfiguredHelionKernel:
|
||||
assert self._config_picker is not None, (
|
||||
f"No config picker registered for kernel '{self.op_name}'. "
|
||||
f"Use @{self.op_name}.register_config_picker to register one."
|
||||
)
|
||||
|
||||
if self._configured_kernel is None:
|
||||
self._configured_kernel = ConfiguredHelionKernel(
|
||||
op_name=self.op_name,
|
||||
config_picker=self._config_picker,
|
||||
raw_kernel_func=self.raw_kernel_func,
|
||||
helion_settings=self.helion_settings,
|
||||
)
|
||||
|
||||
return self._configured_kernel
|
||||
|
||||
def _get_or_register_custom_op(self) -> Any:
|
||||
if hasattr(torch.ops.vllm_helion, self.op_name):
|
||||
logger.debug("Op vllm_helion::%s already registered", self.op_name)
|
||||
return getattr(torch.ops.vllm_helion, self.op_name)
|
||||
|
||||
configured_kernel = ConfiguredHelionKernel(
|
||||
op_name=self.op_name,
|
||||
config_picker=self._config_picker,
|
||||
raw_kernel_func=self.raw_kernel_func,
|
||||
helion_settings=self.helion_settings,
|
||||
)
|
||||
configured_kernel = self.get_configured_op()
|
||||
|
||||
logger.info("Registering op: vllm_helion::%s", self.op_name)
|
||||
direct_register_custom_op(
|
||||
op_name=self.op_name,
|
||||
op_func=configured_kernel._decorated_kernel, # Register decorated kernel
|
||||
# TODO(gmagogsfm): Implement automatic mutation/aliasing detection
|
||||
# for Helion kernels.
|
||||
op_func=configured_kernel._decorated_kernel,
|
||||
mutates_args=None,
|
||||
fake_impl=self._fake_impl,
|
||||
target_lib=vllm_helion_lib,
|
||||
|
||||
Reference in New Issue
Block a user