[MM][Core] Decouple ViT backend from LM backend (#27061)

Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Roger Wang
2025-10-21 00:30:10 -07:00
committed by GitHub
parent 72f431e709
commit c3a2c6ac5f
16 changed files with 230 additions and 17 deletions

View File

@@ -320,6 +320,7 @@ class Qwen2VisionAttention(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
# Per attention head and per partition values.
@@ -355,6 +356,7 @@ class Qwen2VisionAttention(nn.Module):
self.attn_backend = get_vit_attn_backend(
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
self.use_upstream_fa = False
@@ -497,6 +499,7 @@ class Qwen2VisionBlock(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
if norm_layer is None:
@@ -512,6 +515,7 @@ class Qwen2VisionBlock(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
self.mlp = Qwen2VisionMLP(
dim,
@@ -662,6 +666,7 @@ class Qwen2VisionTransformer(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend_override: _Backend | None = None,
) -> None:
super().__init__()
@@ -703,6 +708,7 @@ class Qwen2VisionTransformer(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.blocks.{layer_idx}",
use_data_parallel=use_data_parallel,
attn_backend_override=attn_backend_override,
)
for layer_idx in range(depth)
]
@@ -716,7 +722,9 @@ class Qwen2VisionTransformer(nn.Module):
use_data_parallel=use_data_parallel,
)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype()
head_size=head_dim,
dtype=torch.get_default_dtype(),
attn_backend_override=attn_backend_override,
)
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
torch.get_default_dtype()
@@ -1356,12 +1364,18 @@ class Qwen2VLForConditionalGeneration(
if multimodal_config.get_limit_per_prompt(
"image"
) or multimodal_config.get_limit_per_prompt("video"):
attn_backend_override = (
multimodal_config.mm_encoder_attn_backend
if multimodal_config is not None
else None
)
self.visual = Qwen2VisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
attn_backend_override=attn_backend_override,
)
else:
self.visual = None