diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 943220244..8e2e0c4ab 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -522,7 +522,9 @@ def _support_torch_compile( # assume_32bit_indexing is only available in torch 2.10.0.dev+ inductor_config_patches = {} if is_torch_equal_or_newer("2.10.0.dev"): - inductor_config_patches["assume_32bit_indexing"] = True + inductor_config_patches["assume_32bit_indexing"] = ( + self.compilation_config.dynamic_shapes_config.assume_32_bit_indexing + ) with ( patch.object( diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index ed50e0d49..7bffa53bd 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -278,7 +278,11 @@ class DynamicShapesConfig: artifacts also. When type is backed, aot_compile must be disabled for this mode to work. until this change picked up https://github.com/pytorch/pytorch/pull/169239. + """ + assume_32_bit_indexing: bool = True + """ + whether all tensor sizes can use 32 bit indexing. """ def compute_hash(self) -> str: @@ -640,6 +644,7 @@ class CompilationConfig: "compilation_time", "static_forward_context", "pass_config", # handled separately below + "dynamic_shapes_config", # handled separately below } from vllm.config.utils import get_hash_factors, hash_factors @@ -647,6 +652,7 @@ class CompilationConfig: factors = get_hash_factors(self, ignored_factors) factors["pass_config"] = self.pass_config.compute_hash() + factors["dynamic_shapes_config"] = self.dynamic_shapes_config.compute_hash() return hash_factors(factors) def __repr__(self) -> str: