[Model] Add Granite Speech Support (#16246)

Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com>
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
Alex Brooks
2025-04-28 04:05:00 -06:00
committed by GitHub
parent aec9674dbe
commit fa93cd9f60
11 changed files with 1025 additions and 28 deletions

View File

@@ -60,6 +60,7 @@ class Blip2QFormerMultiHeadAttention(nn.Module):
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
is_cross_attention: bool = False,
prefix: str = "",
) -> None:
super().__init__()
@@ -139,7 +140,7 @@ class Blip2QFormerMultiHeadAttention(nn.Module):
class Blip2QFormerSelfOutput(nn.Module):
def __init__(self, config: Blip2QFormerConfig) -> None:
def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
@@ -167,6 +168,7 @@ class Blip2QFormerAttention(nn.Module):
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
is_cross_attention: bool = False,
prefix: str = "",
) -> None:
super().__init__()
@@ -175,9 +177,10 @@ class Blip2QFormerAttention(nn.Module):
quant_config=quant_config,
cache_config=cache_config,
is_cross_attention=is_cross_attention,
prefix=f"{prefix}.attention",
)
self.output = Blip2QFormerSelfOutput(config)
self.output = Blip2QFormerSelfOutput(config, prefix=f"{prefix}.output")
def forward(
self,
@@ -195,7 +198,7 @@ class Blip2QFormerAttention(nn.Module):
class Blip2QFormerIntermediate(nn.Module):
def __init__(self, config: Blip2QFormerConfig) -> None:
def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
@@ -209,7 +212,7 @@ class Blip2QFormerIntermediate(nn.Module):
class Blip2QFormerOutput(nn.Module):
def __init__(self, config: Blip2QFormerConfig) -> None:
def __init__(self, config: Blip2QFormerConfig, prefix: str = "") -> None:
super().__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
@@ -237,6 +240,7 @@ class Blip2QFormerLayer(nn.Module):
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
layer_idx: int,
prefix: str = "",
) -> None:
super().__init__()
@@ -244,7 +248,8 @@ class Blip2QFormerLayer(nn.Module):
self.seq_len_dim = 1
self.attention = Blip2QFormerAttention(config,
quant_config=quant_config,
cache_config=cache_config)
cache_config=cache_config,
prefix=f"{prefix}.attention")
self.layer_idx = layer_idx
@@ -253,13 +258,16 @@ class Blip2QFormerLayer(nn.Module):
config,
quant_config=quant_config,
cache_config=cache_config,
is_cross_attention=True)
is_cross_attention=True,
prefix=f"{prefix}.crossattention")
self.has_cross_attention = True
else:
self.has_cross_attention = False
self.intermediate_query = Blip2QFormerIntermediate(config)
self.output_query = Blip2QFormerOutput(config)
self.intermediate_query = Blip2QFormerIntermediate(
config, prefix=f"{prefix}.intermediate_query")
self.output_query = Blip2QFormerOutput(config,
prefix=f"{prefix}.output_query")
def forward(
self,
@@ -325,6 +333,7 @@ class Blip2QFormerEncoder(nn.Module):
*,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
prefix: str = "",
) -> None:
super().__init__()
@@ -334,7 +343,8 @@ class Blip2QFormerEncoder(nn.Module):
Blip2QFormerLayer(config,
quant_config=quant_config,
cache_config=cache_config,
layer_idx=layer_idx)
layer_idx=layer_idx,
prefix=f"{prefix}.layer.{layer_idx}")
for layer_idx in range(config.num_hidden_layers)
])
@@ -365,6 +375,7 @@ class Blip2QFormerModel(nn.Module):
*,
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
prefix: str = "",
) -> None:
super().__init__()
@@ -376,7 +387,8 @@ class Blip2QFormerModel(nn.Module):
self.encoder = Blip2QFormerEncoder(config,
quant_config=quant_config,
cache_config=cache_config)
cache_config=cache_config,
prefix=f"{prefix}.encoder")
def forward(
self,
@@ -511,7 +523,8 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self.qformer = Blip2QFormerModel(config.qformer_config,
cache_config=cache_config,
quant_config=quant_config)
quant_config=quant_config,
prefix=f"{prefix}.qformer")
self.language_projection = nn.Linear(
config.qformer_config.hidden_size,