[Bugfix] Fix auto dtype casting for BatchFeature (#19316)
Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -12,6 +12,7 @@ from transformers import AutoTokenizer
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core import EngineCore
|
||||
from vllm.v1.executor.abstract import Executor, UniProcExecutor
|
||||
@@ -56,9 +57,10 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True)
|
||||
with set_default_torch_num_threads(1):
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True)
|
||||
"""Test basic request lifecycle."""
|
||||
|
||||
# First request.
|
||||
@@ -190,9 +192,10 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True)
|
||||
with set_default_torch_num_threads(1):
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True)
|
||||
"""Test basic request lifecycle."""
|
||||
# First request.
|
||||
request: EngineCoreRequest = make_request()
|
||||
@@ -286,9 +289,10 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
enforce_eager=True,
|
||||
)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
log_stats=False,
|
||||
executor_class=DummyExecutor)
|
||||
with set_default_torch_num_threads(1):
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
log_stats=False,
|
||||
executor_class=DummyExecutor)
|
||||
assert engine_core.batch_queue is not None
|
||||
|
||||
# Add two requests in a row. Each request have 12 prompt tokens.
|
||||
|
||||
Reference in New Issue
Block a user