Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -23,6 +23,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only IBM Granite speech model."""
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping
|
||||
from typing import Annotated, Optional, Union
|
||||
@@ -34,25 +35,37 @@ from transformers import BatchFeature, PretrainedConfig
|
||||
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate)
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
AudioProcessorItems,
|
||||
MultiModalDataItems,
|
||||
MultiModalDataParser,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .blip2 import Blip2QFormerModel
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsLoRA,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||
|
||||
|
||||
@@ -60,7 +73,7 @@ from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||
class GraniteSpeechAudioInputs(TensorSchema):
|
||||
"""
|
||||
Audio input features for Granite Speech model.
|
||||
|
||||
|
||||
Dimensions:
|
||||
- b: Batch size
|
||||
- fi: Number of input features from the Mel spectrogram.
|
||||
@@ -79,7 +92,6 @@ class GraniteSpeechAudioInputs(TensorSchema):
|
||||
|
||||
|
||||
class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"audio": 1}
|
||||
|
||||
@@ -96,8 +108,8 @@ class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
### Input Processing & Multimodal utils
|
||||
class GraniteSpeechMultiModalProcessor(
|
||||
BaseMultiModalProcessor[GraniteSpeechMultiModalProcessingInfo]):
|
||||
|
||||
BaseMultiModalProcessor[GraniteSpeechMultiModalProcessingInfo]
|
||||
):
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self.info.get_hf_processor().audio_processor
|
||||
sampling_rate = feature_extractor.melspec_kwargs["sample_rate"]
|
||||
@@ -133,7 +145,8 @@ class GraniteSpeechMultiModalProcessor(
|
||||
audio = audios.get(item_idx)
|
||||
audio_length = audio.shape[-1]
|
||||
num_projector_features = feature_extractor._get_num_audio_features(
|
||||
[audio_length])[0]
|
||||
[audio_length]
|
||||
)[0]
|
||||
return [audio_token_id] * num_projector_features
|
||||
|
||||
return [
|
||||
@@ -170,14 +183,15 @@ class GraniteSpeechMultiModalProcessor(
|
||||
# This is used to split the batch back out after padding.
|
||||
audio_token_index = self.info.get_hf_config().audio_token_index
|
||||
processed_outputs["audio_embed_sizes"] = (
|
||||
processed_outputs["input_ids"] == audio_token_index).sum(-1)
|
||||
processed_outputs["input_ids"] == audio_token_index
|
||||
).sum(-1)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
|
||||
class GraniteSpeechDummyInputsBuilder(
|
||||
BaseDummyInputsBuilder[GraniteSpeechMultiModalProcessingInfo]):
|
||||
|
||||
BaseDummyInputsBuilder[GraniteSpeechMultiModalProcessingInfo]
|
||||
):
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
@@ -188,8 +202,7 @@ class GraniteSpeechDummyInputsBuilder(
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
|
||||
return {
|
||||
"audio":
|
||||
self._get_dummy_audios(
|
||||
"audio": self._get_dummy_audios(
|
||||
length=self.info.get_max_audio_len(),
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides,
|
||||
@@ -205,7 +218,6 @@ class GraniteSpeechDummyInputsBuilder(
|
||||
|
||||
### QFormer Projector
|
||||
class GraniteSpeechEncoderProjector(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
@@ -220,8 +232,8 @@ class GraniteSpeechEncoderProjector(nn.Module):
|
||||
self.num_queries = config.window_size // config.downsample_rate
|
||||
|
||||
self.query = nn.Parameter(
|
||||
torch.zeros(1, self.num_queries,
|
||||
config.projector_config.hidden_size))
|
||||
torch.zeros(1, self.num_queries, config.projector_config.hidden_size)
|
||||
)
|
||||
|
||||
# NOTE - this is implemented generically in transformers,
|
||||
# but for now we create the QFormer model directly since
|
||||
@@ -232,17 +244,16 @@ class GraniteSpeechEncoderProjector(nn.Module):
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.qformer",
|
||||
)
|
||||
self.linear = nn.Linear(config.projector_config.hidden_size,
|
||||
config.text_config.hidden_size)
|
||||
self.linear = nn.Linear(
|
||||
config.projector_config.hidden_size, config.text_config.hidden_size
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, seq_len, dim = hidden_states.size()
|
||||
nblocks = math.ceil(seq_len / self.window_size)
|
||||
pad = nblocks * self.window_size - seq_len
|
||||
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad),
|
||||
"constant", 0)
|
||||
hidden_states = hidden_states.view(batch_size * nblocks,
|
||||
self.window_size, dim)
|
||||
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad), "constant", 0)
|
||||
hidden_states = hidden_states.view(batch_size * nblocks, self.window_size, dim)
|
||||
|
||||
last_hidden_state = self.qformer(
|
||||
query_embeds=self.query.data,
|
||||
@@ -254,7 +265,8 @@ class GraniteSpeechEncoderProjector(nn.Module):
|
||||
batch_size,
|
||||
nblocks * self.window_size // self.downsample_rate,
|
||||
-1,
|
||||
))
|
||||
)
|
||||
)
|
||||
return query_proj
|
||||
|
||||
|
||||
@@ -264,10 +276,12 @@ class GraniteSpeechEncoderProjector(nn.Module):
|
||||
class GraniteSpeechConformerFeedForward(nn.Module):
|
||||
"""Feedforward module for conformer encoder blocks."""
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.pre_norm = nn.LayerNorm(config.hidden_dim)
|
||||
|
||||
@@ -313,16 +327,16 @@ class GraniteSpeechConformerAttention(nn.Module):
|
||||
self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, config.hidden_dim)
|
||||
self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1,
|
||||
self.dim_head)
|
||||
self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, self.dim_head)
|
||||
|
||||
if self.context_size <= 0 or self.context_size > self.max_pos_emb:
|
||||
raise ValueError(
|
||||
"Context size is either less than 0 or exceeds the max_pos_emb"
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
attention_dists: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, attention_dists: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.pre_norm(hidden_states)
|
||||
bsz, num_features, _ = hidden_states.shape
|
||||
|
||||
@@ -331,47 +345,53 @@ class GraniteSpeechConformerAttention(nn.Module):
|
||||
if remainder > 0:
|
||||
# right padding to reach block size
|
||||
hidden_states = torch.nn.functional.pad(
|
||||
hidden_states, (0, 0, 0, self.context_size - remainder))
|
||||
hidden_states, (0, 0, 0, self.context_size - remainder)
|
||||
)
|
||||
|
||||
# NOTE: would be nice to try to use qkvparallellinear
|
||||
# here for this block attention implementation if possible
|
||||
query_states = self.to_q(hidden_states)
|
||||
key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1)
|
||||
|
||||
query_states = query_states.reshape(bsz, num_blocks, self.context_size,
|
||||
self.num_heads,
|
||||
-1).transpose(2, 3)
|
||||
key_states = key_states.reshape(bsz, num_blocks, self.context_size,
|
||||
self.num_heads, -1).transpose(2, 3)
|
||||
value_states = value_states.reshape(bsz, num_blocks, self.context_size,
|
||||
self.num_heads,
|
||||
-1).transpose(2, 3)
|
||||
query_states = query_states.reshape(
|
||||
bsz, num_blocks, self.context_size, self.num_heads, -1
|
||||
).transpose(2, 3)
|
||||
key_states = key_states.reshape(
|
||||
bsz, num_blocks, self.context_size, self.num_heads, -1
|
||||
).transpose(2, 3)
|
||||
value_states = value_states.reshape(
|
||||
bsz, num_blocks, self.context_size, self.num_heads, -1
|
||||
).transpose(2, 3)
|
||||
|
||||
# shaw's relative positional embedding
|
||||
dist = attention_dists.to(hidden_states.device)
|
||||
rel_pos_emb = self.rel_pos_emb(dist)
|
||||
rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] +
|
||||
list(rel_pos_emb.shape))
|
||||
pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded,
|
||||
dim=-1) * self.scale
|
||||
rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
|
||||
pos_attn = (
|
||||
torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1)
|
||||
* self.scale
|
||||
)
|
||||
|
||||
if remainder > 0:
|
||||
# masked attention in the extended block
|
||||
mask = torch.ones(self.context_size,
|
||||
self.context_size,
|
||||
dtype=bool,
|
||||
device=hidden_states.device)
|
||||
mask = torch.ones(
|
||||
self.context_size,
|
||||
self.context_size,
|
||||
dtype=bool,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
mask[:remainder, :remainder] = 0
|
||||
mask_value = -torch.finfo(pos_attn.dtype).max
|
||||
pos_attn[:, -1, :].masked_fill_(mask, mask_value)
|
||||
|
||||
with torch.nn.attention.sdpa_kernel(
|
||||
torch.nn.attention.SDPBackend.MATH):
|
||||
out = F.scaled_dot_product_attention(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=pos_attn,
|
||||
scale=self.scale)
|
||||
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
|
||||
out = F.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=pos_attn,
|
||||
scale=self.scale,
|
||||
)
|
||||
out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1)
|
||||
return self.to_out(out[:, :num_features, :])
|
||||
|
||||
@@ -379,22 +399,16 @@ class GraniteSpeechConformerAttention(nn.Module):
|
||||
class GraniteSpeechConformerDepthWiseConv1d(nn.Module):
|
||||
"""Wrapper for padded 1D pointwise convolution."""
|
||||
|
||||
def __init__(self,
|
||||
chan_in: int,
|
||||
chan_out: int,
|
||||
kernel_size: int,
|
||||
prefix: str = ""):
|
||||
def __init__(self, chan_in: int, chan_out: int, kernel_size: int, prefix: str = ""):
|
||||
super().__init__()
|
||||
# Padding for the 1D conv is symmetric or close (i.e., offset by one).
|
||||
pad = kernel_size // 2
|
||||
pad_offset = (kernel_size + 1) % 2
|
||||
self.padding = (pad, pad - pad_offset)
|
||||
|
||||
self.conv = nn.Conv1d(chan_in,
|
||||
chan_out,
|
||||
kernel_size,
|
||||
groups=chan_in,
|
||||
bias=False)
|
||||
self.conv = nn.Conv1d(
|
||||
chan_in, chan_out, kernel_size, groups=chan_in, bias=False
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = F.pad(hidden_states, self.padding)
|
||||
@@ -439,21 +453,19 @@ class GraniteSpeechConformerBlock(nn.Module):
|
||||
|
||||
def __init__(self, config: PretrainedConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.ff1 = GraniteSpeechConformerFeedForward(config,
|
||||
prefix=f"{prefix}.ff1")
|
||||
self.attn = GraniteSpeechConformerAttention(config,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.conv = GraniteSpeechConformerConvModule(config,
|
||||
prefix=f"{prefix}.conv")
|
||||
self.ff2 = GraniteSpeechConformerFeedForward(config,
|
||||
prefix=f"{prefix}.ff2")
|
||||
self.ff1 = GraniteSpeechConformerFeedForward(config, prefix=f"{prefix}.ff1")
|
||||
self.attn = GraniteSpeechConformerAttention(config, prefix=f"{prefix}.attn")
|
||||
self.conv = GraniteSpeechConformerConvModule(config, prefix=f"{prefix}.conv")
|
||||
self.ff2 = GraniteSpeechConformerFeedForward(config, prefix=f"{prefix}.ff2")
|
||||
self.post_norm = nn.LayerNorm(config.hidden_dim)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
attention_dists: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, attention_dists: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states
|
||||
hidden_states = self.attn(
|
||||
hidden_states, attention_dists=attention_dists) + hidden_states
|
||||
hidden_states = (
|
||||
self.attn(hidden_states, attention_dists=attention_dists) + hidden_states
|
||||
)
|
||||
hidden_states = self.conv(hidden_states) + hidden_states
|
||||
hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states
|
||||
hidden_states = self.post_norm(hidden_states)
|
||||
@@ -463,29 +475,33 @@ class GraniteSpeechConformerBlock(nn.Module):
|
||||
class GraniteSpeechCTCEncoder(nn.Module):
|
||||
"""CTC Encoder comprising conformer blocks and additional linear layers."""
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
# Precompute clamped relative positional encoding distances
|
||||
seq = torch.arange(config.context_size)
|
||||
relpos_dist = seq.view(-1, 1) - seq.view(1, -1)
|
||||
self.attention_dists = torch.clamp(
|
||||
relpos_dist, -config.context_size,
|
||||
config.context_size) + config.max_pos_emb
|
||||
self.attention_dists = (
|
||||
torch.clamp(relpos_dist, -config.context_size, config.context_size)
|
||||
+ config.max_pos_emb
|
||||
)
|
||||
|
||||
self.input_linear = nn.Linear(config.input_dim,
|
||||
config.hidden_dim,
|
||||
bias=True)
|
||||
self.layers = nn.ModuleList([
|
||||
GraniteSpeechConformerBlock(
|
||||
config,
|
||||
prefix=f"{prefix}.layers.{idx}",
|
||||
) for idx in range(config.num_layers)
|
||||
])
|
||||
self.input_linear = nn.Linear(config.input_dim, config.hidden_dim, bias=True)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
GraniteSpeechConformerBlock(
|
||||
config,
|
||||
prefix=f"{prefix}.layers.{idx}",
|
||||
)
|
||||
for idx in range(config.num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.out = ColumnParallelLinear(
|
||||
input_size=config.hidden_dim,
|
||||
@@ -508,8 +524,7 @@ class GraniteSpeechCTCEncoder(nn.Module):
|
||||
def forward(self, hidden_states: torch.Tensor):
|
||||
hidden_states = self.input_linear(hidden_states)
|
||||
for idx, layer in enumerate(self.layers, start=1):
|
||||
hidden_states = layer(hidden_states,
|
||||
attention_dists=self.attention_dists)
|
||||
hidden_states = layer(hidden_states, attention_dists=self.attention_dists)
|
||||
|
||||
if idx == self.num_layers // 2:
|
||||
hidden_states_mid = hidden_states.clone()
|
||||
@@ -523,12 +538,13 @@ class GraniteSpeechCTCEncoder(nn.Module):
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
GraniteSpeechMultiModalProcessor,
|
||||
info=GraniteSpeechMultiModalProcessingInfo,
|
||||
dummy_inputs=GraniteSpeechDummyInputsBuilder)
|
||||
dummy_inputs=GraniteSpeechDummyInputsBuilder,
|
||||
)
|
||||
class GraniteSpeechForConditionalGeneration(
|
||||
nn.Module,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
SupportsLoRA,
|
||||
nn.Module,
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
SupportsLoRA,
|
||||
):
|
||||
merge_by_field_config = True
|
||||
|
||||
@@ -584,7 +600,8 @@ class GraniteSpeechForConditionalGeneration(
|
||||
)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
self.language_model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self,
|
||||
@@ -602,17 +619,21 @@ class GraniteSpeechForConditionalGeneration(
|
||||
# from the processor, but we handle rebuilding it here since
|
||||
# vLLM generally processes everything independently + batches.
|
||||
if input_features_mask is None:
|
||||
input_features_mask = self._build_input_features_mask(
|
||||
audio_embed_sizes)
|
||||
input_features_mask = self._build_input_features_mask(audio_embed_sizes)
|
||||
|
||||
if not isinstance(input_features, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio input features. "
|
||||
f"Got type: {type(input_features)}")
|
||||
raise ValueError(
|
||||
"Incorrect type of audio input features. "
|
||||
f"Got type: {type(input_features)}"
|
||||
)
|
||||
|
||||
if input_features_mask is not None and not isinstance(
|
||||
input_features_mask, torch.Tensor):
|
||||
raise ValueError("Incorrect type of audio input features mask. "
|
||||
f"Got type: {type(input_features_mask)}")
|
||||
input_features_mask, torch.Tensor
|
||||
):
|
||||
raise ValueError(
|
||||
"Incorrect type of audio input features mask. "
|
||||
f"Got type: {type(input_features_mask)}"
|
||||
)
|
||||
|
||||
if isinstance(input_features, torch.Tensor):
|
||||
# Granite speech currently only allows one audio token per instance
|
||||
@@ -625,16 +646,17 @@ class GraniteSpeechForConditionalGeneration(
|
||||
if len(input_features.shape) != 3:
|
||||
raise ValueError(
|
||||
"Squeezed input features should be 3D but are of shape "
|
||||
f"{input_features.shape}")
|
||||
input_features = input_features.to(
|
||||
self.encoder.input_linear.weight.dtype)
|
||||
f"{input_features.shape}"
|
||||
)
|
||||
input_features = input_features.to(self.encoder.input_linear.weight.dtype)
|
||||
|
||||
else:
|
||||
# Otherwise we have a list of tensors, which are almost certainly
|
||||
# differing in their respective numbers of audio features;
|
||||
# stack them into a 3D tensor of size [bsz, most_num_features, 160].
|
||||
input_features = self._pad_and_stack_input_features(
|
||||
input_features, ).to(self.encoder.input_linear.weight.dtype)
|
||||
input_features,
|
||||
).to(self.encoder.input_linear.weight.dtype)
|
||||
|
||||
return GraniteSpeechAudioInputs(
|
||||
input_features=input_features,
|
||||
@@ -706,7 +728,7 @@ class GraniteSpeechForConditionalGeneration(
|
||||
audio_input: GraniteSpeechAudioInputs,
|
||||
) -> tuple[torch.Tensor]:
|
||||
"""Compute the audio features to be merged into the LLM embeddings.
|
||||
|
||||
|
||||
Args:
|
||||
audio_input: GraniteSpeechAudioInputs
|
||||
Audio inputs object containing Mel features, an input features
|
||||
@@ -769,8 +791,9 @@ class GraniteSpeechForConditionalGeneration(
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
model_output = self.language_model(input_ids, positions,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
model_output = self.language_model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds
|
||||
)
|
||||
return model_output
|
||||
|
||||
def compute_logits(
|
||||
|
||||
Reference in New Issue
Block a user