[Core] Simplify multimodal masking (#34246)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
Lukas Geiger
2026-04-01 09:18:22 +01:00
committed by GitHub
parent 36d7f19897
commit 4f6eed3bd4
9 changed files with 54 additions and 51 deletions

View File

@@ -4,9 +4,11 @@
import pytest
import torch
from vllm.model_executor.models.utils import AutoWeightsLoader
pytestmark = pytest.mark.cpu_test
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
_merge_multimodal_embeddings,
)
from vllm.platforms import current_platform
class ModuleWithBatchNorm(torch.nn.Module):
@@ -27,6 +29,7 @@ class ModuleWithNestedBatchNorm(torch.nn.Module):
return self.nested_mod(x)
@pytest.mark.cpu_test
def test_module_with_batchnorm_can_load():
"""Ensure the auto weight loader can load batchnorm stats."""
mod = ModuleWithBatchNorm()
@@ -52,6 +55,7 @@ def test_module_with_batchnorm_can_load():
assert new_mod.bn.num_batches_tracked.item() == 1
@pytest.mark.cpu_test
def test_module_with_child_containing_batchnorm_can_autoload():
"""Ensure the auto weight loader can load nested modules batchnorm stats."""
mod = ModuleWithNestedBatchNorm()
@@ -83,6 +87,7 @@ def test_module_with_child_containing_batchnorm_can_autoload():
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
@pytest.mark.cpu_test
def test_module_skip_prefix():
"""Ensure the auto weight loader can skip prefix."""
mod = ModuleWithNestedBatchNorm()
@@ -119,6 +124,7 @@ def test_module_skip_prefix():
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
@pytest.mark.cpu_test
def test_module_skip_substr():
"""Ensure the auto weight loader can skip prefix."""
mod = ModuleWithNestedBatchNorm()
@@ -155,3 +161,23 @@ def test_module_skip_substr():
)
assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
class raise_if_cuda_sync:
def __enter__(self):
self.previous_debug_mode = torch.cuda.get_sync_debug_mode()
torch.cuda.set_sync_debug_mode("error")
def __exit__(self, exception_type, exception_value, traceback):
torch.cuda.set_sync_debug_mode(self.previous_debug_mode)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
def test_merge_multimodal_embeddings_no_sync():
inputs_embeds = torch.zeros([5, 10], dtype=torch.bfloat16, device="cuda:0")
multimodal_embeddings = [torch.ones([3, 10], dtype=torch.bfloat16, device="cuda:0")]
is_multimodal = torch.tensor([True, False, True, True, False], device="cpu")
with raise_if_cuda_sync():
_merge_multimodal_embeddings(
inputs_embeds, multimodal_embeddings, is_multimodal
)