[Core] Simplify multimodal masking (#34246)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
@@ -4,9 +4,11 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
from vllm.model_executor.models.utils import (
|
||||
AutoWeightsLoader,
|
||||
_merge_multimodal_embeddings,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class ModuleWithBatchNorm(torch.nn.Module):
|
||||
@@ -27,6 +29,7 @@ class ModuleWithNestedBatchNorm(torch.nn.Module):
|
||||
return self.nested_mod(x)
|
||||
|
||||
|
||||
@pytest.mark.cpu_test
|
||||
def test_module_with_batchnorm_can_load():
|
||||
"""Ensure the auto weight loader can load batchnorm stats."""
|
||||
mod = ModuleWithBatchNorm()
|
||||
@@ -52,6 +55,7 @@ def test_module_with_batchnorm_can_load():
|
||||
assert new_mod.bn.num_batches_tracked.item() == 1
|
||||
|
||||
|
||||
@pytest.mark.cpu_test
|
||||
def test_module_with_child_containing_batchnorm_can_autoload():
|
||||
"""Ensure the auto weight loader can load nested modules batchnorm stats."""
|
||||
mod = ModuleWithNestedBatchNorm()
|
||||
@@ -83,6 +87,7 @@ def test_module_with_child_containing_batchnorm_can_autoload():
|
||||
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
|
||||
|
||||
|
||||
@pytest.mark.cpu_test
|
||||
def test_module_skip_prefix():
|
||||
"""Ensure the auto weight loader can skip prefix."""
|
||||
mod = ModuleWithNestedBatchNorm()
|
||||
@@ -119,6 +124,7 @@ def test_module_skip_prefix():
|
||||
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
|
||||
|
||||
|
||||
@pytest.mark.cpu_test
|
||||
def test_module_skip_substr():
|
||||
"""Ensure the auto weight loader can skip prefix."""
|
||||
mod = ModuleWithNestedBatchNorm()
|
||||
@@ -155,3 +161,23 @@ def test_module_skip_substr():
|
||||
)
|
||||
assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||||
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
|
||||
|
||||
|
||||
class raise_if_cuda_sync:
|
||||
def __enter__(self):
|
||||
self.previous_debug_mode = torch.cuda.get_sync_debug_mode()
|
||||
torch.cuda.set_sync_debug_mode("error")
|
||||
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
torch.cuda.set_sync_debug_mode(self.previous_debug_mode)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
||||
def test_merge_multimodal_embeddings_no_sync():
|
||||
inputs_embeds = torch.zeros([5, 10], dtype=torch.bfloat16, device="cuda:0")
|
||||
multimodal_embeddings = [torch.ones([3, 10], dtype=torch.bfloat16, device="cuda:0")]
|
||||
is_multimodal = torch.tensor([True, False, True, True, False], device="cpu")
|
||||
with raise_if_cuda_sync():
|
||||
_merge_multimodal_embeddings(
|
||||
inputs_embeds, multimodal_embeddings, is_multimodal
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user