[torch.compile][BE][Multimodal] Remove requirement to set_model_tag to avoid cache conflict (#37345)

Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
Lucas Kabela
2026-03-19 10:26:12 -07:00
committed by GitHub
parent 2f9f946b22
commit 7769b58307
7 changed files with 86 additions and 69 deletions

View File

@@ -29,10 +29,9 @@ To compile a multimodal component such as an encoder, we follow the same mechani
1. The `@support_torch_compile` decorator should include `enable_if=should_torch_compile_mm_encoder`. This will gate the compilation behind our
`compile_mm_encoder` configuration
2. `with set_model_tag("<component_name>", is_encoder=True)` context manager should be used around the nn.Module's instantiation. Since torch.compile
relies on caching artifacts to reduce start time, we must properly propagate the `<component_name>` information to the cache in order to avoid collisions
with the LLM text-backbone, or other instances of the same artifact (as is the case with vision block). `is_encoder=True` is also needed for encoder
components (see Compile Range Integration).
2. The `@support_torch_compile` decorator should include `is_encoder=True` for encoder components. This is needed for compile range integration
(see Compile Range Integration). The decorator automatically uses the class name as the cache directory prefix, avoiding collisions between
independently compiled sub-modules (e.g. vision encoder components vs the text backbone).
### CompilationConfig
@@ -57,8 +56,8 @@ tradeoff
### Compile ranges
The torch.compile integration will try to rely on max_batch_size to infer compilation ranges for dynamic shapes; however, for modules used in the encoder, this
shape can be difficult to infer due to the unspecified range of shapes the encoder may see as input. Therefore, we rely on `is_encoder=True` in the `set_model_tag`
to alert torch.compile to the fact that this range cannot be inferred, and we default to the range (1, MAX_INT).
shape can be difficult to infer due to the unspecified range of shapes the encoder may see as input. Therefore, we rely on `is_encoder=True` in the
`@support_torch_compile` decorator to alert torch.compile to the fact that this range cannot be inferred, and we default to the range (1, MAX_INT).
!!! note
We may seek to tighten this range for better performance in the future

View File

@@ -118,6 +118,7 @@ def support_torch_compile(
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
is_encoder: bool = False,
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
) -> Callable[[type[_T]], type[_T]] | type[_T]:
"""
@@ -177,6 +178,11 @@ def support_torch_compile(
enforce that dynamo does not specialize on 0/1 values in the case of dummy input
such as for vision model compilation
`is_encoder` marks this module as a portion of an multimodal encoder.
When True, the compile range upper bound is set to MAX_INT32 instead of
max_num_batched_tokens, since encoder input shapes are unpredictable.
This is typically used for vision encoder sub-modules in multimodal models.
`shape_invariants` is a function that gets compiled right before forward.
The function should have the torch._check calls that are needed to set
the relationships between different input sizes. For example:
@@ -226,6 +232,7 @@ def support_torch_compile(
inferred_dynamic_arg_dims,
mark_unbacked_dims,
enable_if,
is_encoder,
shape_invariants,
)
@@ -316,6 +323,7 @@ def _support_torch_compile(
dynamic_arg_dims: dict[str, int | list[int]],
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
is_encoder: bool = False,
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
) -> type[_T]:
"""
@@ -345,8 +353,7 @@ def _support_torch_compile(
vllm_config = get_current_vllm_config()
# NOTE: to support multimodal models (such as encoder),
# we may not have vllm_config so we may need to patch
# it
# we may not have vllm_config so we may need to patch it
sig = inspect.signature(old_init)
if "vllm_config" in sig.parameters:
kwargs["vllm_config"] = vllm_config
@@ -374,7 +381,11 @@ def _support_torch_compile(
self.compiled = False
# Handled by monkeypatching `TorchCompileWithNoGuardsWrapper` into base class
TorchCompileWithNoGuardsWrapper.__init__(self)
TorchCompileWithNoGuardsWrapper.__init__(
self,
compile_prefix=cls.__name__ if is_encoder else "",
is_encoder=is_encoder,
)
cls.__init__ = __init__

View File

@@ -75,8 +75,14 @@ class TorchCompileWithNoGuardsWrapper:
return ctx.result
return callable_fn(*args, **kwargs)
def __init__(self) -> None:
def __init__(
self,
compile_prefix: str = "",
is_encoder: bool = False,
) -> None:
self.compiled = False
self._compile_prefix = compile_prefix
self._is_encoder = is_encoder
vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config
@@ -87,7 +93,9 @@ class TorchCompileWithNoGuardsWrapper:
if mode is None:
raise RuntimeError("Compilation mode cannot be NO_COMPILATION")
backend = vllm_config.compilation_config.init_backend(vllm_config)
backend = vllm_config.compilation_config.init_backend(
vllm_config, prefix=compile_prefix, is_encoder=is_encoder
)
options = {}
if isinstance(backend, str) and backend == "inductor":
@@ -332,4 +340,8 @@ def reset_compile_wrapper(model: torch.nn.Module) -> None:
compilation_config.local_cache_dir = ""
model.__class__.forward.__code__ = model.original_code_object()
TorchCompileWithNoGuardsWrapper.__init__(model)
TorchCompileWithNoGuardsWrapper.__init__(
model,
compile_prefix=model._compile_prefix,
is_encoder=model._is_encoder,
)

View File

@@ -909,11 +909,19 @@ class CompilationConfig:
if self.backend == "":
self.backend = current_platform.get_compile_backend()
def init_backend(self, vllm_config: "VllmConfig") -> str | Callable:
def init_backend(
self,
vllm_config: "VllmConfig",
prefix: str = "",
is_encoder: bool = False,
) -> str | Callable:
"""
Initialize the backend for the compilation config from a vllm config.
Arguments:
vllm_config: The vllm config to initialize the backend from.
prefix: Cache directory prefix for this compiled module.
is_encoder: Whether this module is used in an encoder (as
opposed to a text backbone).
Returns:
The backend for the compilation config.
"""
@@ -943,9 +951,7 @@ class CompilationConfig:
from vllm.compilation.backends import VllmBackend
# TODO[@lucaskabela]: See if we can forward prefix
# https://github.com/vllm-project/vllm/issues/27045
return VllmBackend(vllm_config)
return VllmBackend(vllm_config, prefix=prefix, is_encoder=is_encoder)
def post_init_cudagraph_sizes(self) -> None:
"""To complete the initialization after cudagraph related

View File

@@ -272,6 +272,7 @@ class Siglip2MLP(nn.Module):
@support_torch_compile(
dynamic_arg_dims={"hidden_states": [0, 1], "cu_seqlens": 0},
enable_if=should_torch_compile_mm_encoder,
is_encoder=True,
)
class Siglip2EncoderLayer(nn.Module):
def __init__(
@@ -395,16 +396,12 @@ class Siglip2VisionTransformer(nn.Module):
embed_dim = config.hidden_size
self.config = config
self.embeddings = Siglip2VisionEmbeddings(config)
# Keep the import local to avoid circular dependencies during model init.
from vllm.compilation.backends import set_model_tag
with set_model_tag("Siglip2Encoder", is_encoder=True):
self.encoder = Siglip2Encoder(
config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
)
self.encoder = Siglip2Encoder(
config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
)
num_hidden_layers = config.num_hidden_layers
if len(self.encoder.layers) > config.num_hidden_layers:
raise ValueError(

View File

@@ -453,7 +453,9 @@ class Llama4UnfoldConvolution(nn.Module):
@support_torch_compile(
dynamic_arg_dims={"images_flattened": 0}, enable_if=should_torch_compile_mm_encoder
dynamic_arg_dims={"images_flattened": 0},
enable_if=should_torch_compile_mm_encoder,
is_encoder=True,
)
class Llama4VisionModel(nn.Module):
def __init__(
@@ -754,12 +756,7 @@ class Llama4ForConditionalGeneration(
self.multimodal_config = multimodal_config
with self._mark_tower_model(vllm_config, "image"):
from vllm.compilation.backends import set_model_tag
with (
set_current_vllm_config(vllm_config),
set_model_tag("Llama4VisionModel", is_encoder=True),
):
with set_current_vllm_config(vllm_config):
self.vision_model = Llama4VisionModel(
config=config.vision_config,
quant_config=None,

View File

@@ -427,6 +427,7 @@ class Qwen2_5_VisionAttention(nn.Module):
"rotary_pos_emb_sin": 0,
},
enable_if=should_torch_compile_mm_encoder,
is_encoder=True,
)
class Qwen2_5_VisionBlock(nn.Module):
def __init__(
@@ -486,6 +487,7 @@ class Qwen2_5_VisionBlock(nn.Module):
"x": 0,
},
enable_if=should_torch_compile_mm_encoder,
is_encoder=True,
)
class Qwen2_5_VisionPatchEmbed(nn.Module):
def __init__(
@@ -521,6 +523,7 @@ class Qwen2_5_VisionPatchEmbed(nn.Module):
"x": 0,
},
enable_if=should_torch_compile_mm_encoder,
is_encoder=True,
)
class Qwen2_5_VisionPatchMerger(nn.Module):
def __init__(
@@ -592,18 +595,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
self.spatial_merge_size = vision_config.spatial_merge_size
self.fullatt_block_indexes = vision_config.fullatt_block_indexes
self.spatial_merge_unit = self.spatial_merge_size**2
# TODO[@lucaskabela]: Investigate fixing this usage
# see https://github.com/vllm-project/vllm/issues/27044
# DO NOT MOVE THIS IMPORT
from vllm.compilation.backends import set_model_tag
with set_model_tag("Qwen2_5_VisionPatchEmbed", is_encoder=True):
self.patch_embed = Qwen2_5_VisionPatchEmbed(
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
in_channels=in_channels,
hidden_size=self.hidden_size,
)
self.patch_embed = Qwen2_5_VisionPatchEmbed(
patch_size=patch_size,
temporal_patch_size=temporal_patch_size,
in_channels=in_channels,
hidden_size=self.hidden_size,
)
norm_layer = partial(RMSNorm, eps=norm_eps)
head_dim = self.hidden_size // self.num_heads
@@ -619,31 +616,29 @@ class Qwen2_5_VisionTransformer(nn.Module):
dtype=torch.get_default_dtype(),
)
with set_model_tag("Qwen2_5_VisionBlock", is_encoder=True):
self.blocks = nn.ModuleList(
[
Qwen2_5_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=get_act_and_mul_fn(vision_config.hidden_act),
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
)
for layer_idx in range(depth)
]
)
self.blocks = nn.ModuleList(
[
Qwen2_5_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_hidden_dim=vision_config.intermediate_size,
act_fn=get_act_and_mul_fn(vision_config.hidden_act),
norm_layer=norm_layer,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
)
for layer_idx in range(depth)
]
)
with set_model_tag("Qwen2_5_VisionPatchMerger", is_encoder=True):
self.merger = Qwen2_5_VisionPatchMerger(
d_model=vision_config.out_hidden_size,
context_dim=self.hidden_size,
norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
prefix=f"{prefix}.merger",
)
self.merger = Qwen2_5_VisionPatchMerger(
d_model=vision_config.out_hidden_size,
context_dim=self.hidden_size,
norm_layer=norm_layer,
spatial_merge_size=self.spatial_merge_size,
quant_config=quant_config,
prefix=f"{prefix}.merger",
)
@property
def dtype(self) -> torch.dtype: