[V1] Scatter and gather placeholders in the model runner (#16076)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Jennifer Zhao <ai.jenniferzhao@gmail.com>
This commit is contained in:
Roger Wang
2025-04-07 19:43:41 -07:00
committed by GitHub
parent 1d01211264
commit f2ebb6f541
41 changed files with 521 additions and 1020 deletions

View File

@@ -40,7 +40,8 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
DictEmbeddingItems, ModalityData,
ModalityDataItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import ProcessorInputs
from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
@@ -50,7 +51,6 @@ from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
_minicpmv_field_config)
from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
maybe_prefix)
from .vision import scatter_patch_features
CPU_DEVICE = torch.device("cpu")
@@ -73,14 +73,6 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
which equals to `audio_features.shape[-1]`
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which audio embeddings correspond
to patch tokens.
Shape: `(batch_size * num_audios, num_embeds)`
"""
class MiniCPMOAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
@@ -93,14 +85,6 @@ class MiniCPMOAudioEmbeddingInputs(TypedDict):
Length of each slice may vary, so pass it as a list.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which audio embeddings correspond
to patch tokens.
Shape: `(batch_size * num_audios, num_embeds)`
"""
MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
MiniCPMOAudioEmbeddingInputs]
@@ -115,7 +99,6 @@ def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_features=MultiModalFieldConfig.batched("audio"),
audio_feature_lens=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.batched("audio"),
audio_embed_is_patch=MultiModalFieldConfig.batched("audio"),
audio_token_id=MultiModalFieldConfig.shared("audio", num_audios),
)
@@ -197,8 +180,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
pool_step = self.get_default_audio_pool_step()
fbank_feat_in_chunk = 100
cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) // 2 + 1
num_audio_tokens = (cnn_feat_in_chunk - pool_step) // pool_step + 1
return num_audio_tokens + 2 # <audio>(<unk>*N)</audio>
return (cnn_feat_in_chunk - pool_step) // pool_step + 1
def get_max_audio_chunks_with_most_features(self) -> int:
return 30
@@ -209,8 +191,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
def get_audio_len_by_num_chunks(self, num_chunks: int) -> int:
sampling_rate = self.get_default_audio_sampling_rate()
# exclude <audio> </audio>
num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk() - 2
num_tokens_per_chunk = self.get_max_audio_tokens_per_chunk()
return int(num_chunks * sampling_rate / num_tokens_per_chunk) + 1
def get_num_frames_with_most_features(
@@ -295,13 +276,6 @@ class MiniCPMOMultiModalProcessor(
if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems):
audio_inputs = {}
audio_lens = [
self.info.get_audio_len_by_num_chunks(
sum(map(len,
parsed_audios.get(i)["audio_embeds"])))
for i in range(len(parsed_audios))
]
else:
audio_inputs = self._base_call_hf_processor(
prompts=[self.info.audio_pattern] * len(parsed_audios),
@@ -323,27 +297,7 @@ class MiniCPMOMultiModalProcessor(
]
audio_inputs["audio_features"] = unpadded_audio_features
audio_lens = [
parsed_audios.get_audio_length(i)
for i in range(len(parsed_audios))
]
audio_repl_features = [
self.get_audio_prompt_texts(audio_len) for audio_len in audio_lens
]
tokenizer = self.info.get_tokenizer()
audio_repls_feature_tokens = [
tokenizer.encode(audio_repl, add_special_tokens=False)
for audio_repl in audio_repl_features
]
embed_is_patch = [
self.get_embed_is_patch(audio_repl_tokens)
for audio_repl_tokens in audio_repls_feature_tokens
]
audio_inputs["audio_embed_is_patch"] = embed_is_patch
unk_token_id = tokenizer.get_vocab()["<unk>"]
audio_inputs["audio_token_id"] = torch.tensor(unk_token_id)
@@ -384,7 +338,10 @@ class MiniCPMOMultiModalProcessor(
else:
audio_len = audios.get_audio_length(item_idx)
return self.get_audio_prompt_texts(audio_len)
return PromptUpdateDetails.select_text(
self.get_audio_prompt_texts(audio_len),
"<unk>",
)
return [
*base_updates,
@@ -713,13 +670,6 @@ class MiniCPMO(MiniCPMV2_6):
assert isinstance(audio_token_id, torch.Tensor)
self.mm_token_ids.add(audio_token_id.flatten().unique().item())
audio_embed_is_patch = kwargs.pop("audio_embed_is_patch")
if not isinstance(audio_embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_embed_is_patch. "
f"Got type: {type(audio_embed_is_patch)}")
audio_embed_is_patch = flatten_bn(audio_embed_is_patch)
if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_embeds. "
@@ -730,7 +680,6 @@ class MiniCPMO(MiniCPMV2_6):
return MiniCPMOAudioEmbeddingInputs(
type="audio_embeds",
audio_embeds=audio_embeds_flat,
embed_is_patch=audio_embed_is_patch,
)
if not isinstance(audio_features, (torch.Tensor, list)):
@@ -749,7 +698,6 @@ class MiniCPMO(MiniCPMV2_6):
type="audio_features",
audio_features=audio_features_flat,
audio_feature_lens=audio_feature_lens_flat,
embed_is_patch=audio_embed_is_patch,
)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
@@ -781,10 +729,6 @@ class MiniCPMO(MiniCPMV2_6):
if modality == "audios":
audio_input = modalities["audios"]
audio_features = self._process_audio_input(audio_input)
multimodal_embeddings += tuple(
scatter_patch_features(
audio_features,
audio_input["embed_is_patch"],
))
multimodal_embeddings += tuple(audio_features)
return multimodal_embeddings