[torch.compile][BE][Multimodal] Remove requirement to set_model_tag to avoid cache conflict (#37345)

Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
Lucas Kabela
2026-03-19 10:26:12 -07:00
committed by GitHub
parent 2f9f946b22
commit 7769b58307
7 changed files with 86 additions and 69 deletions

View File

@@ -29,10 +29,9 @@ To compile a multimodal component such as an encoder, we follow the same mechani
1. The `@support_torch_compile` decorator should include `enable_if=should_torch_compile_mm_encoder`. This will gate the compilation behind our
`compile_mm_encoder` configuration
2. `with set_model_tag("<component_name>", is_encoder=True)` context manager should be used around the nn.Module's instantiation. Since torch.compile
relies on caching artifacts to reduce start time, we must properly propagate the `<component_name>` information to the cache in order to avoid collisions
with the LLM text-backbone, or other instances of the same artifact (as is the case with vision block). `is_encoder=True` is also needed for encoder
components (see Compile Range Integration).
2. The `@support_torch_compile` decorator should include `is_encoder=True` for encoder components. This is needed for compile range integration
(see Compile Range Integration). The decorator automatically uses the class name as the cache directory prefix, avoiding collisions between
independently compiled sub-modules (e.g. vision encoder components vs the text backbone).
### CompilationConfig
@@ -57,8 +56,8 @@ tradeoff
### Compile ranges
The torch.compile integration will try to rely on max_batch_size to infer compilation ranges for dynamic shapes; however, for modules used in the encoder, this
shape can be difficult to infer due to the unspecified range of shapes the encoder may see as input. Therefore, we rely on `is_encoder=True` in the `set_model_tag`
to alert torch.compile to the fact that this range cannot be inferred, and we default to the range (1, MAX_INT).
shape can be difficult to infer due to the unspecified range of shapes the encoder may see as input. Therefore, we rely on `is_encoder=True` in the
`@support_torch_compile` decorator to alert torch.compile to the fact that this range cannot be inferred, and we default to the range (1, MAX_INT).
!!! note
We may seek to tighten this range for better performance in the future