Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -4,8 +4,9 @@ import contextlib
|
||||
import copy
|
||||
import hashlib
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Optional
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@@ -62,9 +63,9 @@ class CompilerInterface:
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
runtime_shape: int | None = None,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
"""
|
||||
Compile the graph with the given example inputs and compiler config,
|
||||
with a runtime shape. If the `runtime_shape` is None, it means
|
||||
@@ -97,7 +98,7 @@ class CompilerInterface:
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None,
|
||||
runtime_shape: int | None = None,
|
||||
) -> Callable:
|
||||
"""
|
||||
Load the compiled function from the handle.
|
||||
@@ -191,9 +192,9 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
runtime_shape: int | None = None,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
@@ -229,7 +230,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None,
|
||||
runtime_shape: int | None = None,
|
||||
) -> Callable:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
@@ -293,9 +294,9 @@ class InductorAdaptor(CompilerInterface):
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
runtime_shape: int | None = None,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
@@ -492,7 +493,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None,
|
||||
runtime_shape: int | None = None,
|
||||
) -> Callable:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
@@ -610,9 +611,9 @@ class EagerAdaptor(CompilerInterface):
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: list[Any],
|
||||
compiler_config: dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> tuple[Optional[Callable], Optional[Any]]:
|
||||
runtime_shape: int | None = None,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable | None, Any | None]:
|
||||
compilation_counter.num_eager_compiles += 1
|
||||
# we don't need to compile the graph, just return the graph itself.
|
||||
# It does not support caching, return None for the handle.
|
||||
|
||||
Reference in New Issue
Block a user