[Bugfix] Fix auto dtype casting for BatchFeature (#19316)
Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -15,6 +15,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.v1.metrics.loggers import LoggingStatLogger
|
||||
|
||||
@@ -107,7 +108,8 @@ async def test_load(
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
NUM_REQUESTS = 100
|
||||
@@ -154,7 +156,8 @@ async def test_abort(
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
NUM_REQUESTS = 100
|
||||
@@ -226,7 +229,8 @@ async def test_finished_flag(
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
@@ -260,7 +264,8 @@ async def test_mid_stream_cancellation(monkeypatch: pytest.MonkeyPatch,
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
NUM_REQUESTS = 100
|
||||
@@ -322,10 +327,11 @@ async def test_customize_loggers(monkeypatch):
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
engine = AsyncLLM.from_engine_args(
|
||||
TEXT_ENGINE_ARGS,
|
||||
stat_loggers=[MockLoggingStatLogger],
|
||||
)
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(
|
||||
TEXT_ENGINE_ARGS,
|
||||
stat_loggers=[MockLoggingStatLogger],
|
||||
)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
await engine.do_log_stats()
|
||||
@@ -340,7 +346,8 @@ async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=100,
|
||||
|
||||
@@ -12,6 +12,7 @@ from transformers import AutoTokenizer
|
||||
from vllm import SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core import EngineCore
|
||||
from vllm.v1.executor.abstract import Executor, UniProcExecutor
|
||||
@@ -56,9 +57,10 @@ def test_engine_core(monkeypatch: pytest.MonkeyPatch):
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True)
|
||||
with set_default_torch_num_threads(1):
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True)
|
||||
"""Test basic request lifecycle."""
|
||||
|
||||
# First request.
|
||||
@@ -190,9 +192,10 @@ def test_engine_core_advanced_sampling(monkeypatch: pytest.MonkeyPatch):
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True)
|
||||
with set_default_torch_num_threads(1):
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True)
|
||||
"""Test basic request lifecycle."""
|
||||
# First request.
|
||||
request: EngineCoreRequest = make_request()
|
||||
@@ -286,9 +289,10 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
enforce_eager=True,
|
||||
)
|
||||
vllm_config = engine_args.create_engine_config()
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
log_stats=False,
|
||||
executor_class=DummyExecutor)
|
||||
with set_default_torch_num_threads(1):
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
log_stats=False,
|
||||
executor_class=DummyExecutor)
|
||||
assert engine_core.batch_queue is not None
|
||||
|
||||
# Add two requests in a row. Each request have 12 prompt tokens.
|
||||
|
||||
@@ -19,6 +19,7 @@ from vllm.distributed.kv_events import (BlockStored, KVEventBatch,
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core import EngineCore
|
||||
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
|
||||
@@ -138,13 +139,15 @@ def test_engine_core_client(monkeypatch: pytest.MonkeyPatch,
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
UsageContext.UNKNOWN_CONTEXT)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=multiprocessing_mode,
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=False,
|
||||
)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=multiprocessing_mode,
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=False,
|
||||
)
|
||||
|
||||
MAX_TOKENS = 20
|
||||
params = SamplingParams(max_tokens=MAX_TOKENS)
|
||||
@@ -223,13 +226,15 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=True,
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True,
|
||||
)
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=True,
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True,
|
||||
)
|
||||
|
||||
try:
|
||||
MAX_TOKENS = 20
|
||||
@@ -312,13 +317,14 @@ def test_kv_cache_events(
|
||||
UsageContext.UNKNOWN_CONTEXT)
|
||||
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=multiprocessing_mode,
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=False,
|
||||
)
|
||||
with set_default_torch_num_threads(1):
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=multiprocessing_mode,
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=False,
|
||||
)
|
||||
endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
|
||||
subscriber = MockSubscriber(endpoint,
|
||||
topic=publisher_config.topic,
|
||||
@@ -394,13 +400,14 @@ async def test_kv_cache_events_dp(
|
||||
UsageContext.UNKNOWN_CONTEXT)
|
||||
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=multiprocessing_mode,
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=False,
|
||||
)
|
||||
with set_default_torch_num_threads(1):
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=multiprocessing_mode,
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=False,
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Build endpoints for all DP ranks
|
||||
|
||||
Reference in New Issue
Block a user