Improve configs - TokenizerPoolConfig + DeviceConfig (#16603)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,14 +1,36 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import asdict
|
||||
from dataclasses import MISSING, Field, asdict, dataclass, field
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import ModelConfig, PoolerConfig
|
||||
from vllm.config import ModelConfig, PoolerConfig, get_field
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
def test_get_field():
|
||||
|
||||
@dataclass
|
||||
class TestConfig:
|
||||
a: int
|
||||
b: dict = field(default_factory=dict)
|
||||
c: str = "default"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
get_field(TestConfig, "a")
|
||||
|
||||
b = get_field(TestConfig, "b")
|
||||
assert isinstance(b, Field)
|
||||
assert b.default is MISSING
|
||||
assert b.default_factory is dict
|
||||
|
||||
c = get_field(TestConfig, "c")
|
||||
assert isinstance(c, Field)
|
||||
assert c.default == "default"
|
||||
assert c.default_factory is MISSING
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "expected_runner_type", "expected_task"),
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user