feat: add data parallel rank to KVEventBatch (#18925)

This commit is contained in:
Yan Ru Pei
2025-06-03 17:14:20 -07:00
committed by GitHub
parent a8da78eac9
commit b712be98c7
6 changed files with 359 additions and 83 deletions

View File

@@ -12,8 +12,10 @@ from typing import Optional
import pytest
from transformers import AutoTokenizer
from tests.utils import multi_gpu_test
from vllm import SamplingParams
from vllm.distributed.kv_events import BlockStored, KVEventBatch
from vllm.distributed.kv_events import (BlockStored, KVEventBatch,
ZmqEventPublisher)
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext
@@ -37,10 +39,15 @@ PROMPT = "Hello my name is Robert and I love quantization kernels"
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
def make_request(params: SamplingParams) -> EngineCoreRequest:
def make_request(
params: SamplingParams,
prompt_tokens_ids: Optional[list[int]] = None) -> EngineCoreRequest:
if not prompt_tokens_ids:
prompt_tokens_ids = PROMPT_TOKENS
return EngineCoreRequest(
request_id=str(uuid.uuid4()),
prompt_token_ids=PROMPT_TOKENS,
prompt_token_ids=prompt_tokens_ids,
mm_inputs=None,
mm_hashes=None,
mm_placeholders=None,
@@ -88,6 +95,25 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: dict):
break
async def loop_until_fully_done_async(client: EngineCoreClient, outputs: dict):
while True:
engine_core_outputs = (await client.get_output_async()).outputs
if len(engine_core_outputs) == 0:
continue
# Add outputs to the dict
for out in engine_core_outputs:
outputs[out.request_id].append(out)
# Check if all request IDs in outputs have finished
if all(outs and outs[-1].finished for outs in outputs.values()):
break
await asyncio.sleep(0.1)
# Dummy utility function to monkey-patch into engine core.
def echo(self, msg: str, err_msg: Optional[str] = None) -> str:
print(f"echo util function called: {msg}, {err_msg}")
@@ -273,10 +299,12 @@ def test_kv_cache_events(
block_size = 16
num_blocks = 2
engine_args = EngineArgs(model=MODEL_NAME,
enforce_eager=True,
enable_prefix_caching=True,
block_size=block_size)
engine_args = EngineArgs(
model=MODEL_NAME,
enforce_eager=True,
enable_prefix_caching=True,
block_size=block_size,
)
engine_args.kv_events_config = publisher_config
vllm_config = engine_args.create_engine_config(
@@ -297,19 +325,8 @@ def test_kv_cache_events(
try:
custom_tokens = list(range(num_blocks * block_size))
request = EngineCoreRequest(
request_id=str(uuid.uuid4()),
prompt_token_ids=custom_tokens,
mm_inputs=None,
mm_hashes=None,
mm_placeholders=None,
sampling_params=SamplingParams(
max_tokens=1), # Short completion for speed
eos_token_id=None,
arrival_time=time.time(),
lora_request=None,
cache_salt=None,
)
sampling_params = SamplingParams(max_tokens=1)
request = make_request(sampling_params, custom_tokens)
client.add_request(request)
outputs: dict[str, list] = {request.request_id: []}
@@ -321,24 +338,130 @@ def test_kv_cache_events(
seq, received = result
assert seq == 0, "Sequence number mismatch"
assert len(received.events) == 1, (
"We should have exactly one BlockStored event")
assert (len(received.events) == 1
), "We should have exactly one BlockStored event"
event = received.events[0]
assert isinstance(
event, BlockStored), ("We should have a BlockStored event")
assert len(event.block_hashes) == num_blocks, (
"We should have a BlockStored event with 2 block_hashes")
assert event.block_size == block_size, (
"Block size should be the same as the block size")
assert event.parent_block_hash is None, (
"Parent block hash should be None")
event, BlockStored), "We should have a BlockStored event"
assert (len(event.block_hashes) == num_blocks
), "We should have a BlockStored event with 2 block_hashes"
assert (event.block_size == block_size
), "Block size should be the same as the block size"
assert (event.parent_block_hash
is None), "Parent block hash should be None"
assert event.lora_id is None, "Lora id should be None"
assert len(event.token_ids) == num_blocks * block_size, (
"Token ids should be the same as the custom tokens")
assert event.token_ids == custom_tokens, (
"Token ids should be the same as the custom tokens")
assert (len(event.token_ids) == num_blocks * block_size
), "Token ids should be the same as the custom tokens"
assert (event.token_ids == custom_tokens
), "Token ids should be the same as the custom tokens"
finally:
client.shutdown()
subscriber.close()
@pytest.mark.asyncio
@pytest.mark.parametrize(
"multiprocessing_mode,publisher_config",
[(True, "tcp")],
indirect=["publisher_config"],
)
@multi_gpu_test(num_gpus=4)
async def test_kv_cache_events_dp(
monkeypatch: pytest.MonkeyPatch,
multiprocessing_mode: bool,
publisher_config,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
block_size = 16
num_blocks = 2
dp_size = 2
tp_size = 2
engine_args = EngineArgs(
model=MODEL_NAME,
enforce_eager=True,
enable_prefix_caching=True,
data_parallel_size=dp_size,
tensor_parallel_size=tp_size,
block_size=block_size,
)
engine_args.kv_events_config = publisher_config
vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
await asyncio.sleep(1)
# Build endpoints for all DP ranks
base_endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
endpoints = []
for i in range(dp_size):
offset_endpoint = ZmqEventPublisher.offset_endpoint_port(
base_endpoint, i)
endpoints.append(offset_endpoint)
subscriber = MockSubscriber(endpoints,
topic=publisher_config.topic,
decode_type=KVEventBatch)
try:
custom_tokens = list(range(num_blocks * block_size))
sampling_params = SamplingParams(max_tokens=1)
all_request_ids = []
# Create and add 25 requests
# NOTE: attempts to force routing to both dp groups but can be flaky
for i in range(25):
await asyncio.sleep(0.01)
request = make_request(sampling_params, custom_tokens)
await client.add_request_async(request)
all_request_ids.append(request.request_id)
await asyncio.sleep(0.1)
# Initialize outputs dict for all requests
outputs: dict[str, list] = {
req_id: []
for req_id in all_request_ids
}
print("processing requests...")
await asyncio.wait_for(loop_until_fully_done_async(
client, outputs),
timeout=20.0)
# Receive from subscriber until no more messages
print("collecting results...")
results = []
while True:
result = subscriber.receive_one(timeout=1)
print(result)
if result is None:
break
results.append(result)
# Collect all events and data_parallel_ranks from all results
all_dp_ranks = [
received.data_parallel_rank for (_, received) in results
]
unique_dps = set(all_dp_ranks)
assert (
len(unique_dps) == 2
), f"Expected 2 unique data_parallel_ranks, got {len(unique_dps)}"
finally:
client.shutdown()
subscriber.close()
@pytest.mark.timeout(20)