[V1][Spec Decoding] Use model_loader.get_model() to load models (#18273)

Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Mark McLoughlin
2025-05-23 03:05:44 +01:00
committed by GitHub
parent 04eb88dc80
commit c6b636f9fb
16 changed files with 59 additions and 135 deletions

View File

@@ -117,34 +117,13 @@ def test_prepare_inputs():
])
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
@mock.patch('vllm.v1.spec_decode.eagle.ModelRegistry')
@mock.patch('vllm.v1.spec_decode.eagle.get_model_loader')
@mock.patch('vllm.v1.spec_decode.eagle.set_default_torch_dtype')
@mock.patch('vllm.v1.spec_decode.eagle.set_current_vllm_config')
def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
mock_registry, mock_get_layers, mock_get_pp_group, method,
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
proposer_helper, draft_model_dir, target_attribute_path):
# Setup mock for model class
mock_model_cls = mock.MagicMock()
mock_registry.resolve_model_cls.return_value = (mock_model_cls,
"test_arch")
# Create a real context manager for mocks
class MockContextManager:
def __init__(self):
pass
def __enter__(self):
return None
def __exit__(self, exc_type, exc_val, exc_tb):
return False
# Make the mocks return actual context manager objects
mock_set_dtype.return_value = MockContextManager()
mock_set_config.return_value = MockContextManager()
# Setup model mock
mock_model = mock.MagicMock()
mock_get_model.return_value = mock_model
# Setup mocks for attention layers
target_attn_layers = {
@@ -164,25 +143,6 @@ def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
mock_pp_group.world_size = 2 if method == "eagle" else 1
mock_get_pp_group.return_value = mock_pp_group
# Setup model loader mock
mock_loader = mock.MagicMock()
mock_get_loader.return_value = mock_loader
# Setup model mock
mock_model = mock.MagicMock()
mock_model_cls.return_value = mock_model
mock_model.to.return_value = mock_model
# Configure mock to test the attribute sharing path
if method == "eagle":
# For eagle, test the lm_head path
mock_model.load_weights.return_value = {
"model.embed_tokens.weight": torch.zeros(1)
}
else:
# For eagle3, test the embed_tokens path
mock_model.load_weights.return_value = {}
# Setup target model with the appropriate attributes
target_model = mock.MagicMock()
@@ -204,13 +164,7 @@ def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
proposer.load_model(target_model)
# Verify common interactions
mock_get_loader.assert_called_once()
mock_model_cls.assert_called_once()
mock_model.to.assert_called_once()
mock_model.load_weights.assert_called_once()
# Verify the loader was called with the right config
mock_get_loader.assert_called_once_with(proposer.vllm_config.load_config)
mock_get_model.assert_called_once()
# Verify the specific attribute sharing based on the method
if method == "eagle":