Add pt_load_map_location to allow loading to cuda (#16869)

Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
Jerry Zhang
2025-05-01 23:23:42 -07:00
committed by GitHub
parent f192ca90e6
commit 109e15a335
6 changed files with 74 additions and 3 deletions

View File

@@ -5,7 +5,8 @@ from typing import Literal, Union
import pytest
from vllm.config import ModelConfig, PoolerConfig, config, get_field
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
config, get_field)
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
@@ -410,3 +411,16 @@ def test_generation_config_loading():
override_generation_config=override_generation_config)
assert model_config.get_diff_sampling_param() == override_generation_config
@pytest.mark.parametrize("pt_load_map_location", [
"cuda",
{
"": "cuda"
},
])
def test_load_config_pt_load_map_location(pt_load_map_location):
load_config = LoadConfig(pt_load_map_location=pt_load_map_location)
config = VllmConfig(load_config=load_config)
assert config.load_config.pt_load_map_location == pt_load_map_location