set assume_32bit_indexing and pass unbacked hints (#30459)

Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
Laith Sakka
2025-12-13 18:36:53 +03:00
committed by GitHub
parent 39cefbdf17
commit 763963aa73

View File

@@ -28,7 +28,7 @@ from vllm.config.compilation import DynamicShapesType
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import supports_dynamo from vllm.utils.torch_utils import is_torch_equal_or_newer, supports_dynamo
from .monitor import start_monitoring_torch_compile from .monitor import start_monitoring_torch_compile
@@ -316,6 +316,12 @@ def _support_torch_compile(
def _mark_dynamic_inputs(mod, type, *args, **kwargs): def _mark_dynamic_inputs(mod, type, *args, **kwargs):
def mark_dynamic(arg, dims): def mark_dynamic(arg, dims):
if type == DynamicShapesType.UNBACKED: if type == DynamicShapesType.UNBACKED:
if is_torch_equal_or_newer("2.10.0.dev"):
for dim in dims:
torch._dynamo.decorators.mark_unbacked(
arg, dim, hint_override=arg.size()[dim]
)
else:
torch._dynamo.decorators.mark_unbacked(arg, dims) torch._dynamo.decorators.mark_unbacked(arg, dims)
else: else:
torch._dynamo.mark_dynamic(arg, dims) torch._dynamo.mark_dynamic(arg, dims)
@@ -350,6 +356,12 @@ def _support_torch_compile(
if isinstance(arg, torch.Tensor): if isinstance(arg, torch.Tensor):
# In case dims is specified with negative indexing # In case dims is specified with negative indexing
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims] dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
if is_torch_equal_or_newer("2.10.0.dev"):
for dim in dims:
torch._dynamo.decorators.mark_unbacked(
arg, dim, hint_override=arg.size()[dim]
)
else:
torch._dynamo.decorators.mark_unbacked(arg, dims) torch._dynamo.decorators.mark_unbacked(arg, dims)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
@@ -488,6 +500,12 @@ def _support_torch_compile(
if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS: if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
fx_config_patches["backed_size_oblivious"] = True fx_config_patches["backed_size_oblivious"] = True
# Prepare inductor config patches
# assume_32bit_indexing is only available in torch 2.10.0.dev+
inductor_config_patches = {}
if is_torch_equal_or_newer("2.10.0.dev"):
inductor_config_patches["assume_32bit_indexing"] = True
with ( with (
patch.object( patch.object(
InliningInstructionTranslator, "inline_call_", patched_inline_call InliningInstructionTranslator, "inline_call_", patched_inline_call
@@ -496,6 +514,7 @@ def _support_torch_compile(
maybe_use_cudagraph_partition_wrapper(self.vllm_config), maybe_use_cudagraph_partition_wrapper(self.vllm_config),
torch.fx.experimental._config.patch(**fx_config_patches), torch.fx.experimental._config.patch(**fx_config_patches),
_torch27_patch_tensor_subclasses(), _torch27_patch_tensor_subclasses(),
torch._inductor.config.patch(**inductor_config_patches),
): ):
if envs.VLLM_USE_AOT_COMPILE: if envs.VLLM_USE_AOT_COMPILE:
self.aot_compiled_fn = self.aot_compile(*args, **kwargs) self.aot_compiled_fn = self.aot_compile(*args, **kwargs)