475 lines
14 KiB
Python
475 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import json
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from vllm.config.multimodal import MultiModalConfig
|
|
from vllm.entrypoints.openai.engine.protocol import StreamOptions
|
|
from vllm.entrypoints.openai.models.protocol import BaseModelPath
|
|
from vllm.entrypoints.openai.models.serving import OpenAIServingModels
|
|
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest
|
|
from vllm.entrypoints.serve.disagg.serving import ServingTokens
|
|
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
|
|
from vllm.logprobs import Logprob
|
|
from vllm.outputs import CompletionOutput, RequestOutput
|
|
from vllm.renderers import renderer_from_config
|
|
from vllm.sampling_params import SamplingParams
|
|
from vllm.v1.engine.async_llm import AsyncLLM
|
|
|
|
MODEL_NAME = "openai-community/gpt2"
|
|
BASE_MODEL_PATHS = [
|
|
BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME),
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class MockHFConfig:
|
|
model_type: str = "any"
|
|
|
|
|
|
@dataclass
|
|
class MockModelConfig:
|
|
task = "generate"
|
|
runner_type = "generate"
|
|
model = MODEL_NAME
|
|
tokenizer = MODEL_NAME
|
|
trust_remote_code = False
|
|
tokenizer_mode = "auto"
|
|
max_model_len = 100
|
|
tokenizer_revision = None
|
|
multimodal_config = MultiModalConfig()
|
|
hf_config = MockHFConfig()
|
|
hf_text_config = MockHFConfig()
|
|
logits_processors: list[str] | None = None
|
|
diff_sampling_param: dict | None = None
|
|
allowed_local_media_path: str = ""
|
|
allowed_media_domains: list[str] | None = None
|
|
encoder_config = None
|
|
generation_config: str = "auto"
|
|
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
|
|
skip_tokenizer_init = False
|
|
is_encoder_decoder: bool = False
|
|
is_multimodal_model: bool = False
|
|
renderer_num_workers: int = 1
|
|
|
|
def get_diff_sampling_param(self):
|
|
return self.diff_sampling_param or {}
|
|
|
|
|
|
@dataclass
|
|
class MockParallelConfig:
|
|
_api_process_rank: int = 0
|
|
|
|
|
|
@dataclass
|
|
class MockVllmConfig:
|
|
model_config: MockModelConfig
|
|
parallel_config: MockParallelConfig
|
|
|
|
|
|
def _build_renderer(model_config: MockModelConfig):
|
|
return renderer_from_config(
|
|
MockVllmConfig(model_config, parallel_config=MockParallelConfig()),
|
|
)
|
|
|
|
|
|
def _build_serving_tokens(engine: AsyncLLM, **kwargs) -> ServingTokens:
|
|
models = OpenAIServingModels(
|
|
engine_client=engine,
|
|
base_model_paths=BASE_MODEL_PATHS,
|
|
)
|
|
serving_render = OpenAIServingRender(
|
|
model_config=engine.model_config,
|
|
renderer=engine.renderer,
|
|
io_processor=engine.io_processor,
|
|
model_registry=models.registry,
|
|
request_logger=None,
|
|
chat_template=None,
|
|
chat_template_content_format="auto",
|
|
)
|
|
serving = ServingTokens(
|
|
engine,
|
|
models,
|
|
openai_serving_render=serving_render,
|
|
request_logger=None,
|
|
**kwargs,
|
|
)
|
|
|
|
async def _fake_preprocess(*args, **kwargs):
|
|
return [{"prompt_token_ids": [1, 2, 3]}]
|
|
|
|
serving.openai_serving_render.preprocess_completion = AsyncMock(
|
|
side_effect=_fake_preprocess
|
|
)
|
|
return serving
|
|
|
|
|
|
def _make_request_output(
|
|
request_id: str,
|
|
token_ids: list[int],
|
|
finish_reason: str | None = None,
|
|
finished: bool = False,
|
|
prompt_token_ids: list[int] | None = None,
|
|
logprobs: list[dict[int, Any] | None] | None = None,
|
|
num_cached_tokens: int | None = None,
|
|
index: int = 0,
|
|
) -> RequestOutput:
|
|
return RequestOutput(
|
|
request_id=request_id,
|
|
prompt=None,
|
|
prompt_token_ids=prompt_token_ids or [1, 2, 3],
|
|
prompt_logprobs=None,
|
|
outputs=[
|
|
CompletionOutput(
|
|
index=index,
|
|
text="",
|
|
token_ids=token_ids,
|
|
cumulative_logprob=None,
|
|
logprobs=logprobs,
|
|
finish_reason=finish_reason,
|
|
)
|
|
],
|
|
finished=finished,
|
|
metrics=None,
|
|
lora_request=None,
|
|
encoder_prompt=None,
|
|
encoder_prompt_token_ids=None,
|
|
num_cached_tokens=num_cached_tokens,
|
|
)
|
|
|
|
|
|
def _mock_engine() -> MagicMock:
|
|
engine = MagicMock(spec=AsyncLLM)
|
|
engine.errored = False
|
|
engine.model_config = MockModelConfig()
|
|
engine.input_processor = MagicMock()
|
|
engine.io_processor = MagicMock()
|
|
engine.renderer = _build_renderer(engine.model_config)
|
|
return engine
|
|
|
|
|
|
def _parse_sse_chunks(chunks: list[str]) -> list[Any]:
|
|
"""Parse SSE chunks into dicts (JSON) or raw strings ([DONE])."""
|
|
parsed: list[Any] = []
|
|
for chunk in chunks:
|
|
assert chunk.startswith("data: ") and chunk.endswith("\n\n")
|
|
payload = chunk[len("data: ") : -len("\n\n")]
|
|
if payload == "[DONE]":
|
|
parsed.append("[DONE]")
|
|
else:
|
|
parsed.append(json.loads(payload))
|
|
return parsed
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_basic():
|
|
"""Streaming returns SSE chunks with correct token_ids and ends with [DONE]."""
|
|
engine = _mock_engine()
|
|
|
|
async def mock_generate(*args, **kwargs):
|
|
yield _make_request_output("req-1", token_ids=[10])
|
|
yield _make_request_output("req-1", token_ids=[20, 30])
|
|
yield _make_request_output(
|
|
"req-1", token_ids=[40], finish_reason="stop", finished=True
|
|
)
|
|
|
|
engine.generate = MagicMock(side_effect=mock_generate)
|
|
serving = _build_serving_tokens(engine)
|
|
|
|
request = GenerateRequest(
|
|
token_ids=[1, 2, 3],
|
|
sampling_params=SamplingParams(max_tokens=10),
|
|
model=MODEL_NAME,
|
|
stream=True,
|
|
)
|
|
|
|
response = await serving.serve_tokens(request)
|
|
chunks = []
|
|
async for chunk in response:
|
|
chunks.append(chunk)
|
|
|
|
parsed = _parse_sse_chunks(chunks)
|
|
|
|
# 3 data chunks + [DONE]
|
|
assert parsed[-1] == "[DONE]"
|
|
data_chunks = [c for c in parsed if c != "[DONE]"]
|
|
assert len(data_chunks) == 3
|
|
|
|
assert data_chunks[0]["choices"][0]["token_ids"] == [10]
|
|
assert data_chunks[1]["choices"][0]["token_ids"] == [20, 30]
|
|
assert data_chunks[2]["choices"][0]["token_ids"] == [40]
|
|
assert data_chunks[2]["choices"][0]["finish_reason"] == "stop"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_error_mid_generation():
|
|
"""finish_reason='error' mid-stream yields error chunk then [DONE]."""
|
|
engine = _mock_engine()
|
|
|
|
async def mock_generate(*args, **kwargs):
|
|
yield _make_request_output("req-1", token_ids=[10])
|
|
yield _make_request_output(
|
|
"req-1", token_ids=[20], finish_reason="error", finished=True
|
|
)
|
|
|
|
engine.generate = MagicMock(side_effect=mock_generate)
|
|
serving = _build_serving_tokens(engine)
|
|
|
|
request = GenerateRequest(
|
|
token_ids=[1, 2, 3],
|
|
sampling_params=SamplingParams(max_tokens=10),
|
|
model=MODEL_NAME,
|
|
stream=True,
|
|
)
|
|
|
|
response = await serving.serve_tokens(request)
|
|
chunks = []
|
|
async for chunk in response:
|
|
chunks.append(chunk)
|
|
|
|
assert len(chunks) >= 2
|
|
assert any("Internal server error" in chunk for chunk in chunks), (
|
|
f"Expected error message in chunks: {chunks}"
|
|
)
|
|
assert chunks[-1] == "data: [DONE]\n\n"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_error_with_empty_delta():
|
|
"""finish_reason='error' with empty delta_token_ids still raises."""
|
|
engine = _mock_engine()
|
|
|
|
async def mock_generate(*args, **kwargs):
|
|
yield _make_request_output("req-1", token_ids=[10])
|
|
yield _make_request_output(
|
|
"req-1", token_ids=[], finish_reason="error", finished=True
|
|
)
|
|
|
|
engine.generate = MagicMock(side_effect=mock_generate)
|
|
serving = _build_serving_tokens(engine)
|
|
|
|
request = GenerateRequest(
|
|
token_ids=[1, 2, 3],
|
|
sampling_params=SamplingParams(max_tokens=10),
|
|
model=MODEL_NAME,
|
|
stream=True,
|
|
)
|
|
|
|
response = await serving.serve_tokens(request)
|
|
chunks = []
|
|
async for chunk in response:
|
|
chunks.append(chunk)
|
|
|
|
assert any("Internal server error" in chunk for chunk in chunks), (
|
|
f"Expected error message in chunks: {chunks}"
|
|
)
|
|
assert chunks[-1] == "data: [DONE]\n\n"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_skips_empty_token_output():
|
|
"""Outputs with empty token_ids are skipped (no chunk emitted)."""
|
|
engine = _mock_engine()
|
|
|
|
async def mock_generate(*args, **kwargs):
|
|
yield _make_request_output("req-1", token_ids=[10])
|
|
yield _make_request_output("req-1", token_ids=[])
|
|
yield _make_request_output(
|
|
"req-1", token_ids=[20], finish_reason="stop", finished=True
|
|
)
|
|
|
|
engine.generate = MagicMock(side_effect=mock_generate)
|
|
serving = _build_serving_tokens(engine)
|
|
|
|
request = GenerateRequest(
|
|
token_ids=[1, 2, 3],
|
|
sampling_params=SamplingParams(max_tokens=10),
|
|
model=MODEL_NAME,
|
|
stream=True,
|
|
)
|
|
|
|
response = await serving.serve_tokens(request)
|
|
chunks = []
|
|
async for chunk in response:
|
|
chunks.append(chunk)
|
|
|
|
parsed = _parse_sse_chunks(chunks)
|
|
assert parsed[-1] == "[DONE]"
|
|
data_chunks = [c for c in parsed if c != "[DONE]"]
|
|
|
|
# Only 2 data chunks — the empty one is skipped
|
|
assert len(data_chunks) == 2
|
|
assert data_chunks[0]["choices"][0]["token_ids"] == [10]
|
|
assert data_chunks[1]["choices"][0]["token_ids"] == [20]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_include_usage():
|
|
"""stream_options.include_usage emits a final usage-only chunk."""
|
|
engine = _mock_engine()
|
|
|
|
async def mock_generate(*args, **kwargs):
|
|
yield _make_request_output("req-1", token_ids=[10])
|
|
yield _make_request_output(
|
|
"req-1", token_ids=[20], finish_reason="stop", finished=True
|
|
)
|
|
|
|
engine.generate = MagicMock(side_effect=mock_generate)
|
|
serving = _build_serving_tokens(engine)
|
|
|
|
request = GenerateRequest(
|
|
token_ids=[1, 2, 3],
|
|
sampling_params=SamplingParams(max_tokens=10),
|
|
model=MODEL_NAME,
|
|
stream=True,
|
|
stream_options=StreamOptions(include_usage=True),
|
|
)
|
|
|
|
response = await serving.serve_tokens(request)
|
|
chunks = []
|
|
async for chunk in response:
|
|
chunks.append(chunk)
|
|
|
|
parsed = _parse_sse_chunks(chunks)
|
|
assert parsed[-1] == "[DONE]"
|
|
|
|
# The chunk before [DONE] should be the usage-only chunk
|
|
usage_chunk = parsed[-2]
|
|
assert usage_chunk["choices"] == []
|
|
assert usage_chunk["usage"]["prompt_tokens"] == 3
|
|
assert usage_chunk["usage"]["completion_tokens"] == 2
|
|
assert usage_chunk["usage"]["total_tokens"] == 5
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_continuous_usage():
|
|
"""continuous_usage_stats adds usage to every data chunk."""
|
|
engine = _mock_engine()
|
|
|
|
async def mock_generate(*args, **kwargs):
|
|
yield _make_request_output("req-1", token_ids=[10])
|
|
yield _make_request_output(
|
|
"req-1", token_ids=[20], finish_reason="stop", finished=True
|
|
)
|
|
|
|
engine.generate = MagicMock(side_effect=mock_generate)
|
|
serving = _build_serving_tokens(engine)
|
|
|
|
request = GenerateRequest(
|
|
token_ids=[1, 2, 3],
|
|
sampling_params=SamplingParams(max_tokens=10),
|
|
model=MODEL_NAME,
|
|
stream=True,
|
|
stream_options=StreamOptions(
|
|
include_usage=True,
|
|
continuous_usage_stats=True,
|
|
),
|
|
)
|
|
|
|
response = await serving.serve_tokens(request)
|
|
chunks = []
|
|
async for chunk in response:
|
|
chunks.append(chunk)
|
|
|
|
parsed = _parse_sse_chunks(chunks)
|
|
data_chunks = [c for c in parsed if isinstance(c, dict) and c.get("choices")]
|
|
|
|
# Every data chunk should have usage
|
|
for i, dc in enumerate(data_chunks):
|
|
assert dc["usage"] is not None, f"chunk {i} missing usage"
|
|
assert dc["usage"]["prompt_tokens"] == 3
|
|
|
|
# First chunk: 1 completion token
|
|
assert data_chunks[0]["usage"]["completion_tokens"] == 1
|
|
assert data_chunks[0]["usage"]["total_tokens"] == 4
|
|
|
|
# Second chunk: 2 completion tokens (cumulative)
|
|
assert data_chunks[1]["usage"]["completion_tokens"] == 2
|
|
assert data_chunks[1]["usage"]["total_tokens"] == 5
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_with_logprobs():
|
|
"""Streaming with logprobs includes logprob data in each chunk."""
|
|
engine = _mock_engine()
|
|
|
|
async def mock_generate(*args, **kwargs):
|
|
yield _make_request_output(
|
|
"req-1",
|
|
token_ids=[10],
|
|
logprobs=[{10: Logprob(logprob=-0.5)}],
|
|
)
|
|
yield _make_request_output(
|
|
"req-1",
|
|
token_ids=[20],
|
|
logprobs=[{20: Logprob(logprob=-1.0)}],
|
|
finish_reason="stop",
|
|
finished=True,
|
|
)
|
|
|
|
engine.generate = MagicMock(side_effect=mock_generate)
|
|
serving = _build_serving_tokens(engine)
|
|
|
|
request = GenerateRequest(
|
|
token_ids=[1, 2, 3],
|
|
sampling_params=SamplingParams(max_tokens=10, logprobs=1),
|
|
model=MODEL_NAME,
|
|
stream=True,
|
|
)
|
|
|
|
response = await serving.serve_tokens(request)
|
|
chunks = []
|
|
async for chunk in response:
|
|
chunks.append(chunk)
|
|
|
|
parsed = _parse_sse_chunks(chunks)
|
|
data_chunks = [c for c in parsed if isinstance(c, dict) and c.get("choices")]
|
|
|
|
for dc in data_chunks:
|
|
lp = dc["choices"][0]["logprobs"]
|
|
assert lp is not None
|
|
assert len(lp["content"]) == 1
|
|
assert lp["content"][0]["token"].startswith("token_id:")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_prompt_tokens_details():
|
|
"""enable_prompt_tokens_details includes cached_tokens in final usage."""
|
|
engine = _mock_engine()
|
|
|
|
async def mock_generate(*args, **kwargs):
|
|
yield _make_request_output(
|
|
"req-1",
|
|
token_ids=[10],
|
|
finish_reason="stop",
|
|
finished=True,
|
|
num_cached_tokens=2,
|
|
)
|
|
|
|
engine.generate = MagicMock(side_effect=mock_generate)
|
|
serving = _build_serving_tokens(engine, enable_prompt_tokens_details=True)
|
|
|
|
request = GenerateRequest(
|
|
token_ids=[1, 2, 3],
|
|
sampling_params=SamplingParams(max_tokens=10),
|
|
model=MODEL_NAME,
|
|
stream=True,
|
|
stream_options=StreamOptions(include_usage=True),
|
|
)
|
|
|
|
response = await serving.serve_tokens(request)
|
|
chunks = []
|
|
async for chunk in response:
|
|
chunks.append(chunk)
|
|
|
|
parsed = _parse_sse_chunks(chunks)
|
|
# Usage-only chunk (before [DONE])
|
|
usage_chunk = parsed[-2]
|
|
assert usage_chunk["choices"] == []
|
|
assert usage_chunk["usage"]["prompt_tokens_details"]["cached_tokens"] == 2
|