[Misc][BE] Turn on strict type coverage for vllm/compilation (#31756)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
@@ -100,6 +100,13 @@ ignore_missing_imports = true
|
||||
check_untyped_defs = true
|
||||
follow_imports = "silent"
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "vllm.compilation.*"
|
||||
disallow_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
warn_return_any = true
|
||||
follow_imports = "silent"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = [
|
||||
"slow_test",
|
||||
|
||||
@@ -28,7 +28,7 @@ def test_bad_callable():
|
||||
pass_manager.configure(config)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
pass_manager.add(simple_callable)
|
||||
pass_manager.add(simple_callable) # type: ignore[arg-type]
|
||||
|
||||
|
||||
# Pass that inherits from InductorPass
|
||||
|
||||
@@ -77,6 +77,11 @@ EXCLUDE = [
|
||||
"vllm/v1/attention/ops",
|
||||
]
|
||||
|
||||
# Directories that should be checked with --strict
|
||||
STRICT_DIRS = [
|
||||
"vllm/compilation",
|
||||
]
|
||||
|
||||
|
||||
def group_files(changed_files: list[str]) -> dict[str, list[str]]:
|
||||
"""
|
||||
@@ -108,11 +113,17 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]:
|
||||
return file_groups
|
||||
|
||||
|
||||
def is_strict_file(filepath: str) -> bool:
|
||||
"""Check if a file should be checked with strict mode."""
|
||||
return any(filepath.startswith(strict_dir) for strict_dir in STRICT_DIRS)
|
||||
|
||||
|
||||
def mypy(
|
||||
targets: list[str],
|
||||
python_version: str | None,
|
||||
follow_imports: str | None,
|
||||
file_group: str,
|
||||
strict: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
Run mypy on the given targets.
|
||||
@@ -124,6 +135,7 @@ def mypy(
|
||||
follow_imports: Value for the --follow-imports option or None to use
|
||||
the default mypy behavior.
|
||||
file_group: The file group name for logging purposes.
|
||||
strict: If True, run mypy with --strict flag.
|
||||
|
||||
Returns:
|
||||
The return code from mypy.
|
||||
@@ -133,6 +145,8 @@ def mypy(
|
||||
args += ["--python-version", python_version]
|
||||
if follow_imports is not None:
|
||||
args += ["--follow-imports", follow_imports]
|
||||
if strict:
|
||||
args += ["--strict"]
|
||||
print(f"$ {' '.join(args)} {file_group}")
|
||||
return subprocess.run(args + targets, check=False).returncode
|
||||
|
||||
@@ -149,9 +163,29 @@ def main():
|
||||
for file_group, changed_files in file_groups.items():
|
||||
follow_imports = None if ci and file_group == "" else "skip"
|
||||
if changed_files:
|
||||
returncode |= mypy(
|
||||
changed_files, python_version, follow_imports, file_group
|
||||
)
|
||||
# Separate files into strict and non-strict groups
|
||||
strict_files = [f for f in changed_files if is_strict_file(f)]
|
||||
non_strict_files = [f for f in changed_files if not is_strict_file(f)]
|
||||
|
||||
# Run mypy on non-strict files
|
||||
if non_strict_files:
|
||||
returncode |= mypy(
|
||||
non_strict_files,
|
||||
python_version,
|
||||
follow_imports,
|
||||
file_group,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
# Run mypy on strict files with --strict flag
|
||||
if strict_files:
|
||||
returncode |= mypy(
|
||||
strict_files,
|
||||
python_version,
|
||||
follow_imports,
|
||||
f"{file_group} (strict)",
|
||||
strict=True,
|
||||
)
|
||||
return returncode
|
||||
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ def make_copy_and_call(
|
||||
A wrapper function that copies inputs and calls the compiled function
|
||||
"""
|
||||
|
||||
def copy_and_call(*args):
|
||||
def copy_and_call(*args: Any) -> Any:
|
||||
list_args = list(args)
|
||||
for i, index in enumerate(sym_tensor_indices):
|
||||
runtime_tensor = list_args[index]
|
||||
|
||||
@@ -43,15 +43,15 @@ class StandaloneCompiledArtifacts:
|
||||
split on attn)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
# dict from submodule name to byte hash
|
||||
self.submodule_bytes = {}
|
||||
self.submodule_bytes: dict[str, str] = {}
|
||||
# dict from byte hash to bytes
|
||||
self.submodule_bytes_store = {}
|
||||
self.submodule_bytes_store: dict[str, bytes] = {}
|
||||
# dict from byte hash to loaded module
|
||||
self.loaded_submodule_store = {}
|
||||
self.loaded_submodule_store: dict[str, Any] = {}
|
||||
|
||||
def insert(self, submod_name: str, shape: str, entry: bytes):
|
||||
def insert(self, submod_name: str, shape: str, entry: bytes) -> None:
|
||||
hasher = hashlib.sha256()
|
||||
hasher.update(entry)
|
||||
hex_digest = hasher.hexdigest()
|
||||
@@ -86,7 +86,7 @@ class StandaloneCompiledArtifacts:
|
||||
self.submodule_bytes[f"{submod_name}_{shape}"]
|
||||
]
|
||||
|
||||
def get_loaded(self, submod_name: str, shape: str):
|
||||
def get_loaded(self, submod_name: str, shape: str) -> Any:
|
||||
logger.debug(
|
||||
"getting artifact for submod %s with shape %s",
|
||||
submod_name,
|
||||
@@ -119,7 +119,7 @@ class StandaloneCompiledArtifacts:
|
||||
|
||||
from torch._inductor.standalone_compile import AOTCompiledArtifact
|
||||
|
||||
def _load_entry(entry_bytes) -> AOTCompiledArtifact:
|
||||
def _load_entry(entry_bytes: bytes) -> AOTCompiledArtifact:
|
||||
entry = pickle.loads(entry_bytes)
|
||||
return AOTCompiledArtifact.deserialize(entry)
|
||||
|
||||
@@ -132,13 +132,13 @@ class StandaloneCompiledArtifacts:
|
||||
|
||||
logger.debug("loaded all %s submodules", self.num_artifacts())
|
||||
|
||||
def __getstate__(self):
|
||||
def __getstate__(self) -> dict[str, dict[str, str] | dict[str, bytes]]:
|
||||
return {
|
||||
"submodule_bytes": self.submodule_bytes,
|
||||
"submodule_bytes_store": self.submodule_bytes_store,
|
||||
}
|
||||
|
||||
def __setstate__(self, state):
|
||||
def __setstate__(self, state: dict[str, dict[str, Any]]) -> None:
|
||||
self.submodule_bytes = state["submodule_bytes"]
|
||||
self.submodule_bytes_store = state["submodule_bytes_store"]
|
||||
self.loaded_submodule_store = {}
|
||||
@@ -387,7 +387,7 @@ def reconstruct_serializable_fn_from_mega_artifact(
|
||||
standalone_compile_artifacts.load_all()
|
||||
|
||||
submod_names = standalone_compile_artifacts.submodule_names()
|
||||
compiled_callables: dict[str, dict[str, Callable]] = {}
|
||||
compiled_callables: dict[str, dict[str, Callable[..., Any]]] = {}
|
||||
|
||||
for cache_key in standalone_compile_artifacts.submodule_bytes:
|
||||
submod_name, shape_str = cache_key.rsplit("_", 1)
|
||||
@@ -495,9 +495,10 @@ def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
|
||||
# e.g. exec(). We can't actually check these.
|
||||
continue
|
||||
hash_content.append(content)
|
||||
return safe_hash(
|
||||
result: str = safe_hash(
|
||||
"\n".join(hash_content).encode(), usedforsecurity=False
|
||||
).hexdigest()
|
||||
return result
|
||||
|
||||
|
||||
def _compute_code_hash(files: set[str]) -> str:
|
||||
|
||||
@@ -30,19 +30,15 @@ from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
flashinfer_comm: ModuleType | None = None
|
||||
if find_spec("flashinfer"):
|
||||
try:
|
||||
import flashinfer.comm as flashinfer_comm
|
||||
import flashinfer.comm as _flashinfer_comm
|
||||
|
||||
flashinfer_comm: ModuleType | None = ( # type: ignore[no-redef]
|
||||
flashinfer_comm
|
||||
if hasattr(flashinfer_comm, "trtllm_allreduce_fusion")
|
||||
else None
|
||||
)
|
||||
if hasattr(_flashinfer_comm, "trtllm_allreduce_fusion"):
|
||||
flashinfer_comm = _flashinfer_comm
|
||||
except ImportError:
|
||||
flashinfer_comm = None # type: ignore[assignment]
|
||||
else:
|
||||
flashinfer_comm = None # type: ignore[assignment]
|
||||
pass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -441,7 +437,7 @@ class AsyncTPPass(VllmPatternMatcherPass):
|
||||
):
|
||||
return True
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return compile_range.is_single_size() and compile_range.end % tp_size == 0
|
||||
return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
@@ -516,7 +512,7 @@ if flashinfer_comm is not None:
|
||||
# Get one shot input size limit for the current world size
|
||||
# for the current device capability
|
||||
max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES_MB.get(
|
||||
device_capability, # type: ignore[arg-type]
|
||||
device_capability, # type: ignore[arg-type, unused-ignore]
|
||||
{},
|
||||
).get(world_size, None)
|
||||
# Use one shot if no max size is specified
|
||||
@@ -666,6 +662,7 @@ class AllReduceRMSNormPattern(BasePattern):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
residual = torch.zeros_like(input)
|
||||
rms_result = torch.empty_like(input)
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
@@ -722,6 +719,7 @@ class AllReduceFusedAddRMSNormPattern(BasePattern):
|
||||
def replacement(
|
||||
residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
@@ -800,6 +798,7 @@ class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
residual = torch.zeros_like(input)
|
||||
result_rms = torch.empty_like(input)
|
||||
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
@@ -875,6 +874,7 @@ class AllReduceFusedAddRMSNormStaticQuantFP8Pattern(BasePattern):
|
||||
scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
result_quant = torch.empty_like(input, dtype=self.quant_dtype)
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
@@ -960,6 +960,7 @@ class AllReduceFusedRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
residual = torch.zeros_like(input)
|
||||
result_rms = torch.empty_like(input)
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
@@ -1055,6 +1056,7 @@ class AllReduceFusedAddRMSNormStaticQuantNVFP4Pattern(BasePattern):
|
||||
weight: torch.Tensor,
|
||||
input_global_scale: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert flashinfer_comm is not None, "FlashInfer must be enabled"
|
||||
allreduce = auto_functionalized(
|
||||
flashinfer_trtllm_fused_allreduce_norm,
|
||||
allreduce_in=input,
|
||||
@@ -1131,7 +1133,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
)
|
||||
|
||||
self.ipc_handles, workspace_tensor = (
|
||||
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion( # type: ignore[misc]
|
||||
flashinfer_comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
|
||||
tp_rank=rank,
|
||||
tp_size=self.tp_size,
|
||||
max_token_num=self.max_token_num,
|
||||
@@ -1204,7 +1206,7 @@ class AllReduceFusionPass(VllmPatternMatcherPass):
|
||||
if self.disabled:
|
||||
logger.warning_once("AllReduce fusion pass is disabled.")
|
||||
return False
|
||||
return compile_range.end <= self.max_token_num
|
||||
return bool(compile_range.end <= self.max_token_num)
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
|
||||
@@ -201,9 +201,9 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
|
||||
:10
|
||||
]
|
||||
hash_str: str = safe_hash(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(
|
||||
@@ -319,9 +319,9 @@ class InductorAdaptor(CompilerInterface):
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()[
|
||||
:10
|
||||
]
|
||||
hash_str: str = safe_hash(
|
||||
str(factors).encode(), usedforsecurity=False
|
||||
).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(
|
||||
|
||||
@@ -45,10 +45,10 @@ logger = init_logger(__name__)
|
||||
|
||||
IGNORE_COMPILE_KEY = "_ignore_compile_vllm"
|
||||
|
||||
_T = TypeVar("_T", bound=type[nn.Module])
|
||||
_T = TypeVar("_T", bound=nn.Module)
|
||||
|
||||
|
||||
def ignore_torch_compile(cls: _T) -> _T:
|
||||
def ignore_torch_compile(cls: type[_T]) -> type[_T]:
|
||||
"""
|
||||
A decorator to ignore support_torch_compile decorator
|
||||
on the class. This is useful when a parent class has
|
||||
@@ -68,7 +68,7 @@ def ignore_torch_compile(cls: _T) -> _T:
|
||||
return cls
|
||||
|
||||
|
||||
def _should_ignore_torch_compile(cls: _T) -> bool:
|
||||
def _should_ignore_torch_compile(cls: type[_T]) -> bool:
|
||||
"""
|
||||
Check if the class should be ignored for torch.compile.
|
||||
"""
|
||||
@@ -79,21 +79,21 @@ def _should_ignore_torch_compile(cls: _T) -> bool:
|
||||
def support_torch_compile(
|
||||
*,
|
||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||
) -> Callable[[_T], _T]: ...
|
||||
) -> Callable[[type[_T]], type[_T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
dynamic_arg_dims: dict[str, int | list[int]] | None,
|
||||
) -> Callable[[_T], _T]: ...
|
||||
) -> Callable[[type[_T]], type[_T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(
|
||||
*,
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None,
|
||||
) -> Callable[[_T], _T]: ...
|
||||
) -> Callable[[type[_T]], type[_T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
@@ -101,21 +101,21 @@ def support_torch_compile(
|
||||
*,
|
||||
dynamic_arg_dims: dict[str, int | list[int]] | None,
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None,
|
||||
) -> Callable[[_T], _T]: ...
|
||||
) -> Callable[[type[_T]], type[_T]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def support_torch_compile(cls: _T) -> _T: ...
|
||||
def support_torch_compile(cls: type[_T]) -> type[_T]: ...
|
||||
|
||||
|
||||
def support_torch_compile(
|
||||
cls: _T | None = None,
|
||||
cls: type[_T] | None = None,
|
||||
*,
|
||||
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
|
||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
|
||||
) -> Callable[[_T], _T] | _T:
|
||||
) -> Callable[[type[_T]], type[_T]] | type[_T]:
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
|
||||
@@ -182,7 +182,7 @@ def support_torch_compile(
|
||||
errors.
|
||||
"""
|
||||
|
||||
def cls_decorator_helper(cls: _T) -> _T:
|
||||
def cls_decorator_helper(cls: type[_T]) -> type[_T]:
|
||||
# helper to pass `dynamic_arg_dims` to `_support_torch_compile`
|
||||
# to avoid too much indentation for `_support_torch_compile`
|
||||
if not hasattr(cls, "forward"):
|
||||
@@ -263,12 +263,12 @@ def _verify_source_unchanged(
|
||||
|
||||
|
||||
def _support_torch_compile(
|
||||
cls: _T,
|
||||
cls: type[_T],
|
||||
dynamic_arg_dims: dict[str, int | list[int]],
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
|
||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
|
||||
) -> _T:
|
||||
) -> type[_T]:
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
"""
|
||||
@@ -325,12 +325,12 @@ def _support_torch_compile(
|
||||
self.compiled = False
|
||||
|
||||
# Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class
|
||||
TorchCompileWithNoGuardsWrapper.__init__(self) # type: ignore[arg-type]
|
||||
TorchCompileWithNoGuardsWrapper.__init__(self)
|
||||
|
||||
cls.__init__ = __init__
|
||||
|
||||
def _mark_dynamic_inputs(
|
||||
mod: _T, ds_type: DynamicShapesType, *args: Any, **kwargs: Any
|
||||
mod: type[_T], ds_type: DynamicShapesType, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
def mark_dynamic(arg: torch.Tensor, dims: list[int]) -> None:
|
||||
if ds_type == DynamicShapesType.UNBACKED:
|
||||
@@ -382,7 +382,7 @@ def _support_torch_compile(
|
||||
else:
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
|
||||
def __call__(self: _T, *args: Any, **kwargs: Any) -> Any:
|
||||
def __call__(self: type[_T], *args: Any, **kwargs: Any) -> Any:
|
||||
# torch.compiler.is_compiling() means we are inside the compilation
|
||||
# e.g. TPU has the compilation logic in model runner, so we don't
|
||||
# need to compile the model inside.
|
||||
@@ -564,7 +564,7 @@ def _support_torch_compile(
|
||||
return output
|
||||
|
||||
# triggers VllmSerializableFunction.serialize()
|
||||
def save_aot_compiled_function(self):
|
||||
def save_aot_compiled_function(self: type[_T]) -> None:
|
||||
if self.was_aot_compile_fn_loaded_from_disk:
|
||||
logger.debug("AOT compiled function was loaded from cache, skipping save")
|
||||
return
|
||||
|
||||
@@ -141,15 +141,18 @@ class MatcherRotaryEmbedding(MatcherCustomOp):
|
||||
key: torch.Tensor | None,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
return RotaryEmbedding.forward_static(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
cos_sin_cache,
|
||||
self.is_neox,
|
||||
result: tuple[torch.Tensor, torch.Tensor | None] = (
|
||||
RotaryEmbedding.forward_static(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
cos_sin_cache,
|
||||
self.is_neox,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class MatcherRMSNorm(MatcherCustomOp):
|
||||
@@ -275,9 +278,10 @@ class MatcherFusedAddRMSNorm(MatcherCustomOp):
|
||||
weight: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return RMSNorm.forward_static(
|
||||
result: tuple[torch.Tensor, torch.Tensor] = RMSNorm.forward_static(
|
||||
input, self.epsilon, input.size(-1), self.model_dtype, weight, residual
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class MatcherQuantFP8(MatcherCustomOp):
|
||||
|
||||
@@ -25,7 +25,7 @@ logger = init_logger(__name__)
|
||||
class RangeEntry:
|
||||
compile_range: Range
|
||||
compiled: bool = False
|
||||
runnable: Callable = None # type: ignore
|
||||
runnable: Callable[..., Any] = None # type: ignore
|
||||
|
||||
|
||||
class PiecewiseBackend:
|
||||
@@ -38,7 +38,7 @@ class PiecewiseBackend:
|
||||
sym_shape_indices: list[int],
|
||||
vllm_backend: VllmBackend,
|
||||
returns_tuple: bool,
|
||||
compiled_runnables: dict[str, Callable] | None = None,
|
||||
compiled_runnables: dict[str, Callable[..., Any]] | None = None,
|
||||
):
|
||||
"""
|
||||
The backend for piecewise compilation.
|
||||
@@ -138,8 +138,10 @@ class PiecewiseBackend:
|
||||
|
||||
self.on_compilation_complete = _on_compilation_complete_callback.get()
|
||||
|
||||
def get_compiled_graph_wrapper(self, compiled_graph):
|
||||
def compiled_graph_wrapper(*args):
|
||||
def get_compiled_graph_wrapper(
|
||||
self, compiled_graph: Callable[..., Any]
|
||||
) -> Callable[..., Any]:
|
||||
def compiled_graph_wrapper(*args: Any) -> Any:
|
||||
graph_output = compiled_graph(*args)
|
||||
# unpack the tuple if needed
|
||||
# TODO(rzou): the implication is that we're not
|
||||
@@ -163,7 +165,7 @@ class PiecewiseBackend:
|
||||
|
||||
def to_bytes(self) -> dict[str, bytes]:
|
||||
class StandaloneCompiledArtifactsPickler(Pickler):
|
||||
def reducer_override(self, obj):
|
||||
def reducer_override(self, obj: object) -> Any:
|
||||
if isinstance(obj, CachingAutotuner):
|
||||
obj.prepare_for_pickle()
|
||||
return pickle.loads, (
|
||||
@@ -173,7 +175,7 @@ class PiecewiseBackend:
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def serialize(fn) -> bytes:
|
||||
def serialize(fn: Callable[..., Any]) -> bytes:
|
||||
assert hasattr(fn, "serialize"), "fn must have serialize method"
|
||||
with torch._functorch.config.patch("bundled_autograd_cache", True):
|
||||
entry = fn.serialize()
|
||||
|
||||
@@ -358,7 +358,10 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
):
|
||||
return True
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return (compile_range.is_single_size()) and (compile_range.end % tp_size == 0)
|
||||
result: bool = (compile_range.is_single_size()) and (
|
||||
compile_range.end % tp_size == 0
|
||||
)
|
||||
return result
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
|
||||
Reference in New Issue
Block a user