diff --git a/tests/kernels/helion/test_register.py b/tests/kernels/helion/test_register.py index cb1e66d9e..c7f93993c 100644 --- a/tests/kernels/helion/test_register.py +++ b/tests/kernels/helion/test_register.py @@ -35,6 +35,11 @@ from vllm.kernels.helion.register import ( validate_helion_settings, ) +if _HOP_AVAILABLE: + from helion._compiler._dynamo.higher_order_ops import ( + helion_kernel_wrapper_mutation, + ) + def _add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) @@ -941,3 +946,60 @@ class TestKernelRegistry: registered = get_registered_kernels() assert "disabled_kernel" in registered assert registered["disabled_kernel"] is wrapper + + +@pytest.mark.skipif(not _HOP_AVAILABLE, reason="Requires PyTorch >= 2.11 for HOP") +class TestTorchCompileHOP: + """Test that HelionKernelWrapper emits the correct HOP under torch.compile.""" + + def test_compiled_graph_contains_helion_hop(self): + """Verify torch.compile on a HelionKernelWrapper emits a + helion_kernel_wrapper_mutation HOP node in the FX graph.""" + configs = {"default": helion.Config(block_sizes=[4, 4])} + + with dummy_kernel_registry(configs=configs) as register: + add_helion_kernel = register( + op_name="test_torch_compile_add_kernel", + config_picker=lambda args, keys: "default", + )(_add_kernel) + + captured_graph: torch.fx.GraphModule | None = None + + def capturing_backend(gm, example_inputs): + nonlocal captured_graph + assert captured_graph is None, "Backend called multiple times" + captured_graph = gm + return gm.forward + + def f(x, y): + return add_helion_kernel(x, y) + + torch._dynamo.reset() + compiled_f = torch.compile(f, backend=capturing_backend, fullgraph=True) + + x = torch.randn(4, 4, device="cuda") + y = torch.randn(4, 4, device="cuda") + + # Run compiled version and capture graph + compiled_result = compiled_f(x, y) + + assert captured_graph is not None + hop_nodes = [ + node + for node in captured_graph.graph.nodes + if node.op == "call_function" + and node.target is helion_kernel_wrapper_mutation + ] + assert len(hop_nodes) > 0, ( + "Expected helion_kernel_wrapper_mutation HOP node in compiled graph, " + f"but found none. Graph nodes: " + f"{[(n.op, n.target) for n in captured_graph.graph.nodes]}" + ) + + # Verify compiled result matches eager execution + eager_result = f(x, y) # Run in eager mode + + assert torch.allclose(compiled_result, eager_result, atol=1e-5, rtol=1e-5), ( + "Compiled execution result doesn't match eager execution. " + f"Max difference: {torch.max(torch.abs(compiled_result - eager_result))}" + ) diff --git a/vllm/kernels/helion/register.py b/vllm/kernels/helion/register.py index ba98e87ca..1557a36b2 100644 --- a/vllm/kernels/helion/register.py +++ b/vllm/kernels/helion/register.py @@ -63,16 +63,11 @@ from helion.runtime.settings import default_autotuner_fn _HOP_AVAILABLE = requires_torch_version("2.11") if _HOP_AVAILABLE: - import torch.utils._pytree as pytree - from helion._compiler._dynamo.higher_order_ops import ( - helion_kernel_side_table, - helion_kernel_wrapper_mutation, - ) - from helion._compiler._dynamo.variables import infer_output_spec - from torch.fx.experimental.proxy_tensor import ( - disable_proxy_modes_tracing, - get_proxy_mode, - ) + from helion._compiler._dynamo.higher_order_ops import helion_kernel_side_table + from helion._compiler._dynamo.variables import HelionKernelVariable + from torch._dynamo.guards import GuardBuilder + from torch._dynamo.variables.builder import VariableBuilder + logger = init_logger(__name__) @@ -298,76 +293,12 @@ class HelionKernelWrapper: f"Kernel '{self.op_name}' was not initialized. " "Please open an issue on GitHub." ) - if get_proxy_mode() is not None: - return self._call_via_hop(args, kwargs) + + # During Dynamo tracing, this call will be intercepted by our custom + # HelionKernelWrapperVariable and handled via proper HOP emission. + # During eager execution, call the kernel directly. return self._configured_kernel(*args, **kwargs) - def _call_via_hop( - self, - args: tuple[Any, ...], - kwargs: dict[str, Any], - ) -> Any: - kernel = self.get_configured_op()._decorated_kernel - kernel_idx = helion_kernel_side_table.add_kernel(kernel) - - constant_args, tensor_args = self._partition_args(kernel, args, kwargs) - - all_named = {**constant_args, **tensor_args} - full_args = tuple( - all_named.get(n, p.default) - for n, p in kernel.signature.parameters.items() # type: ignore[attr-defined] - if n in all_named or p.default is not p.empty - ) - - with disable_proxy_modes_tracing(): - output_spec = infer_output_spec(kernel, full_args) - - hop_result = helion_kernel_wrapper_mutation( - kernel_idx=kernel_idx, - constant_args=constant_args, - tensor_args=tensor_args, - output_spec=output_spec, - ) - - tree_spec_str = output_spec.get("tree_spec_str") - if tree_spec_str is None: - return None - tree_spec = pytree.treespec_loads(tree_spec_str) - - hop_iter = iter(hop_result) - reconstructed = [] - for spec in output_spec["leaf_specs"]: - is_constant_scalar = spec["type"] == "scalar" and not isinstance( - spec.get("scalar_value"), torch.SymInt - ) - if is_constant_scalar: - reconstructed.append(spec["scalar_value"]) - else: - reconstructed.append(next(hop_iter)) - return pytree.tree_unflatten(reconstructed, tree_spec) - - @staticmethod - def _partition_args( - kernel: Any, - args: tuple[Any, ...], - kwargs: dict[str, Any], - ) -> tuple[dict[str, Any], dict[str, Any]]: - constant_args: dict[str, Any] = {} - tensor_args: dict[str, Any] = {} - params = list(kernel.signature.parameters.keys()) - for i, val in enumerate(args): - name = params[i] - if isinstance(val, torch.Tensor): - tensor_args[name] = val - else: - constant_args[name] = val - for name, val in kwargs.items(): - if isinstance(val, torch.Tensor): - tensor_args[name] = val - else: - constant_args[name] = val - return constant_args, tensor_args - def get_inputs(self) -> dict[str, tuple[Any, ...]]: if self._input_generator is None: raise NotImplementedError( @@ -535,3 +466,30 @@ def register_kernel( return kernel_wrapper return decorator + + +# Register HelionKernelWrapper with Dynamo's variable tracker system +if _HOP_AVAILABLE: + + def _register_vllm_helion_dynamo_variable(): + """Register HelionKernelWrapper with Dynamo's VariableBuilder. + + When Dynamo encounters a HelionKernelWrapper during tracing, this + extracts the underlying Helion Kernel, registers it in the side table, + and returns Helion's own HelionKernelVariable to handle HOP emission. + """ + + def wrap_helion_kernel_wrapper( + builder: VariableBuilder, value: HelionKernelWrapper + ): + kernel = value.get_configured_op()._decorated_kernel + kernel_idx = helion_kernel_side_table.add_kernel(kernel) + builder.install_guards(GuardBuilder.ID_MATCH) + return HelionKernelVariable(kernel, kernel_idx, source=builder.source) + + # Register with Dynamo's type dispatch system + dispatch = VariableBuilder._type_dispatch() + dispatch[HelionKernelWrapper] = wrap_helion_kernel_wrapper + + # Register immediately when the module is imported + _register_vllm_helion_dynamo_variable()