[Core] Encoder separation for Encode-Prefill-Decode Disaggregation (#25233)
Signed-off-by: n00909098 <nguyen.kha.long@huawei.com> Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com> Signed-off-by: herotai214 <herotai214@gmail.com> Signed-off-by: Khuong Le <khuong.le.manh@huawei.com> Signed-off-by: Khuong Le <lemanhkhuong2611@gmail.com> Co-authored-by: n00909098 <nguyen.kha.long@huawei.com> Co-authored-by: knlnguyen1802 <knlnguyen1802@gmail.com> Co-authored-by: herotai214 <herotai214@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Khuong Le <khuong.le.manh@huawei.com> Co-authored-by: Khuong Le <lemanhkhuong2611@gmail.com>
This commit is contained in:
@@ -5,6 +5,7 @@ import torch
|
||||
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
ECTransferConfig,
|
||||
KVTransferConfig,
|
||||
ModelConfig,
|
||||
SchedulerConfig,
|
||||
@@ -46,6 +47,8 @@ def create_scheduler(
|
||||
num_speculative_tokens: int | None = None,
|
||||
skip_tokenizer_init: bool = False,
|
||||
async_scheduling: bool = False,
|
||||
use_ec_connector: bool = False,
|
||||
ec_role: str | None = None,
|
||||
) -> Scheduler | AsyncScheduler:
|
||||
"""Create scheduler under test.
|
||||
|
||||
@@ -107,12 +110,23 @@ def create_scheduler(
|
||||
model="ngram", num_speculative_tokens=num_speculative_tokens
|
||||
)
|
||||
|
||||
ec_transfer_config = (
|
||||
ECTransferConfig(
|
||||
ec_connector="ECSharedStorageConnector",
|
||||
ec_role=ec_role,
|
||||
ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test"},
|
||||
)
|
||||
if use_ec_connector
|
||||
else None
|
||||
)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
speculative_config=speculative_config,
|
||||
ec_transfer_config=ec_transfer_config,
|
||||
)
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||
@@ -140,12 +154,14 @@ _none_hash_initialized = False
|
||||
def create_requests(
|
||||
num_requests: int,
|
||||
num_tokens: int = 10,
|
||||
mm_hashes_list: list[list[str]] | None = None,
|
||||
mm_positions: list[list[PlaceholderRange]] | None = None,
|
||||
max_tokens: int = 16,
|
||||
stop_token_ids: list[int] | None = None,
|
||||
prompt_logprobs: int | None = None,
|
||||
same_prompt: bool = False,
|
||||
block_size: int = 16,
|
||||
req_ids: list[str] | None = None,
|
||||
) -> list[Request]:
|
||||
global _none_hash_initialized
|
||||
if not _none_hash_initialized:
|
||||
@@ -160,25 +176,58 @@ def create_requests(
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
requests = []
|
||||
|
||||
if mm_hashes_list is not None:
|
||||
# NOTE: allow manual input; some mm items can have the same identifier
|
||||
# no. of mm_hashes and mm_positions for each request should be identical
|
||||
assert mm_positions is not None, (
|
||||
"mm_positions must be provided when mm_hashes_list is provided"
|
||||
)
|
||||
assert len(mm_hashes_list) == len(mm_positions) == num_requests
|
||||
assert [len(h) for h in mm_hashes_list] == [len(p) for p in mm_positions]
|
||||
|
||||
# Since same identifier would imply they are identical encoder output
|
||||
# Verify mm items with identical identifier are having mm_position.length
|
||||
seen_hashes: dict[str, int] = {}
|
||||
|
||||
if req_ids:
|
||||
assert len(req_ids) == num_requests
|
||||
else:
|
||||
req_ids = [f"{i}" for i in range(num_requests)]
|
||||
|
||||
for i in range(num_requests):
|
||||
mm_features = []
|
||||
if mm_positions is not None:
|
||||
mm_position = mm_positions[i]
|
||||
for j, position in enumerate(mm_position):
|
||||
# Dummy hash for each mm item should be unique
|
||||
# since encoder cache tracks entries by hash
|
||||
|
||||
for j, position in enumerate(
|
||||
mm_positions[i] if mm_positions is not None else []
|
||||
):
|
||||
if mm_hashes_list is not None:
|
||||
identifier = mm_hashes_list[i][j]
|
||||
|
||||
# Verify if position length is identical
|
||||
position_length = position.length
|
||||
if identifier in seen_hashes:
|
||||
assert seen_hashes[identifier] == position_length, (
|
||||
f"mm_hash '{identifier}' has inconsistent position lengths: "
|
||||
f"previously {seen_hashes[identifier]}, now {position_length} "
|
||||
f"at request {i}, position {j}"
|
||||
)
|
||||
else:
|
||||
seen_hashes[identifier] = position_length
|
||||
else:
|
||||
# Unique dummy hash for each mm item
|
||||
identifier = f"hash{i}_{j}"
|
||||
mm_feature = MultiModalFeatureSpec(
|
||||
data=MultiModalKwargsItem.dummy("dummy_m"),
|
||||
mm_position=position,
|
||||
identifier=identifier,
|
||||
modality="image",
|
||||
)
|
||||
mm_features.append(mm_feature)
|
||||
mm_feature = MultiModalFeatureSpec(
|
||||
data=MultiModalKwargsItem.dummy("dummy_m"),
|
||||
mm_position=position,
|
||||
identifier=identifier,
|
||||
modality="image",
|
||||
)
|
||||
mm_features.append(mm_feature)
|
||||
|
||||
prompt_token_ids = [0] * num_tokens if same_prompt else [i] * num_tokens
|
||||
request = Request(
|
||||
request_id=f"{i}",
|
||||
request_id=req_ids[i],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
|
||||
Reference in New Issue
Block a user