Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -23,6 +23,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only IBM Granite speech model."""
import math
from collections.abc import Iterable, Mapping
from typing import Annotated, Optional, Union
@@ -34,25 +35,37 @@ from transformers import BatchFeature, PretrainedConfig
from vllm.config import CacheConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
AudioProcessorItems,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .blip2 import Blip2QFormerModel
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsPP,
)
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
@@ -60,7 +73,7 @@ from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
class GraniteSpeechAudioInputs(TensorSchema):
"""
Audio input features for Granite Speech model.
Dimensions:
- b: Batch size
- fi: Number of input features from the Mel spectrogram.
@@ -79,7 +92,6 @@ class GraniteSpeechAudioInputs(TensorSchema):
class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": 1}
@@ -96,8 +108,8 @@ class GraniteSpeechMultiModalProcessingInfo(BaseProcessingInfo):
### Input Processing & Multimodal utils
class GraniteSpeechMultiModalProcessor(
BaseMultiModalProcessor[GraniteSpeechMultiModalProcessingInfo]):
BaseMultiModalProcessor[GraniteSpeechMultiModalProcessingInfo]
):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_hf_processor().audio_processor
sampling_rate = feature_extractor.melspec_kwargs["sample_rate"]
@@ -133,7 +145,8 @@ class GraniteSpeechMultiModalProcessor(
audio = audios.get(item_idx)
audio_length = audio.shape[-1]
num_projector_features = feature_extractor._get_num_audio_features(
[audio_length])[0]
[audio_length]
)[0]
return [audio_token_id] * num_projector_features
return [
@@ -170,14 +183,15 @@ class GraniteSpeechMultiModalProcessor(
# This is used to split the batch back out after padding.
audio_token_index = self.info.get_hf_config().audio_token_index
processed_outputs["audio_embed_sizes"] = (
processed_outputs["input_ids"] == audio_token_index).sum(-1)
processed_outputs["input_ids"] == audio_token_index
).sum(-1)
return processed_outputs
class GraniteSpeechDummyInputsBuilder(
BaseDummyInputsBuilder[GraniteSpeechMultiModalProcessingInfo]):
BaseDummyInputsBuilder[GraniteSpeechMultiModalProcessingInfo]
):
def get_dummy_mm_data(
self,
seq_len: int,
@@ -188,8 +202,7 @@ class GraniteSpeechDummyInputsBuilder(
audio_overrides = mm_options.get("audio") if mm_options else None
return {
"audio":
self._get_dummy_audios(
"audio": self._get_dummy_audios(
length=self.info.get_max_audio_len(),
num_audios=num_audios,
overrides=audio_overrides,
@@ -205,7 +218,6 @@ class GraniteSpeechDummyInputsBuilder(
### QFormer Projector
class GraniteSpeechEncoderProjector(nn.Module):
def __init__(
self,
config: PretrainedConfig,
@@ -220,8 +232,8 @@ class GraniteSpeechEncoderProjector(nn.Module):
self.num_queries = config.window_size // config.downsample_rate
self.query = nn.Parameter(
torch.zeros(1, self.num_queries,
config.projector_config.hidden_size))
torch.zeros(1, self.num_queries, config.projector_config.hidden_size)
)
# NOTE - this is implemented generically in transformers,
# but for now we create the QFormer model directly since
@@ -232,17 +244,16 @@ class GraniteSpeechEncoderProjector(nn.Module):
cache_config=cache_config,
prefix=f"{prefix}.qformer",
)
self.linear = nn.Linear(config.projector_config.hidden_size,
config.text_config.hidden_size)
self.linear = nn.Linear(
config.projector_config.hidden_size, config.text_config.hidden_size
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, seq_len, dim = hidden_states.size()
nblocks = math.ceil(seq_len / self.window_size)
pad = nblocks * self.window_size - seq_len
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad),
"constant", 0)
hidden_states = hidden_states.view(batch_size * nblocks,
self.window_size, dim)
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad), "constant", 0)
hidden_states = hidden_states.view(batch_size * nblocks, self.window_size, dim)
last_hidden_state = self.qformer(
query_embeds=self.query.data,
@@ -254,7 +265,8 @@ class GraniteSpeechEncoderProjector(nn.Module):
batch_size,
nblocks * self.window_size // self.downsample_rate,
-1,
))
)
)
return query_proj
@@ -264,10 +276,12 @@ class GraniteSpeechEncoderProjector(nn.Module):
class GraniteSpeechConformerFeedForward(nn.Module):
"""Feedforward module for conformer encoder blocks."""
def __init__(self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.pre_norm = nn.LayerNorm(config.hidden_dim)
@@ -313,16 +327,16 @@ class GraniteSpeechConformerAttention(nn.Module):
self.to_q = nn.Linear(config.hidden_dim, inner_dim, bias=False)
self.to_kv = nn.Linear(config.hidden_dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, config.hidden_dim)
self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1,
self.dim_head)
self.rel_pos_emb = nn.Embedding(2 * self.max_pos_emb + 1, self.dim_head)
if self.context_size <= 0 or self.context_size > self.max_pos_emb:
raise ValueError(
"Context size is either less than 0 or exceeds the max_pos_emb"
)
def forward(self, hidden_states: torch.Tensor,
attention_dists: torch.Tensor) -> torch.Tensor:
def forward(
self, hidden_states: torch.Tensor, attention_dists: torch.Tensor
) -> torch.Tensor:
hidden_states = self.pre_norm(hidden_states)
bsz, num_features, _ = hidden_states.shape
@@ -331,47 +345,53 @@ class GraniteSpeechConformerAttention(nn.Module):
if remainder > 0:
# right padding to reach block size
hidden_states = torch.nn.functional.pad(
hidden_states, (0, 0, 0, self.context_size - remainder))
hidden_states, (0, 0, 0, self.context_size - remainder)
)
# NOTE: would be nice to try to use qkvparallellinear
# here for this block attention implementation if possible
query_states = self.to_q(hidden_states)
key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1)
query_states = query_states.reshape(bsz, num_blocks, self.context_size,
self.num_heads,
-1).transpose(2, 3)
key_states = key_states.reshape(bsz, num_blocks, self.context_size,
self.num_heads, -1).transpose(2, 3)
value_states = value_states.reshape(bsz, num_blocks, self.context_size,
self.num_heads,
-1).transpose(2, 3)
query_states = query_states.reshape(
bsz, num_blocks, self.context_size, self.num_heads, -1
).transpose(2, 3)
key_states = key_states.reshape(
bsz, num_blocks, self.context_size, self.num_heads, -1
).transpose(2, 3)
value_states = value_states.reshape(
bsz, num_blocks, self.context_size, self.num_heads, -1
).transpose(2, 3)
# shaw's relative positional embedding
dist = attention_dists.to(hidden_states.device)
rel_pos_emb = self.rel_pos_emb(dist)
rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] +
list(rel_pos_emb.shape))
pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded,
dim=-1) * self.scale
rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
pos_attn = (
torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1)
* self.scale
)
if remainder > 0:
# masked attention in the extended block
mask = torch.ones(self.context_size,
self.context_size,
dtype=bool,
device=hidden_states.device)
mask = torch.ones(
self.context_size,
self.context_size,
dtype=bool,
device=hidden_states.device,
)
mask[:remainder, :remainder] = 0
mask_value = -torch.finfo(pos_attn.dtype).max
pos_attn[:, -1, :].masked_fill_(mask, mask_value)
with torch.nn.attention.sdpa_kernel(
torch.nn.attention.SDPBackend.MATH):
out = F.scaled_dot_product_attention(query_states,
key_states,
value_states,
attn_mask=pos_attn,
scale=self.scale)
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
out = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=pos_attn,
scale=self.scale,
)
out = out.transpose(2, 3).reshape(bsz, hidden_states.shape[1], -1)
return self.to_out(out[:, :num_features, :])
@@ -379,22 +399,16 @@ class GraniteSpeechConformerAttention(nn.Module):
class GraniteSpeechConformerDepthWiseConv1d(nn.Module):
"""Wrapper for padded 1D pointwise convolution."""
def __init__(self,
chan_in: int,
chan_out: int,
kernel_size: int,
prefix: str = ""):
def __init__(self, chan_in: int, chan_out: int, kernel_size: int, prefix: str = ""):
super().__init__()
# Padding for the 1D conv is symmetric or close (i.e., offset by one).
pad = kernel_size // 2
pad_offset = (kernel_size + 1) % 2
self.padding = (pad, pad - pad_offset)
self.conv = nn.Conv1d(chan_in,
chan_out,
kernel_size,
groups=chan_in,
bias=False)
self.conv = nn.Conv1d(
chan_in, chan_out, kernel_size, groups=chan_in, bias=False
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = F.pad(hidden_states, self.padding)
@@ -439,21 +453,19 @@ class GraniteSpeechConformerBlock(nn.Module):
def __init__(self, config: PretrainedConfig, prefix: str = ""):
super().__init__()
self.ff1 = GraniteSpeechConformerFeedForward(config,
prefix=f"{prefix}.ff1")
self.attn = GraniteSpeechConformerAttention(config,
prefix=f"{prefix}.attn")
self.conv = GraniteSpeechConformerConvModule(config,
prefix=f"{prefix}.conv")
self.ff2 = GraniteSpeechConformerFeedForward(config,
prefix=f"{prefix}.ff2")
self.ff1 = GraniteSpeechConformerFeedForward(config, prefix=f"{prefix}.ff1")
self.attn = GraniteSpeechConformerAttention(config, prefix=f"{prefix}.attn")
self.conv = GraniteSpeechConformerConvModule(config, prefix=f"{prefix}.conv")
self.ff2 = GraniteSpeechConformerFeedForward(config, prefix=f"{prefix}.ff2")
self.post_norm = nn.LayerNorm(config.hidden_dim)
def forward(self, hidden_states: torch.Tensor,
attention_dists: torch.Tensor) -> torch.Tensor:
def forward(
self, hidden_states: torch.Tensor, attention_dists: torch.Tensor
) -> torch.Tensor:
hidden_states = 0.5 * self.ff1(hidden_states) + hidden_states
hidden_states = self.attn(
hidden_states, attention_dists=attention_dists) + hidden_states
hidden_states = (
self.attn(hidden_states, attention_dists=attention_dists) + hidden_states
)
hidden_states = self.conv(hidden_states) + hidden_states
hidden_states = 0.5 * self.ff2(hidden_states) + hidden_states
hidden_states = self.post_norm(hidden_states)
@@ -463,29 +475,33 @@ class GraniteSpeechConformerBlock(nn.Module):
class GraniteSpeechCTCEncoder(nn.Module):
"""CTC Encoder comprising conformer blocks and additional linear layers."""
def __init__(self,
config: PretrainedConfig,
prefix: str,
quant_config: Optional[QuantizationConfig] = None):
def __init__(
self,
config: PretrainedConfig,
prefix: str,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
# Precompute clamped relative positional encoding distances
seq = torch.arange(config.context_size)
relpos_dist = seq.view(-1, 1) - seq.view(1, -1)
self.attention_dists = torch.clamp(
relpos_dist, -config.context_size,
config.context_size) + config.max_pos_emb
self.attention_dists = (
torch.clamp(relpos_dist, -config.context_size, config.context_size)
+ config.max_pos_emb
)
self.input_linear = nn.Linear(config.input_dim,
config.hidden_dim,
bias=True)
self.layers = nn.ModuleList([
GraniteSpeechConformerBlock(
config,
prefix=f"{prefix}.layers.{idx}",
) for idx in range(config.num_layers)
])
self.input_linear = nn.Linear(config.input_dim, config.hidden_dim, bias=True)
self.layers = nn.ModuleList(
[
GraniteSpeechConformerBlock(
config,
prefix=f"{prefix}.layers.{idx}",
)
for idx in range(config.num_layers)
]
)
self.out = ColumnParallelLinear(
input_size=config.hidden_dim,
@@ -508,8 +524,7 @@ class GraniteSpeechCTCEncoder(nn.Module):
def forward(self, hidden_states: torch.Tensor):
hidden_states = self.input_linear(hidden_states)
for idx, layer in enumerate(self.layers, start=1):
hidden_states = layer(hidden_states,
attention_dists=self.attention_dists)
hidden_states = layer(hidden_states, attention_dists=self.attention_dists)
if idx == self.num_layers // 2:
hidden_states_mid = hidden_states.clone()
@@ -523,12 +538,13 @@ class GraniteSpeechCTCEncoder(nn.Module):
@MULTIMODAL_REGISTRY.register_processor(
GraniteSpeechMultiModalProcessor,
info=GraniteSpeechMultiModalProcessingInfo,
dummy_inputs=GraniteSpeechDummyInputsBuilder)
dummy_inputs=GraniteSpeechDummyInputsBuilder,
)
class GraniteSpeechForConditionalGeneration(
nn.Module,
SupportsMultiModal,
SupportsPP,
SupportsLoRA,
nn.Module,
SupportsMultiModal,
SupportsPP,
SupportsLoRA,
):
merge_by_field_config = True
@@ -584,7 +600,8 @@ class GraniteSpeechForConditionalGeneration(
)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
self.language_model.make_empty_intermediate_tensors
)
def _parse_and_validate_audio_input(
self,
@@ -602,17 +619,21 @@ class GraniteSpeechForConditionalGeneration(
# from the processor, but we handle rebuilding it here since
# vLLM generally processes everything independently + batches.
if input_features_mask is None:
input_features_mask = self._build_input_features_mask(
audio_embed_sizes)
input_features_mask = self._build_input_features_mask(audio_embed_sizes)
if not isinstance(input_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio input features. "
f"Got type: {type(input_features)}")
raise ValueError(
"Incorrect type of audio input features. "
f"Got type: {type(input_features)}"
)
if input_features_mask is not None and not isinstance(
input_features_mask, torch.Tensor):
raise ValueError("Incorrect type of audio input features mask. "
f"Got type: {type(input_features_mask)}")
input_features_mask, torch.Tensor
):
raise ValueError(
"Incorrect type of audio input features mask. "
f"Got type: {type(input_features_mask)}"
)
if isinstance(input_features, torch.Tensor):
# Granite speech currently only allows one audio token per instance
@@ -625,16 +646,17 @@ class GraniteSpeechForConditionalGeneration(
if len(input_features.shape) != 3:
raise ValueError(
"Squeezed input features should be 3D but are of shape "
f"{input_features.shape}")
input_features = input_features.to(
self.encoder.input_linear.weight.dtype)
f"{input_features.shape}"
)
input_features = input_features.to(self.encoder.input_linear.weight.dtype)
else:
# Otherwise we have a list of tensors, which are almost certainly
# differing in their respective numbers of audio features;
# stack them into a 3D tensor of size [bsz, most_num_features, 160].
input_features = self._pad_and_stack_input_features(
input_features, ).to(self.encoder.input_linear.weight.dtype)
input_features,
).to(self.encoder.input_linear.weight.dtype)
return GraniteSpeechAudioInputs(
input_features=input_features,
@@ -706,7 +728,7 @@ class GraniteSpeechForConditionalGeneration(
audio_input: GraniteSpeechAudioInputs,
) -> tuple[torch.Tensor]:
"""Compute the audio features to be merged into the LLM embeddings.
Args:
audio_input: GraniteSpeechAudioInputs
Audio inputs object containing Mel features, an input features
@@ -769,8 +791,9 @@ class GraniteSpeechForConditionalGeneration(
if intermediate_tensors is not None:
inputs_embeds = None
model_output = self.language_model(input_ids, positions,
intermediate_tensors, inputs_embeds)
model_output = self.language_model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return model_output
def compute_logits(