[Bugfix] Add fake mode around passes (#23349)

Signed-off-by: angelayi <yiangela7@gmail.com>
This commit is contained in:
Angela Yi
2025-08-28 08:25:56 -07:00
committed by GitHub
parent 95089607fa
commit db74d60490
6 changed files with 64 additions and 39 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import hashlib
import inspect
import json
@@ -10,6 +11,8 @@ from typing import Any, Callable, Optional, Union
import torch
from torch import fx
from torch._subclasses.fake_tensor import (FakeTensorMode,
unset_fake_temporarily)
from vllm.utils import is_torch_equal_or_newer
@@ -114,3 +117,20 @@ class CallableInductorPass(InductorPass):
def uuid(self) -> Any:
return self._uuid
def enable_fake_mode(fn: Callable[..., Any]) -> Callable[..., Any]:
"""
Applies a FakeTensorMode context. This is useful when you don't want to
create or run things with real tensors.
"""
@functools.wraps(fn)
def fn_new(*args, **kwargs) -> Any:
with torch._guards.tracing(
None), unset_fake_temporarily(), FakeTensorMode():
result = fn(*args, **kwargs)
return result
return fn_new