Remove unused kwargs from model definitions (#13555)
This commit is contained in:
@@ -10,7 +10,7 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
|
||||
WhisperProcessor)
|
||||
from transformers.models.whisper.modeling_whisper import sinusoids
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
@@ -134,13 +134,11 @@ class WhisperAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
output, _ = self.out_proj(attn_output)
|
||||
|
||||
@@ -196,8 +194,6 @@ class WhisperCrossAttention(WhisperAttention):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor],
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
):
|
||||
q, _ = self.q_proj(hidden_states)
|
||||
|
||||
@@ -209,13 +205,7 @@ class WhisperCrossAttention(WhisperAttention):
|
||||
else:
|
||||
k = v = None
|
||||
|
||||
attn_output = self.attn(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
)
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
output, _ = self.out_proj(attn_output)
|
||||
|
||||
@@ -285,16 +275,10 @@ class WhisperEncoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
residual = hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
@@ -348,14 +332,10 @@ class WhisperDecoderLayer(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor],
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
hidden_states = self.self_attn(hidden_states=hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
@@ -363,8 +343,6 @@ class WhisperDecoderLayer(nn.Module):
|
||||
hidden_states = self.encoder_attn(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
@@ -411,12 +389,7 @@ class WhisperEncoder(nn.Module):
|
||||
self.embed_positions.weight.copy_(
|
||||
sinusoids(*self.embed_positions.weight.shape))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_features: Union[torch.Tensor, List[torch.Tensor]],
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
):
|
||||
def forward(self, input_features: Union[torch.Tensor, List[torch.Tensor]]):
|
||||
hidden_states = []
|
||||
for features in input_features:
|
||||
embeds = nn.functional.gelu(self.conv1(features))
|
||||
@@ -426,12 +399,8 @@ class WhisperEncoder(nn.Module):
|
||||
hidden_states.append(embeds)
|
||||
hidden_states = torch.cat(hidden_states)
|
||||
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
hidden_states = encoder_layer(
|
||||
hidden_states,
|
||||
kv_cache=kv_caches[idx],
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(hidden_states)
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
return hidden_states
|
||||
@@ -466,19 +435,15 @@ class WhisperDecoder(nn.Module):
|
||||
input_ids,
|
||||
positions: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor],
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
):
|
||||
inputs_embeds = self.get_input_embeddings(input_ids)
|
||||
positions = self.embed_positions(positions)
|
||||
hidden_states = inputs_embeds + positions
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
for decoder_layer in self.layers:
|
||||
hidden_states = decoder_layer(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
kv_cache=kv_caches[idx],
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
@@ -505,36 +470,22 @@ class WhisperModel(nn.Module):
|
||||
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]],
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
encoder_outputs = self.get_encoder_outputs(
|
||||
input_features,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
encoder_outputs = self.get_encoder_outputs(input_features)
|
||||
decoder_outputs = self.decoder(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
encoder_hidden_states=encoder_outputs,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
return decoder_outputs
|
||||
|
||||
def get_encoder_outputs(
|
||||
self,
|
||||
input_features: Optional[Union[torch.Tensor, List[torch.Tensor]]],
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if input_features is None:
|
||||
return None
|
||||
return self.encoder(
|
||||
input_features,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
return self.encoder(input_features)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
@@ -733,8 +684,6 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
@@ -742,31 +691,19 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
input_features=audio_input["input_features"],
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
return decoder_outputs
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs,
|
||||
) -> Optional[NestedTensors]:
|
||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||
# TODO: This method does not obey the interface for SupportsMultiModal.
|
||||
# Refactor this once encoder/decoder support is implemented in V1.
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
return self.model.get_encoder_outputs(
|
||||
audio_input["input_features"],
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
return self.model.get_encoder_outputs(audio_input["input_features"])
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[NestedTensors] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
) -> torch.Tensor:
|
||||
# TODO: This method just returns the decoder sequence embeddings since
|
||||
# Whisper does not have encoder text tokens. Refactor this once
|
||||
|
||||
Reference in New Issue
Block a user