[VLM] Support caching in merged multi-modal processor (#11396)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-28 01:22:48 +08:00
committed by GitHub
parent 5ce4627a7e
commit 101418096f
20 changed files with 1459 additions and 452 deletions

View File

@@ -225,7 +225,7 @@ class VisualAttentionBlock(nn.Module):
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
norm_layer: Callable = nn.LayerNorm,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
@@ -266,7 +266,7 @@ class TransformerBlock(nn.Module):
layers: int,
heads: int,
mlp_ratio: float = 4.0,
norm_layer: Callable = nn.LayerNorm,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()