[MM][Core] Decouple ViT backend from LM backend (#27061)
Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -208,6 +208,7 @@ class Siglip2Attention(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -248,7 +249,9 @@ class Siglip2Attention(nn.Module):
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=self.head_dim, dtype=torch.get_default_dtype()
|
||||
head_size=self.head_dim,
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.use_upstream_fa = False
|
||||
|
||||
@@ -372,6 +375,7 @@ class Siglip2EncoderLayer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
@@ -381,6 +385,7 @@ class Siglip2EncoderLayer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.mlp = Siglip2MLP(
|
||||
@@ -434,6 +439,7 @@ class Siglip2Encoder(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -444,6 +450,7 @@ class Siglip2Encoder(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{idx}",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
]
|
||||
@@ -618,6 +625,7 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -629,6 +637,7 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.encoder",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
@@ -657,6 +666,7 @@ class Siglip2NavitModel(torch.nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -665,6 +675,7 @@ class Siglip2NavitModel(torch.nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.vision_model",
|
||||
use_data_parallel=use_data_parallel,
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user