[Kernel] [Helion] [17/N] Add Helion kernel torch.compile support (#38592)

Signed-off-by: Yanan Cao <gmagogsfm@gmail.com>
Co-authored-by: Claude Sonnet 4 <noreply@anthropic.com>
This commit is contained in:
Yanan Cao
2026-03-31 14:06:42 -07:00
committed by GitHub
parent 856589ed9a
commit cc671cb110
2 changed files with 98 additions and 78 deletions

View File

@@ -35,6 +35,11 @@ from vllm.kernels.helion.register import (
validate_helion_settings,
)
if _HOP_AVAILABLE:
from helion._compiler._dynamo.higher_order_ops import (
helion_kernel_wrapper_mutation,
)
def _add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
@@ -941,3 +946,60 @@ class TestKernelRegistry:
registered = get_registered_kernels()
assert "disabled_kernel" in registered
assert registered["disabled_kernel"] is wrapper
@pytest.mark.skipif(not _HOP_AVAILABLE, reason="Requires PyTorch >= 2.11 for HOP")
class TestTorchCompileHOP:
"""Test that HelionKernelWrapper emits the correct HOP under torch.compile."""
def test_compiled_graph_contains_helion_hop(self):
"""Verify torch.compile on a HelionKernelWrapper emits a
helion_kernel_wrapper_mutation HOP node in the FX graph."""
configs = {"default": helion.Config(block_sizes=[4, 4])}
with dummy_kernel_registry(configs=configs) as register:
add_helion_kernel = register(
op_name="test_torch_compile_add_kernel",
config_picker=lambda args, keys: "default",
)(_add_kernel)
captured_graph: torch.fx.GraphModule | None = None
def capturing_backend(gm, example_inputs):
nonlocal captured_graph
assert captured_graph is None, "Backend called multiple times"
captured_graph = gm
return gm.forward
def f(x, y):
return add_helion_kernel(x, y)
torch._dynamo.reset()
compiled_f = torch.compile(f, backend=capturing_backend, fullgraph=True)
x = torch.randn(4, 4, device="cuda")
y = torch.randn(4, 4, device="cuda")
# Run compiled version and capture graph
compiled_result = compiled_f(x, y)
assert captured_graph is not None
hop_nodes = [
node
for node in captured_graph.graph.nodes
if node.op == "call_function"
and node.target is helion_kernel_wrapper_mutation
]
assert len(hop_nodes) > 0, (
"Expected helion_kernel_wrapper_mutation HOP node in compiled graph, "
f"but found none. Graph nodes: "
f"{[(n.op, n.target) for n in captured_graph.graph.nodes]}"
)
# Verify compiled result matches eager execution
eager_result = f(x, y) # Run in eager mode
assert torch.allclose(compiled_result, eager_result, atol=1e-5, rtol=1e-5), (
"Compiled execution result doesn't match eager execution. "
f"Max difference: {torch.max(torch.abs(compiled_result - eager_result))}"
)

View File

@@ -63,16 +63,11 @@ from helion.runtime.settings import default_autotuner_fn
_HOP_AVAILABLE = requires_torch_version("2.11")
if _HOP_AVAILABLE:
import torch.utils._pytree as pytree
from helion._compiler._dynamo.higher_order_ops import (
helion_kernel_side_table,
helion_kernel_wrapper_mutation,
)
from helion._compiler._dynamo.variables import infer_output_spec
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
get_proxy_mode,
)
from helion._compiler._dynamo.higher_order_ops import helion_kernel_side_table
from helion._compiler._dynamo.variables import HelionKernelVariable
from torch._dynamo.guards import GuardBuilder
from torch._dynamo.variables.builder import VariableBuilder
logger = init_logger(__name__)
@@ -298,76 +293,12 @@ class HelionKernelWrapper:
f"Kernel '{self.op_name}' was not initialized. "
"Please open an issue on GitHub."
)
if get_proxy_mode() is not None:
return self._call_via_hop(args, kwargs)
# During Dynamo tracing, this call will be intercepted by our custom
# HelionKernelWrapperVariable and handled via proper HOP emission.
# During eager execution, call the kernel directly.
return self._configured_kernel(*args, **kwargs)
def _call_via_hop(
self,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> Any:
kernel = self.get_configured_op()._decorated_kernel
kernel_idx = helion_kernel_side_table.add_kernel(kernel)
constant_args, tensor_args = self._partition_args(kernel, args, kwargs)
all_named = {**constant_args, **tensor_args}
full_args = tuple(
all_named.get(n, p.default)
for n, p in kernel.signature.parameters.items() # type: ignore[attr-defined]
if n in all_named or p.default is not p.empty
)
with disable_proxy_modes_tracing():
output_spec = infer_output_spec(kernel, full_args)
hop_result = helion_kernel_wrapper_mutation(
kernel_idx=kernel_idx,
constant_args=constant_args,
tensor_args=tensor_args,
output_spec=output_spec,
)
tree_spec_str = output_spec.get("tree_spec_str")
if tree_spec_str is None:
return None
tree_spec = pytree.treespec_loads(tree_spec_str)
hop_iter = iter(hop_result)
reconstructed = []
for spec in output_spec["leaf_specs"]:
is_constant_scalar = spec["type"] == "scalar" and not isinstance(
spec.get("scalar_value"), torch.SymInt
)
if is_constant_scalar:
reconstructed.append(spec["scalar_value"])
else:
reconstructed.append(next(hop_iter))
return pytree.tree_unflatten(reconstructed, tree_spec)
@staticmethod
def _partition_args(
kernel: Any,
args: tuple[Any, ...],
kwargs: dict[str, Any],
) -> tuple[dict[str, Any], dict[str, Any]]:
constant_args: dict[str, Any] = {}
tensor_args: dict[str, Any] = {}
params = list(kernel.signature.parameters.keys())
for i, val in enumerate(args):
name = params[i]
if isinstance(val, torch.Tensor):
tensor_args[name] = val
else:
constant_args[name] = val
for name, val in kwargs.items():
if isinstance(val, torch.Tensor):
tensor_args[name] = val
else:
constant_args[name] = val
return constant_args, tensor_args
def get_inputs(self) -> dict[str, tuple[Any, ...]]:
if self._input_generator is None:
raise NotImplementedError(
@@ -535,3 +466,30 @@ def register_kernel(
return kernel_wrapper
return decorator
# Register HelionKernelWrapper with Dynamo's variable tracker system
if _HOP_AVAILABLE:
def _register_vllm_helion_dynamo_variable():
"""Register HelionKernelWrapper with Dynamo's VariableBuilder.
When Dynamo encounters a HelionKernelWrapper during tracing, this
extracts the underlying Helion Kernel, registers it in the side table,
and returns Helion's own HelionKernelVariable to handle HOP emission.
"""
def wrap_helion_kernel_wrapper(
builder: VariableBuilder, value: HelionKernelWrapper
):
kernel = value.get_configured_op()._decorated_kernel
kernel_idx = helion_kernel_side_table.add_kernel(kernel)
builder.install_guards(GuardBuilder.ID_MATCH)
return HelionKernelVariable(kernel, kernel_idx, source=builder.source)
# Register with Dynamo's type dispatch system
dispatch = VariableBuilder._type_dispatch()
dispatch[HelionKernelWrapper] = wrap_helion_kernel_wrapper
# Register immediately when the module is imported
_register_vllm_helion_dynamo_variable()