[MM][Core] Decouple ViT backend from LM backend (#27061)
Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -300,6 +300,7 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = vision_config.hidden_size
|
||||
@@ -359,7 +360,9 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
)
|
||||
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim, dtype=torch.get_default_dtype()
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
use_upstream_fa = False
|
||||
if (
|
||||
@@ -379,7 +382,6 @@ class Qwen3_VisionTransformer(nn.Module):
|
||||
raise RuntimeError(
|
||||
f"Qwen3-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Qwen3_VisionBlock(
|
||||
@@ -1214,12 +1216,18 @@ class Qwen3VLForConditionalGeneration(
|
||||
) and not multimodal_config.get_limit_per_prompt("video"):
|
||||
self.visual = None
|
||||
else:
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.visual = Qwen3_VisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
self.language_model = Qwen3LLMForCausalLM(
|
||||
|
||||
Reference in New Issue
Block a user