184 lines
6.2 KiB
Python
184 lines
6.2 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from vllm.model_executor.models.utils import (
|
|
AutoWeightsLoader,
|
|
_merge_multimodal_embeddings,
|
|
)
|
|
from vllm.platforms import current_platform
|
|
|
|
|
|
class ModuleWithBatchNorm(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.bn = torch.nn.BatchNorm1d(2)
|
|
|
|
def forward(self, x):
|
|
return self.bn(x)
|
|
|
|
|
|
class ModuleWithNestedBatchNorm(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.nested_mod = ModuleWithBatchNorm()
|
|
|
|
def forward(self, x):
|
|
return self.nested_mod(x)
|
|
|
|
|
|
@pytest.mark.cpu_test
|
|
def test_module_with_batchnorm_can_load():
|
|
"""Ensure the auto weight loader can load batchnorm stats."""
|
|
mod = ModuleWithBatchNorm()
|
|
# Run some data through the module with batchnorm
|
|
mod(torch.Tensor([[1, 2], [3, 4]]))
|
|
|
|
# Try to load the weights to a new instance
|
|
def weight_generator():
|
|
yield from mod.state_dict().items()
|
|
|
|
new_mod = ModuleWithBatchNorm()
|
|
|
|
assert not torch.all(new_mod.bn.running_mean == mod.bn.running_mean)
|
|
assert not torch.all(new_mod.bn.running_var == mod.bn.running_var)
|
|
assert new_mod.bn.num_batches_tracked.item() == 0
|
|
|
|
loader = AutoWeightsLoader(new_mod)
|
|
loader.load_weights(weight_generator())
|
|
|
|
# Ensure the stats are updated
|
|
assert torch.all(new_mod.bn.running_mean == mod.bn.running_mean)
|
|
assert torch.all(new_mod.bn.running_var == mod.bn.running_var)
|
|
assert new_mod.bn.num_batches_tracked.item() == 1
|
|
|
|
|
|
@pytest.mark.cpu_test
|
|
def test_module_with_child_containing_batchnorm_can_autoload():
|
|
"""Ensure the auto weight loader can load nested modules batchnorm stats."""
|
|
mod = ModuleWithNestedBatchNorm()
|
|
# Run some data through the module with batchnorm
|
|
mod(torch.Tensor([[1, 2], [3, 4]]))
|
|
|
|
# Try to load the weights to a new instance
|
|
def weight_generator():
|
|
yield from mod.state_dict().items()
|
|
|
|
new_mod = ModuleWithNestedBatchNorm()
|
|
|
|
assert not torch.all(
|
|
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
|
|
)
|
|
assert not torch.all(
|
|
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var
|
|
)
|
|
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
|
|
|
|
loader = AutoWeightsLoader(new_mod)
|
|
loader.load_weights(weight_generator())
|
|
|
|
# Ensure the stats are updated
|
|
assert torch.all(
|
|
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
|
|
)
|
|
assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
|
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
|
|
|
|
|
|
@pytest.mark.cpu_test
|
|
def test_module_skip_prefix():
|
|
"""Ensure the auto weight loader can skip prefix."""
|
|
mod = ModuleWithNestedBatchNorm()
|
|
# Run some data through the module with batchnorm
|
|
mod(torch.Tensor([[1, 2], [3, 4]]))
|
|
|
|
# Try to load the weights to a new instance
|
|
def weight_generator():
|
|
# weights needed to be filtered out
|
|
redundant_weights = {
|
|
"prefix.bn.weight": torch.Tensor([1, 2]),
|
|
"prefix.bn.bias": torch.Tensor([3, 4]),
|
|
}
|
|
yield from (mod.state_dict() | redundant_weights).items()
|
|
|
|
new_mod = ModuleWithNestedBatchNorm()
|
|
|
|
assert not torch.all(
|
|
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
|
|
)
|
|
assert not torch.all(
|
|
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var
|
|
)
|
|
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
|
|
|
|
loader = AutoWeightsLoader(new_mod, skip_prefixes=["prefix."])
|
|
loader.load_weights(weight_generator())
|
|
|
|
# Ensure the stats are updated
|
|
assert torch.all(
|
|
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
|
|
)
|
|
assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
|
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
|
|
|
|
|
|
@pytest.mark.cpu_test
|
|
def test_module_skip_substr():
|
|
"""Ensure the auto weight loader can skip prefix."""
|
|
mod = ModuleWithNestedBatchNorm()
|
|
# Run some data through the module with batchnorm
|
|
mod(torch.Tensor([[1, 2], [3, 4]]))
|
|
|
|
# Try to load the weights to a new instance
|
|
def weight_generator():
|
|
# weights needed to be filtered out
|
|
redundant_weights = {
|
|
"nested_mod.0.substr.weight": torch.Tensor([1, 2]),
|
|
"nested_mod.0.substr.bias": torch.Tensor([3, 4]),
|
|
"nested_mod.substr.weight": torch.Tensor([1, 2]),
|
|
"nested_mod.substr.bias": torch.Tensor([3, 4]),
|
|
}
|
|
yield from (mod.state_dict() | redundant_weights).items()
|
|
|
|
new_mod = ModuleWithNestedBatchNorm()
|
|
|
|
assert not torch.all(
|
|
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
|
|
)
|
|
assert not torch.all(
|
|
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var
|
|
)
|
|
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
|
|
|
|
loader = AutoWeightsLoader(new_mod, skip_substrs=["substr."])
|
|
loader.load_weights(weight_generator())
|
|
|
|
# Ensure the stats are updated
|
|
assert torch.all(
|
|
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean
|
|
)
|
|
assert torch.all(new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
|
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
|
|
|
|
|
|
class raise_if_cuda_sync:
|
|
def __enter__(self):
|
|
self.previous_debug_mode = torch.cuda.get_sync_debug_mode()
|
|
torch.cuda.set_sync_debug_mode("error")
|
|
|
|
def __exit__(self, exception_type, exception_value, traceback):
|
|
torch.cuda.set_sync_debug_mode(self.previous_debug_mode)
|
|
|
|
|
|
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
|
|
def test_merge_multimodal_embeddings_no_sync():
|
|
inputs_embeds = torch.zeros([5, 10], dtype=torch.bfloat16, device="cuda:0")
|
|
multimodal_embeddings = [torch.ones([3, 10], dtype=torch.bfloat16, device="cuda:0")]
|
|
is_multimodal = torch.tensor([True, False, True, True, False], device="cpu")
|
|
with raise_if_cuda_sync():
|
|
_merge_multimodal_embeddings(
|
|
inputs_embeds, multimodal_embeddings, is_multimodal
|
|
)
|