[Model] Add Granite Speech Support (#16246)
Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com> Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
@@ -60,6 +60,7 @@ class Blip2QFormerMultiHeadAttention(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
is_cross_attention: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -139,7 +140,7 @@ class Blip2QFormerMultiHeadAttention(nn.Module):
|
||||
|
||||
class Blip2QFormerSelfOutput(nn.Module):
|
||||
|
||||
def __init__(self, config: Blip2QFormerConfig) -> None:
|
||||
def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
@@ -167,6 +168,7 @@ class Blip2QFormerAttention(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
is_cross_attention: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -175,9 +177,10 @@ class Blip2QFormerAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config,
|
||||
is_cross_attention=is_cross_attention,
|
||||
prefix=f"{prefix}.attention",
|
||||
)
|
||||
|
||||
self.output = Blip2QFormerSelfOutput(config)
|
||||
self.output = Blip2QFormerSelfOutput(config, prefix=f"{prefix}.output")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -195,7 +198,7 @@ class Blip2QFormerAttention(nn.Module):
|
||||
|
||||
class Blip2QFormerIntermediate(nn.Module):
|
||||
|
||||
def __init__(self, config: Blip2QFormerConfig) -> None:
|
||||
def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
@@ -209,7 +212,7 @@ class Blip2QFormerIntermediate(nn.Module):
|
||||
|
||||
class Blip2QFormerOutput(nn.Module):
|
||||
|
||||
def __init__(self, config: Blip2QFormerConfig) -> None:
|
||||
def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
@@ -237,6 +240,7 @@ class Blip2QFormerLayer(nn.Module):
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
layer_idx: int,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -244,7 +248,8 @@ class Blip2QFormerLayer(nn.Module):
|
||||
self.seq_len_dim = 1
|
||||
self.attention = Blip2QFormerAttention(config,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.attention")
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
@@ -253,13 +258,16 @@ class Blip2QFormerLayer(nn.Module):
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config,
|
||||
is_cross_attention=True)
|
||||
is_cross_attention=True,
|
||||
prefix=f"{prefix}.crossattention")
|
||||
self.has_cross_attention = True
|
||||
else:
|
||||
self.has_cross_attention = False
|
||||
|
||||
self.intermediate_query = Blip2QFormerIntermediate(config)
|
||||
self.output_query = Blip2QFormerOutput(config)
|
||||
self.intermediate_query = Blip2QFormerIntermediate(
|
||||
config, prefix=f"{prefix}.intermediate_query")
|
||||
self.output_query = Blip2QFormerOutput(config,
|
||||
prefix=f"{prefix}.output_query")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -325,6 +333,7 @@ class Blip2QFormerEncoder(nn.Module):
|
||||
*,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -334,7 +343,8 @@ class Blip2QFormerEncoder(nn.Module):
|
||||
Blip2QFormerLayer(config,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config,
|
||||
layer_idx=layer_idx)
|
||||
layer_idx=layer_idx,
|
||||
prefix=f"{prefix}.layer.{layer_idx}")
|
||||
for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
|
||||
@@ -365,6 +375,7 @@ class Blip2QFormerModel(nn.Module):
|
||||
*,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -376,7 +387,8 @@ class Blip2QFormerModel(nn.Module):
|
||||
|
||||
self.encoder = Blip2QFormerEncoder(config,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.encoder")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -511,7 +523,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
|
||||
self.qformer = Blip2QFormerModel(config.qformer_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qformer")
|
||||
|
||||
self.language_projection = nn.Linear(
|
||||
config.qformer_config.hidden_size,
|
||||
|
||||
Reference in New Issue
Block a user