[V1][Metrics] add support for kv event publishing (#16750)
Signed-off-by: alec-flowers <aflowers@nvidia.com> Signed-off-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
@@ -13,6 +13,8 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
|
||||
from ...distributed.conftest import publisher_config, random_port # noqa: F401
|
||||
|
||||
from tests.v1.engine.utils import FULL_STRINGS # isort: skip
|
||||
|
||||
EngineCoreSampleLogprobsType = list[tuple[torch.Tensor, torch.Tensor]]
|
||||
|
||||
@@ -11,6 +11,7 @@ import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.distributed.kv_events import BlockStored, KVEventBatch
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
@@ -20,6 +21,7 @@ from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
|
||||
SyncMPClient)
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
from ...distributed.conftest import MockSubscriber
|
||||
from ...utils import create_new_process_for_each_test
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
@@ -199,54 +201,142 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
|
||||
log_stats=True,
|
||||
)
|
||||
|
||||
MAX_TOKENS = 20
|
||||
params = SamplingParams(max_tokens=MAX_TOKENS)
|
||||
"""Normal Request Cycle."""
|
||||
try:
|
||||
MAX_TOKENS = 20
|
||||
params = SamplingParams(max_tokens=MAX_TOKENS)
|
||||
"""Normal Request Cycle."""
|
||||
|
||||
requests = [make_request(params) for _ in range(10)]
|
||||
request_ids = [req.request_id for req in requests]
|
||||
requests = [make_request(params) for _ in range(10)]
|
||||
request_ids = [req.request_id for req in requests]
|
||||
|
||||
# Add requests to the engine.
|
||||
for request in requests:
|
||||
await client.add_request_async(request)
|
||||
await asyncio.sleep(0.01)
|
||||
# Add requests to the engine.
|
||||
for request in requests:
|
||||
await client.add_request_async(request)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
outputs: dict[str, list] = {req_id: [] for req_id in request_ids}
|
||||
await loop_until_done_async(client, outputs)
|
||||
outputs: dict[str, list] = {req_id: [] for req_id in request_ids}
|
||||
await loop_until_done_async(client, outputs)
|
||||
|
||||
for req_id in request_ids:
|
||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||
f"{outputs[req_id]=}, {MAX_TOKENS=}")
|
||||
"""Abort Request Cycle."""
|
||||
|
||||
# Add requests to the engine.
|
||||
for idx, request in enumerate(requests):
|
||||
await client.add_request_async(request)
|
||||
await asyncio.sleep(0.01)
|
||||
if idx % 2 == 0:
|
||||
await client.abort_requests_async([request.request_id])
|
||||
|
||||
outputs = {req_id: [] for req_id in request_ids}
|
||||
await loop_until_done_async(client, outputs)
|
||||
|
||||
for idx, req_id in enumerate(request_ids):
|
||||
if idx % 2 == 0:
|
||||
assert len(outputs[req_id]) < MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||
else:
|
||||
for req_id in request_ids:
|
||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||
"""Utility method invocation"""
|
||||
f"{outputs[req_id]=}, {MAX_TOKENS=}")
|
||||
"""Abort Request Cycle."""
|
||||
|
||||
core_client: AsyncMPClient = client
|
||||
# Add requests to the engine.
|
||||
for idx, request in enumerate(requests):
|
||||
await client.add_request_async(request)
|
||||
await asyncio.sleep(0.01)
|
||||
if idx % 2 == 0:
|
||||
await client.abort_requests_async([request.request_id])
|
||||
|
||||
result = await core_client.call_utility_async("echo", "testarg")
|
||||
assert result == "testarg"
|
||||
outputs = {req_id: [] for req_id in request_ids}
|
||||
await loop_until_done_async(client, outputs)
|
||||
|
||||
with pytest.raises(Exception) as e_info:
|
||||
await core_client.call_utility_async("echo", None, "help!")
|
||||
for idx, req_id in enumerate(request_ids):
|
||||
if idx % 2 == 0:
|
||||
assert len(outputs[req_id]) < MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||
else:
|
||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||
"""Utility method invocation"""
|
||||
|
||||
assert str(e_info.value) == "Call to echo method failed: help!"
|
||||
core_client: AsyncMPClient = client
|
||||
|
||||
result = await core_client.call_utility_async("echo", "testarg")
|
||||
assert result == "testarg"
|
||||
|
||||
with pytest.raises(Exception) as e_info:
|
||||
await core_client.call_utility_async("echo", None, "help!")
|
||||
|
||||
assert str(e_info.value) == "Call to echo method failed: help!"
|
||||
finally:
|
||||
client.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"multiprocessing_mode,publisher_config",
|
||||
[(True, "tcp"), (False, "inproc")],
|
||||
indirect=["publisher_config"],
|
||||
)
|
||||
def test_kv_cache_events(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
multiprocessing_mode: bool,
|
||||
publisher_config,
|
||||
):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
block_size = 16
|
||||
num_blocks = 2
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=True,
|
||||
block_size=block_size)
|
||||
engine_args.kv_events_config = publisher_config
|
||||
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
UsageContext.UNKNOWN_CONTEXT)
|
||||
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=multiprocessing_mode,
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=False,
|
||||
)
|
||||
endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
|
||||
time.sleep(0.1)
|
||||
subscriber = MockSubscriber(endpoint,
|
||||
topic=publisher_config.topic,
|
||||
decode_type=KVEventBatch)
|
||||
|
||||
try:
|
||||
custom_tokens = list(range(num_blocks * block_size))
|
||||
request = EngineCoreRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
prompt_token_ids=custom_tokens,
|
||||
mm_inputs=None,
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=1), # Short completion for speed
|
||||
eos_token_id=None,
|
||||
arrival_time=time.time(),
|
||||
lora_request=None,
|
||||
)
|
||||
client.add_request(request)
|
||||
|
||||
outputs: dict[str, list] = {request.request_id: []}
|
||||
loop_until_done(client, outputs)
|
||||
|
||||
result = subscriber.receive_one(timeout=1000)
|
||||
assert result is not None, "No message received"
|
||||
|
||||
seq, received = result
|
||||
|
||||
assert seq == 0, "Sequence number mismatch"
|
||||
assert len(received.events) == 1, (
|
||||
"We should have exactly one BlockStored event")
|
||||
event = received.events[0]
|
||||
assert isinstance(
|
||||
event, BlockStored), ("We should have a BlockStored event")
|
||||
assert len(event.block_hashes) == num_blocks, (
|
||||
"We should have a BlockStored event with 2 block_hashes")
|
||||
assert event.block_size == block_size, (
|
||||
"Block size should be the same as the block size")
|
||||
assert event.parent_block_hash is None, (
|
||||
"Parent block hash should be None")
|
||||
assert event.lora_id is None, "Lora id should be None"
|
||||
assert len(event.token_ids) == num_blocks * block_size, (
|
||||
"Token ids should be the same as the custom tokens")
|
||||
assert event.token_ids == custom_tokens, (
|
||||
"Token ids should be the same as the custom tokens")
|
||||
finally:
|
||||
client.shutdown()
|
||||
return
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
|
||||
Reference in New Issue
Block a user