[V1][Spec Decoding] Use model_loader.get_model() to load models (#18273)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
@@ -117,34 +117,13 @@ def test_prepare_inputs():
|
||||
])
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_pp_group')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_layers_from_vllm_config')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.ModelRegistry')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_model_loader')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.set_default_torch_dtype')
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.set_current_vllm_config')
|
||||
def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
|
||||
mock_registry, mock_get_layers, mock_get_pp_group, method,
|
||||
@mock.patch('vllm.v1.spec_decode.eagle.get_model')
|
||||
def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
|
||||
proposer_helper, draft_model_dir, target_attribute_path):
|
||||
|
||||
# Setup mock for model class
|
||||
mock_model_cls = mock.MagicMock()
|
||||
mock_registry.resolve_model_cls.return_value = (mock_model_cls,
|
||||
"test_arch")
|
||||
|
||||
# Create a real context manager for mocks
|
||||
class MockContextManager:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
return None
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
return False
|
||||
|
||||
# Make the mocks return actual context manager objects
|
||||
mock_set_dtype.return_value = MockContextManager()
|
||||
mock_set_config.return_value = MockContextManager()
|
||||
# Setup model mock
|
||||
mock_model = mock.MagicMock()
|
||||
mock_get_model.return_value = mock_model
|
||||
|
||||
# Setup mocks for attention layers
|
||||
target_attn_layers = {
|
||||
@@ -164,25 +143,6 @@ def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
|
||||
mock_pp_group.world_size = 2 if method == "eagle" else 1
|
||||
mock_get_pp_group.return_value = mock_pp_group
|
||||
|
||||
# Setup model loader mock
|
||||
mock_loader = mock.MagicMock()
|
||||
mock_get_loader.return_value = mock_loader
|
||||
|
||||
# Setup model mock
|
||||
mock_model = mock.MagicMock()
|
||||
mock_model_cls.return_value = mock_model
|
||||
mock_model.to.return_value = mock_model
|
||||
|
||||
# Configure mock to test the attribute sharing path
|
||||
if method == "eagle":
|
||||
# For eagle, test the lm_head path
|
||||
mock_model.load_weights.return_value = {
|
||||
"model.embed_tokens.weight": torch.zeros(1)
|
||||
}
|
||||
else:
|
||||
# For eagle3, test the embed_tokens path
|
||||
mock_model.load_weights.return_value = {}
|
||||
|
||||
# Setup target model with the appropriate attributes
|
||||
target_model = mock.MagicMock()
|
||||
|
||||
@@ -204,13 +164,7 @@ def test_load_model(mock_set_config, mock_set_dtype, mock_get_loader,
|
||||
proposer.load_model(target_model)
|
||||
|
||||
# Verify common interactions
|
||||
mock_get_loader.assert_called_once()
|
||||
mock_model_cls.assert_called_once()
|
||||
mock_model.to.assert_called_once()
|
||||
mock_model.load_weights.assert_called_once()
|
||||
|
||||
# Verify the loader was called with the right config
|
||||
mock_get_loader.assert_called_once_with(proposer.vllm_config.load_config)
|
||||
mock_get_model.assert_called_once()
|
||||
|
||||
# Verify the specific attribute sharing based on the method
|
||||
if method == "eagle":
|
||||
|
||||
Reference in New Issue
Block a user