[torch.compile][BE][Multimodal] Remove requirement to set_model_tag to avoid cache conflict (#37345)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
@@ -29,10 +29,9 @@ To compile a multimodal component such as an encoder, we follow the same mechani
|
||||
1. The `@support_torch_compile` decorator should include `enable_if=should_torch_compile_mm_encoder`. This will gate the compilation behind our
|
||||
`compile_mm_encoder` configuration
|
||||
|
||||
2. `with set_model_tag("<component_name>", is_encoder=True)` context manager should be used around the nn.Module's instantiation. Since torch.compile
|
||||
relies on caching artifacts to reduce start time, we must properly propagate the `<component_name>` information to the cache in order to avoid collisions
|
||||
with the LLM text-backbone, or other instances of the same artifact (as is the case with vision block). `is_encoder=True` is also needed for encoder
|
||||
components (see Compile Range Integration).
|
||||
2. The `@support_torch_compile` decorator should include `is_encoder=True` for encoder components. This is needed for compile range integration
|
||||
(see Compile Range Integration). The decorator automatically uses the class name as the cache directory prefix, avoiding collisions between
|
||||
independently compiled sub-modules (e.g. vision encoder components vs the text backbone).
|
||||
|
||||
### CompilationConfig
|
||||
|
||||
@@ -57,8 +56,8 @@ tradeoff
|
||||
### Compile ranges
|
||||
|
||||
The torch.compile integration will try to rely on max_batch_size to infer compilation ranges for dynamic shapes; however, for modules used in the encoder, this
|
||||
shape can be difficult to infer due to the unspecified range of shapes the encoder may see as input. Therefore, we rely on `is_encoder=True` in the `set_model_tag`
|
||||
to alert torch.compile to the fact that this range cannot be inferred, and we default to the range (1, MAX_INT).
|
||||
shape can be difficult to infer due to the unspecified range of shapes the encoder may see as input. Therefore, we rely on `is_encoder=True` in the
|
||||
`@support_torch_compile` decorator to alert torch.compile to the fact that this range cannot be inferred, and we default to the range (1, MAX_INT).
|
||||
|
||||
!!! note
|
||||
We may seek to tighten this range for better performance in the future
|
||||
|
||||
Reference in New Issue
Block a user