[Kernel] [Helion] [17/N] Add Helion kernel torch.compile support (#38592)
Signed-off-by: Yanan Cao <gmagogsfm@gmail.com> Co-authored-by: Claude Sonnet 4 <noreply@anthropic.com>
This commit is contained in:
@@ -35,6 +35,11 @@ from vllm.kernels.helion.register import (
|
||||
validate_helion_settings,
|
||||
)
|
||||
|
||||
if _HOP_AVAILABLE:
|
||||
from helion._compiler._dynamo.higher_order_ops import (
|
||||
helion_kernel_wrapper_mutation,
|
||||
)
|
||||
|
||||
|
||||
def _add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
out = torch.empty_like(x)
|
||||
@@ -941,3 +946,60 @@ class TestKernelRegistry:
|
||||
registered = get_registered_kernels()
|
||||
assert "disabled_kernel" in registered
|
||||
assert registered["disabled_kernel"] is wrapper
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _HOP_AVAILABLE, reason="Requires PyTorch >= 2.11 for HOP")
|
||||
class TestTorchCompileHOP:
|
||||
"""Test that HelionKernelWrapper emits the correct HOP under torch.compile."""
|
||||
|
||||
def test_compiled_graph_contains_helion_hop(self):
|
||||
"""Verify torch.compile on a HelionKernelWrapper emits a
|
||||
helion_kernel_wrapper_mutation HOP node in the FX graph."""
|
||||
configs = {"default": helion.Config(block_sizes=[4, 4])}
|
||||
|
||||
with dummy_kernel_registry(configs=configs) as register:
|
||||
add_helion_kernel = register(
|
||||
op_name="test_torch_compile_add_kernel",
|
||||
config_picker=lambda args, keys: "default",
|
||||
)(_add_kernel)
|
||||
|
||||
captured_graph: torch.fx.GraphModule | None = None
|
||||
|
||||
def capturing_backend(gm, example_inputs):
|
||||
nonlocal captured_graph
|
||||
assert captured_graph is None, "Backend called multiple times"
|
||||
captured_graph = gm
|
||||
return gm.forward
|
||||
|
||||
def f(x, y):
|
||||
return add_helion_kernel(x, y)
|
||||
|
||||
torch._dynamo.reset()
|
||||
compiled_f = torch.compile(f, backend=capturing_backend, fullgraph=True)
|
||||
|
||||
x = torch.randn(4, 4, device="cuda")
|
||||
y = torch.randn(4, 4, device="cuda")
|
||||
|
||||
# Run compiled version and capture graph
|
||||
compiled_result = compiled_f(x, y)
|
||||
|
||||
assert captured_graph is not None
|
||||
hop_nodes = [
|
||||
node
|
||||
for node in captured_graph.graph.nodes
|
||||
if node.op == "call_function"
|
||||
and node.target is helion_kernel_wrapper_mutation
|
||||
]
|
||||
assert len(hop_nodes) > 0, (
|
||||
"Expected helion_kernel_wrapper_mutation HOP node in compiled graph, "
|
||||
f"but found none. Graph nodes: "
|
||||
f"{[(n.op, n.target) for n in captured_graph.graph.nodes]}"
|
||||
)
|
||||
|
||||
# Verify compiled result matches eager execution
|
||||
eager_result = f(x, y) # Run in eager mode
|
||||
|
||||
assert torch.allclose(compiled_result, eager_result, atol=1e-5, rtol=1e-5), (
|
||||
"Compiled execution result doesn't match eager execution. "
|
||||
f"Max difference: {torch.max(torch.abs(compiled_result - eager_result))}"
|
||||
)
|
||||
|
||||
@@ -63,16 +63,11 @@ from helion.runtime.settings import default_autotuner_fn
|
||||
_HOP_AVAILABLE = requires_torch_version("2.11")
|
||||
|
||||
if _HOP_AVAILABLE:
|
||||
import torch.utils._pytree as pytree
|
||||
from helion._compiler._dynamo.higher_order_ops import (
|
||||
helion_kernel_side_table,
|
||||
helion_kernel_wrapper_mutation,
|
||||
)
|
||||
from helion._compiler._dynamo.variables import infer_output_spec
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
disable_proxy_modes_tracing,
|
||||
get_proxy_mode,
|
||||
)
|
||||
from helion._compiler._dynamo.higher_order_ops import helion_kernel_side_table
|
||||
from helion._compiler._dynamo.variables import HelionKernelVariable
|
||||
from torch._dynamo.guards import GuardBuilder
|
||||
from torch._dynamo.variables.builder import VariableBuilder
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -298,76 +293,12 @@ class HelionKernelWrapper:
|
||||
f"Kernel '{self.op_name}' was not initialized. "
|
||||
"Please open an issue on GitHub."
|
||||
)
|
||||
if get_proxy_mode() is not None:
|
||||
return self._call_via_hop(args, kwargs)
|
||||
|
||||
# During Dynamo tracing, this call will be intercepted by our custom
|
||||
# HelionKernelWrapperVariable and handled via proper HOP emission.
|
||||
# During eager execution, call the kernel directly.
|
||||
return self._configured_kernel(*args, **kwargs)
|
||||
|
||||
def _call_via_hop(
|
||||
self,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
kernel = self.get_configured_op()._decorated_kernel
|
||||
kernel_idx = helion_kernel_side_table.add_kernel(kernel)
|
||||
|
||||
constant_args, tensor_args = self._partition_args(kernel, args, kwargs)
|
||||
|
||||
all_named = {**constant_args, **tensor_args}
|
||||
full_args = tuple(
|
||||
all_named.get(n, p.default)
|
||||
for n, p in kernel.signature.parameters.items() # type: ignore[attr-defined]
|
||||
if n in all_named or p.default is not p.empty
|
||||
)
|
||||
|
||||
with disable_proxy_modes_tracing():
|
||||
output_spec = infer_output_spec(kernel, full_args)
|
||||
|
||||
hop_result = helion_kernel_wrapper_mutation(
|
||||
kernel_idx=kernel_idx,
|
||||
constant_args=constant_args,
|
||||
tensor_args=tensor_args,
|
||||
output_spec=output_spec,
|
||||
)
|
||||
|
||||
tree_spec_str = output_spec.get("tree_spec_str")
|
||||
if tree_spec_str is None:
|
||||
return None
|
||||
tree_spec = pytree.treespec_loads(tree_spec_str)
|
||||
|
||||
hop_iter = iter(hop_result)
|
||||
reconstructed = []
|
||||
for spec in output_spec["leaf_specs"]:
|
||||
is_constant_scalar = spec["type"] == "scalar" and not isinstance(
|
||||
spec.get("scalar_value"), torch.SymInt
|
||||
)
|
||||
if is_constant_scalar:
|
||||
reconstructed.append(spec["scalar_value"])
|
||||
else:
|
||||
reconstructed.append(next(hop_iter))
|
||||
return pytree.tree_unflatten(reconstructed, tree_spec)
|
||||
|
||||
@staticmethod
|
||||
def _partition_args(
|
||||
kernel: Any,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
constant_args: dict[str, Any] = {}
|
||||
tensor_args: dict[str, Any] = {}
|
||||
params = list(kernel.signature.parameters.keys())
|
||||
for i, val in enumerate(args):
|
||||
name = params[i]
|
||||
if isinstance(val, torch.Tensor):
|
||||
tensor_args[name] = val
|
||||
else:
|
||||
constant_args[name] = val
|
||||
for name, val in kwargs.items():
|
||||
if isinstance(val, torch.Tensor):
|
||||
tensor_args[name] = val
|
||||
else:
|
||||
constant_args[name] = val
|
||||
return constant_args, tensor_args
|
||||
|
||||
def get_inputs(self) -> dict[str, tuple[Any, ...]]:
|
||||
if self._input_generator is None:
|
||||
raise NotImplementedError(
|
||||
@@ -535,3 +466,30 @@ def register_kernel(
|
||||
return kernel_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# Register HelionKernelWrapper with Dynamo's variable tracker system
|
||||
if _HOP_AVAILABLE:
|
||||
|
||||
def _register_vllm_helion_dynamo_variable():
|
||||
"""Register HelionKernelWrapper with Dynamo's VariableBuilder.
|
||||
|
||||
When Dynamo encounters a HelionKernelWrapper during tracing, this
|
||||
extracts the underlying Helion Kernel, registers it in the side table,
|
||||
and returns Helion's own HelionKernelVariable to handle HOP emission.
|
||||
"""
|
||||
|
||||
def wrap_helion_kernel_wrapper(
|
||||
builder: VariableBuilder, value: HelionKernelWrapper
|
||||
):
|
||||
kernel = value.get_configured_op()._decorated_kernel
|
||||
kernel_idx = helion_kernel_side_table.add_kernel(kernel)
|
||||
builder.install_guards(GuardBuilder.ID_MATCH)
|
||||
return HelionKernelVariable(kernel, kernel_idx, source=builder.source)
|
||||
|
||||
# Register with Dynamo's type dispatch system
|
||||
dispatch = VariableBuilder._type_dispatch()
|
||||
dispatch[HelionKernelWrapper] = wrap_helion_kernel_wrapper
|
||||
|
||||
# Register immediately when the module is imported
|
||||
_register_vllm_helion_dynamo_variable()
|
||||
|
||||
Reference in New Issue
Block a user