Remove unused kwargs from model definitions (#13555)

This commit is contained in:
Harry Mellor
2025-02-25 01:13:52 +00:00
committed by GitHub
parent f61528d46d
commit cdc1fa12eb
104 changed files with 436 additions and 1654 deletions

View File

@@ -10,7 +10,7 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
WhisperProcessor)
from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention import Attention, AttentionType
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
@@ -134,13 +134,11 @@ class WhisperAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q, k, v)
output, _ = self.out_proj(attn_output)
@@ -196,8 +194,6 @@ class WhisperCrossAttention(WhisperAttention):
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
):
q, _ = self.q_proj(hidden_states)
@@ -209,13 +205,7 @@ class WhisperCrossAttention(WhisperAttention):
else:
k = v = None
attn_output = self.attn(
q,
k,
v,
kv_cache,
attn_metadata,
)
attn_output = self.attn(q, k, v)
output, _ = self.out_proj(attn_output)
@@ -285,16 +275,10 @@ class WhisperEncoderLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
):
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.final_layer_norm(hidden_states)
@@ -348,14 +332,10 @@ class WhisperDecoderLayer(nn.Module):
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
):
residual = hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
hidden_states = self.self_attn(hidden_states=hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
hidden_states = self.self_attn(hidden_states=hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
@@ -363,8 +343,6 @@ class WhisperDecoderLayer(nn.Module):
hidden_states = self.encoder_attn(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
hidden_states = residual + hidden_states
@@ -411,12 +389,7 @@ class WhisperEncoder(nn.Module):
self.embed_positions.weight.copy_(
sinusoids(*self.embed_positions.weight.shape))
def forward(
self,
input_features: Union[torch.Tensor, List[torch.Tensor]],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
):
def forward(self, input_features: Union[torch.Tensor, List[torch.Tensor]]):
hidden_states = []
for features in input_features:
embeds = nn.functional.gelu(self.conv1(features))
@@ -426,12 +399,8 @@ class WhisperEncoder(nn.Module):
hidden_states.append(embeds)
hidden_states = torch.cat(hidden_states)
for idx, encoder_layer in enumerate(self.layers):
hidden_states = encoder_layer(
hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
)
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states)
hidden_states = self.layer_norm(hidden_states)
return hidden_states
@@ -466,19 +435,15 @@ class WhisperDecoder(nn.Module):
input_ids,
positions: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
):
inputs_embeds = self.get_input_embeddings(input_ids)
positions = self.embed_positions(positions)
hidden_states = inputs_embeds + positions
for idx, decoder_layer in enumerate(self.layers):
for decoder_layer in self.layers:
hidden_states = decoder_layer(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
kv_cache=kv_caches[idx],
attn_metadata=attn_metadata,
)
hidden_states = self.layer_norm(hidden_states)
@@ -505,36 +470,22 @@ class WhisperModel(nn.Module):
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]],
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
encoder_outputs = self.get_encoder_outputs(
input_features,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
encoder_outputs = self.get_encoder_outputs(input_features)
decoder_outputs = self.decoder(
input_ids=input_ids,
positions=positions,
encoder_hidden_states=encoder_outputs,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
return decoder_outputs
def get_encoder_outputs(
self,
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]],
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> Optional[torch.Tensor]:
if input_features is None:
return None
return self.encoder(
input_features,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
return self.encoder(input_features)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
@@ -733,8 +684,6 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs,
) -> torch.Tensor:
audio_input = self._parse_and_validate_audio_input(**kwargs)
@@ -742,31 +691,19 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
input_features=audio_input["input_features"],
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
return decoder_outputs
def get_multimodal_embeddings(
self,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
**kwargs,
) -> Optional[NestedTensors]:
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
# TODO: This method does not obey the interface for SupportsMultiModal.
# Refactor this once encoder/decoder support is implemented in V1.
audio_input = self._parse_and_validate_audio_input(**kwargs)
return self.model.get_encoder_outputs(
audio_input["input_features"],
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
return self.model.get_encoder_outputs(audio_input["input_features"])
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
attn_metadata: Optional[AttentionMetadata] = None,
) -> torch.Tensor:
# TODO: This method just returns the decoder sequence embeddings since
# Whisper does not have encoder text tokens. Refactor this once