Add option to use unbacked, and backed size obl dynamic shapes for more sounds compilation. (#26199)

Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
Laith Sakka
2025-11-24 07:12:41 -08:00
committed by GitHub
parent f716a15372
commit 7a228b5305
8 changed files with 442 additions and 15 deletions

View File

@@ -151,6 +151,76 @@ To avoid this, please either:
2. wrap the branching logic into a custom operator. TorchDynamo does not
trace into custom operators.
## Debugging constraint violations and dynamic shapes guards issues
Dynamic-shape guards are a specific category of Dynamo guards. They are constraints that `torch.compile`
attaches to dynamic dimensions (e.g., `seq_len`) to ensure the compiled artifact remains valid.
These guards typically appear when framework code, custom passes, or user code branches based on
dynamic shape values.
**Example:**
```python
if x > 10:
# path A
else:
# path B
```
This creates a guard `x > 10` or `x <= 10` depending on which path was traced.
**vLLM's Assumption:**
vLLM assumes that all guards added by torch.compile are safe to drop and will not
constrain the compiled graph to specific input shapes. When this assumption is violated,
it can cause issues that users need to debug.
Some side effects that indicates this assumption is violated are runtime errors
or `ConstraintViolationErrors`.
A `ConstraintViolationErrors` will be thrown if a dynamic shape gets constrained to
a single value. If you encounter a constraint violation error or suspect that a dynamic
shapes guard is being added incorrectly, you can use stricter dynamic shape modes to
help debug the issue:
```sh
# Online - using unbacked mode
vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=unbacked
# Online - using backed_size_oblivious mode
vllm serve meta-llama/Llama-3.2-1B -O.dynamic_shapes_config.type=backed_size_oblivious
```
```py
# Offline - using unbacked mode
from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
LLM(model, compilation_config=CompilationConfig(
dynamic_shapes_config=DynamicShapesConfig(type=DynamicShapesType.UNBACKED)
))
# Offline - using backed_size_oblivious mode
from vllm.config.compilation import CompilationConfig, DynamicShapesConfig, DynamicShapesType
LLM(model, compilation_config=CompilationConfig(
dynamic_shapes_config=DynamicShapesConfig(type=DynamicShapesType.BACKED_SIZE_OBLIVIOUS)
))
```
These modes are stricter and reduce or eliminate the need of dynamic shapes guarding, which can help isolate issues:
- `unbacked`: Uses unbacked symints which don't allow guards, making it easier to identify where guards are being incorrectly added
- `backed_size_oblivious`: Uses a mode that is more strict about guarding.
For more details on dynamic shapes modes, see [Dynamic shapes and vLLM guard dropping](torch_compile.md#dynamic-shapes-and-vllm-guard-dropping).
### Printing guards
To see all guards that are being added during compilation, you can use `TORCH_LOGS=+dynamic`:
```sh
TORCH_LOGS=+dynamic vllm serve meta-llama/Llama-3.2-1B
```
Look for `[guard added]` in the logs to see where guards are being added. This can help you identify which operations are
causing guards to be added incorrectly.
## Debugging TorchInductor
TorchInductor takes a captured graph and then compiles it down to some Python code