[Core] Raise when non-multi-instance DP clients target a DP rank (#19227)
Signed-off-by: Jon Swenson <jmswen@gmail.com>
This commit is contained in:
@@ -250,3 +250,32 @@ async def test_customize_loggers(monkeypatch):
|
||||
assert len(engine.stat_loggers) == 1
|
||||
assert len(engine.stat_loggers[0]) == 1
|
||||
engine.stat_loggers[0][0].log.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
engine = AsyncLLM.from_engine_args(TEXT_ENGINE_ARGS)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=100,
|
||||
output_kind=RequestOutputKind.DELTA,
|
||||
temperature=1.0,
|
||||
seed=33)
|
||||
|
||||
# Test with valid DP rank.
|
||||
async for _ in engine.generate(request_id="request-34",
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
data_parallel_rank=0):
|
||||
pass
|
||||
|
||||
# Test with out-of-range DP rank.
|
||||
with pytest.raises(ValueError):
|
||||
async for _ in engine.generate(request_id="request-35",
|
||||
prompt=TEXT_PROMPT,
|
||||
sampling_params=sampling_params,
|
||||
data_parallel_rank=1):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user