[Models] Add remaining model PP support (#7168)

Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
Signed-off-by: Murali Andoorveedu <muralidhar.andoorveedu@centml.ai>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Murali Andoorveedu
2024-10-03 19:56:58 -07:00
committed by GitHub
parent 303d44790a
commit 0f6d7a9a34
69 changed files with 2585 additions and 1344 deletions

View File

@@ -3,7 +3,7 @@
import math
from array import array
from functools import lru_cache
from functools import cached_property, lru_cache
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union, cast)
@@ -22,12 +22,10 @@ from vllm.inputs.data import LLMInputs
from vllm.inputs.registry import InputContext
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.utils import (flatten_bn,
group_weights_with_prefix,
init_vllm_registered_model,
@@ -37,9 +35,12 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs, NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
from .interfaces import SupportsMultiModal, SupportsPP
_AUDIO_PLACEHOLDER_TOKEN = 128002
_AUDIO_TOKENS_PER_SECOND = 6.25
@@ -323,7 +324,7 @@ class ModifiedWhisperEncoder(WhisperEncoder):
"audio", get_ultravox_max_audio_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox)
@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox)
class UltravoxModel(nn.Module, SupportsMultiModal):
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self,
config: UltravoxConfig,
@@ -353,6 +354,16 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
revision=None,
prefix="language_model."))
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@cached_property
def sampler(self):
if hasattr(self.language_model, "sampler"):
return self.language_model.sampler
return Sampler()
def _audio_features_to_embeddings(
self, input_features: torch.Tensor) -> torch.Tensor:
audio_input = input_features.to(self.audio_tower.dtype)
@@ -425,7 +436,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[torch.Tensor],
**kwargs) -> SamplerOutput:
**kwargs) -> Union[torch.Tensor, IntermediateTensors]:
"""Run forward pass for Ultravox
One key thing to understand is the `input_ids` already accounts for the
@@ -438,18 +449,22 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
Args:
audio_features: A batch of audio inputs [B, N, 80, M].
"""
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is not None:
audio_embeddings = self._process_audio_input(audio_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, audio_embeddings,
_AUDIO_PLACEHOLDER_TOKEN)
if intermediate_tensors is not None:
input_ids = None
else:
inputs_embeds = None
else:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is not None:
audio_embeddings = self._process_audio_input(audio_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, audio_embeddings,
_AUDIO_PLACEHOLDER_TOKEN)
input_ids = None
else:
inputs_embeds = None
hidden_states = self.language_model.model(
input_ids=input_ids,