Add pt_load_map_location to allow loading to cuda (#16869)
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
@@ -5,7 +5,8 @@ from typing import Literal, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import ModelConfig, PoolerConfig, config, get_field
|
||||
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
|
||||
config, get_field)
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -410,3 +411,16 @@ def test_generation_config_loading():
|
||||
override_generation_config=override_generation_config)
|
||||
|
||||
assert model_config.get_diff_sampling_param() == override_generation_config
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pt_load_map_location", [
|
||||
"cuda",
|
||||
{
|
||||
"": "cuda"
|
||||
},
|
||||
])
|
||||
def test_load_config_pt_load_map_location(pt_load_map_location):
|
||||
load_config = LoadConfig(pt_load_map_location=pt_load_map_location)
|
||||
config = VllmConfig(load_config=load_config)
|
||||
|
||||
assert config.load_config.pt_load_map_location == pt_load_map_location
|
||||
|
||||
Reference in New Issue
Block a user