feat: add data parallel rank to KVEventBatch (#18925)
This commit is contained in:
@@ -12,8 +12,10 @@ from typing import Optional
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from tests.utils import multi_gpu_test
|
||||
from vllm import SamplingParams
|
||||
from vllm.distributed.kv_events import BlockStored, KVEventBatch
|
||||
from vllm.distributed.kv_events import (BlockStored, KVEventBatch,
|
||||
ZmqEventPublisher)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
@@ -37,10 +39,15 @@ PROMPT = "Hello my name is Robert and I love quantization kernels"
|
||||
PROMPT_TOKENS = TOKENIZER(PROMPT).input_ids
|
||||
|
||||
|
||||
def make_request(params: SamplingParams) -> EngineCoreRequest:
|
||||
def make_request(
|
||||
params: SamplingParams,
|
||||
prompt_tokens_ids: Optional[list[int]] = None) -> EngineCoreRequest:
|
||||
if not prompt_tokens_ids:
|
||||
prompt_tokens_ids = PROMPT_TOKENS
|
||||
|
||||
return EngineCoreRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
prompt_token_ids=PROMPT_TOKENS,
|
||||
prompt_token_ids=prompt_tokens_ids,
|
||||
mm_inputs=None,
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
@@ -88,6 +95,25 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: dict):
|
||||
break
|
||||
|
||||
|
||||
async def loop_until_fully_done_async(client: EngineCoreClient, outputs: dict):
|
||||
|
||||
while True:
|
||||
engine_core_outputs = (await client.get_output_async()).outputs
|
||||
|
||||
if len(engine_core_outputs) == 0:
|
||||
continue
|
||||
|
||||
# Add outputs to the dict
|
||||
for out in engine_core_outputs:
|
||||
outputs[out.request_id].append(out)
|
||||
|
||||
# Check if all request IDs in outputs have finished
|
||||
if all(outs and outs[-1].finished for outs in outputs.values()):
|
||||
break
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
# Dummy utility function to monkey-patch into engine core.
|
||||
def echo(self, msg: str, err_msg: Optional[str] = None) -> str:
|
||||
print(f"echo util function called: {msg}, {err_msg}")
|
||||
@@ -273,10 +299,12 @@ def test_kv_cache_events(
|
||||
block_size = 16
|
||||
num_blocks = 2
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=True,
|
||||
block_size=block_size)
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL_NAME,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=True,
|
||||
block_size=block_size,
|
||||
)
|
||||
engine_args.kv_events_config = publisher_config
|
||||
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
@@ -297,19 +325,8 @@ def test_kv_cache_events(
|
||||
|
||||
try:
|
||||
custom_tokens = list(range(num_blocks * block_size))
|
||||
request = EngineCoreRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
prompt_token_ids=custom_tokens,
|
||||
mm_inputs=None,
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=1), # Short completion for speed
|
||||
eos_token_id=None,
|
||||
arrival_time=time.time(),
|
||||
lora_request=None,
|
||||
cache_salt=None,
|
||||
)
|
||||
sampling_params = SamplingParams(max_tokens=1)
|
||||
request = make_request(sampling_params, custom_tokens)
|
||||
client.add_request(request)
|
||||
|
||||
outputs: dict[str, list] = {request.request_id: []}
|
||||
@@ -321,24 +338,130 @@ def test_kv_cache_events(
|
||||
seq, received = result
|
||||
|
||||
assert seq == 0, "Sequence number mismatch"
|
||||
assert len(received.events) == 1, (
|
||||
"We should have exactly one BlockStored event")
|
||||
assert (len(received.events) == 1
|
||||
), "We should have exactly one BlockStored event"
|
||||
event = received.events[0]
|
||||
assert isinstance(
|
||||
event, BlockStored), ("We should have a BlockStored event")
|
||||
assert len(event.block_hashes) == num_blocks, (
|
||||
"We should have a BlockStored event with 2 block_hashes")
|
||||
assert event.block_size == block_size, (
|
||||
"Block size should be the same as the block size")
|
||||
assert event.parent_block_hash is None, (
|
||||
"Parent block hash should be None")
|
||||
event, BlockStored), "We should have a BlockStored event"
|
||||
assert (len(event.block_hashes) == num_blocks
|
||||
), "We should have a BlockStored event with 2 block_hashes"
|
||||
assert (event.block_size == block_size
|
||||
), "Block size should be the same as the block size"
|
||||
assert (event.parent_block_hash
|
||||
is None), "Parent block hash should be None"
|
||||
assert event.lora_id is None, "Lora id should be None"
|
||||
assert len(event.token_ids) == num_blocks * block_size, (
|
||||
"Token ids should be the same as the custom tokens")
|
||||
assert event.token_ids == custom_tokens, (
|
||||
"Token ids should be the same as the custom tokens")
|
||||
assert (len(event.token_ids) == num_blocks * block_size
|
||||
), "Token ids should be the same as the custom tokens"
|
||||
assert (event.token_ids == custom_tokens
|
||||
), "Token ids should be the same as the custom tokens"
|
||||
finally:
|
||||
client.shutdown()
|
||||
subscriber.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"multiprocessing_mode,publisher_config",
|
||||
[(True, "tcp")],
|
||||
indirect=["publisher_config"],
|
||||
)
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
async def test_kv_cache_events_dp(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
multiprocessing_mode: bool,
|
||||
publisher_config,
|
||||
):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
block_size = 16
|
||||
num_blocks = 2
|
||||
dp_size = 2
|
||||
tp_size = 2
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=MODEL_NAME,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=True,
|
||||
data_parallel_size=dp_size,
|
||||
tensor_parallel_size=tp_size,
|
||||
block_size=block_size,
|
||||
)
|
||||
engine_args.kv_events_config = publisher_config
|
||||
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
UsageContext.UNKNOWN_CONTEXT)
|
||||
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=multiprocessing_mode,
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=False,
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Build endpoints for all DP ranks
|
||||
base_endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
|
||||
endpoints = []
|
||||
for i in range(dp_size):
|
||||
offset_endpoint = ZmqEventPublisher.offset_endpoint_port(
|
||||
base_endpoint, i)
|
||||
endpoints.append(offset_endpoint)
|
||||
|
||||
subscriber = MockSubscriber(endpoints,
|
||||
topic=publisher_config.topic,
|
||||
decode_type=KVEventBatch)
|
||||
|
||||
try:
|
||||
custom_tokens = list(range(num_blocks * block_size))
|
||||
sampling_params = SamplingParams(max_tokens=1)
|
||||
all_request_ids = []
|
||||
|
||||
# Create and add 25 requests
|
||||
# NOTE: attempts to force routing to both dp groups but can be flaky
|
||||
for i in range(25):
|
||||
await asyncio.sleep(0.01)
|
||||
request = make_request(sampling_params, custom_tokens)
|
||||
await client.add_request_async(request)
|
||||
all_request_ids.append(request.request_id)
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Initialize outputs dict for all requests
|
||||
outputs: dict[str, list] = {
|
||||
req_id: []
|
||||
for req_id in all_request_ids
|
||||
}
|
||||
|
||||
print("processing requests...")
|
||||
await asyncio.wait_for(loop_until_fully_done_async(
|
||||
client, outputs),
|
||||
timeout=20.0)
|
||||
|
||||
# Receive from subscriber until no more messages
|
||||
print("collecting results...")
|
||||
results = []
|
||||
while True:
|
||||
result = subscriber.receive_one(timeout=1)
|
||||
print(result)
|
||||
if result is None:
|
||||
break
|
||||
results.append(result)
|
||||
|
||||
# Collect all events and data_parallel_ranks from all results
|
||||
all_dp_ranks = [
|
||||
received.data_parallel_rank for (_, received) in results
|
||||
]
|
||||
unique_dps = set(all_dp_ranks)
|
||||
assert (
|
||||
len(unique_dps) == 2
|
||||
), f"Expected 2 unique data_parallel_ranks, got {len(unique_dps)}"
|
||||
|
||||
finally:
|
||||
client.shutdown()
|
||||
subscriber.close()
|
||||
|
||||
|
||||
@pytest.mark.timeout(20)
|
||||
|
||||
Reference in New Issue
Block a user