[torch.compile] fix sym_tensor_indices (#12191)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-01-20 11:37:50 +08:00
committed by GitHub
parent df450aa567
commit 51ef828f10

View File

@@ -624,9 +624,13 @@ class VllmBackend:
]
# index of tensors that have symbolic shapes (batch size)
# for weights and static buffers, they will have concrete shapes.
# symbolic shape only happens for input tensors.
from torch.fx.experimental.symbolic_shapes import is_symbolic
self.sym_tensor_indices = [
i for i, x in enumerate(fake_args)
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor) and \
any(is_symbolic(d) for d in x.size())
]
# compiler managed cudagraph input buffers