set assume_32bit_indexing and pass unbacked hints (#30459)
Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
@@ -28,7 +28,7 @@ from vllm.config.compilation import DynamicShapesType
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||||
from vllm.utils.torch_utils import supports_dynamo
|
from vllm.utils.torch_utils import is_torch_equal_or_newer, supports_dynamo
|
||||||
|
|
||||||
from .monitor import start_monitoring_torch_compile
|
from .monitor import start_monitoring_torch_compile
|
||||||
|
|
||||||
@@ -316,6 +316,12 @@ def _support_torch_compile(
|
|||||||
def _mark_dynamic_inputs(mod, type, *args, **kwargs):
|
def _mark_dynamic_inputs(mod, type, *args, **kwargs):
|
||||||
def mark_dynamic(arg, dims):
|
def mark_dynamic(arg, dims):
|
||||||
if type == DynamicShapesType.UNBACKED:
|
if type == DynamicShapesType.UNBACKED:
|
||||||
|
if is_torch_equal_or_newer("2.10.0.dev"):
|
||||||
|
for dim in dims:
|
||||||
|
torch._dynamo.decorators.mark_unbacked(
|
||||||
|
arg, dim, hint_override=arg.size()[dim]
|
||||||
|
)
|
||||||
|
else:
|
||||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||||
else:
|
else:
|
||||||
torch._dynamo.mark_dynamic(arg, dims)
|
torch._dynamo.mark_dynamic(arg, dims)
|
||||||
@@ -350,6 +356,12 @@ def _support_torch_compile(
|
|||||||
if isinstance(arg, torch.Tensor):
|
if isinstance(arg, torch.Tensor):
|
||||||
# In case dims is specified with negative indexing
|
# In case dims is specified with negative indexing
|
||||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||||
|
if is_torch_equal_or_newer("2.10.0.dev"):
|
||||||
|
for dim in dims:
|
||||||
|
torch._dynamo.decorators.mark_unbacked(
|
||||||
|
arg, dim, hint_override=arg.size()[dim]
|
||||||
|
)
|
||||||
|
else:
|
||||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
@@ -488,6 +500,12 @@ def _support_torch_compile(
|
|||||||
if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
|
if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
|
||||||
fx_config_patches["backed_size_oblivious"] = True
|
fx_config_patches["backed_size_oblivious"] = True
|
||||||
|
|
||||||
|
# Prepare inductor config patches
|
||||||
|
# assume_32bit_indexing is only available in torch 2.10.0.dev+
|
||||||
|
inductor_config_patches = {}
|
||||||
|
if is_torch_equal_or_newer("2.10.0.dev"):
|
||||||
|
inductor_config_patches["assume_32bit_indexing"] = True
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch.object(
|
patch.object(
|
||||||
InliningInstructionTranslator, "inline_call_", patched_inline_call
|
InliningInstructionTranslator, "inline_call_", patched_inline_call
|
||||||
@@ -496,6 +514,7 @@ def _support_torch_compile(
|
|||||||
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
|
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
|
||||||
torch.fx.experimental._config.patch(**fx_config_patches),
|
torch.fx.experimental._config.patch(**fx_config_patches),
|
||||||
_torch27_patch_tensor_subclasses(),
|
_torch27_patch_tensor_subclasses(),
|
||||||
|
torch._inductor.config.patch(**inductor_config_patches),
|
||||||
):
|
):
|
||||||
if envs.VLLM_USE_AOT_COMPILE:
|
if envs.VLLM_USE_AOT_COMPILE:
|
||||||
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
|
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user