[MM][Core] Decouple ViT backend from LM backend (#27061)
Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -247,6 +247,7 @@ class Glm4vVisionAttention(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
@@ -287,6 +288,7 @@ class Glm4vVisionAttention(nn.Module):
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.use_upstream_fa = False
|
||||
|
||||
@@ -417,6 +419,7 @@ class Glm4vVisionBlock(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
@@ -430,6 +433,7 @@ class Glm4vVisionBlock(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.mlp = Glm4vVisionMLP(
|
||||
dim,
|
||||
@@ -696,6 +700,7 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -731,6 +736,7 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{layer_idx}",
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
for layer_idx in range(depth)
|
||||
]
|
||||
@@ -759,7 +765,9 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
)
|
||||
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim, dtype=torch.get_default_dtype()
|
||||
head_size=head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||
torch.get_default_dtype()
|
||||
@@ -1437,12 +1445,18 @@ class Glm4vForConditionalGeneration(
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
|
||||
attn_backend_override = (
|
||||
multimodal_config.mm_encoder_attn_backend
|
||||
if multimodal_config is not None
|
||||
else None
|
||||
)
|
||||
self.visual = Glm4vVisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
if config.model_type == "glm4v":
|
||||
|
||||
Reference in New Issue
Block a user