[Core] Encoder separation for Encode-Prefill-Decode Disaggregation (#25233)

Signed-off-by: n00909098 <nguyen.kha.long@huawei.com>
Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>
Signed-off-by: herotai214 <herotai214@gmail.com>
Signed-off-by: Khuong Le <khuong.le.manh@huawei.com>
Signed-off-by: Khuong Le <lemanhkhuong2611@gmail.com>
Co-authored-by: n00909098 <nguyen.kha.long@huawei.com>
Co-authored-by: knlnguyen1802 <knlnguyen1802@gmail.com>
Co-authored-by: herotai214 <herotai214@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Khuong Le <khuong.le.manh@huawei.com>
Co-authored-by: Khuong Le <lemanhkhuong2611@gmail.com>
This commit is contained in:
Chenguang Zheng
2025-11-12 10:58:33 +08:00
committed by GitHub
parent cbb799e314
commit 4ccffe561f
31 changed files with 5026 additions and 42 deletions

View File

@@ -5,6 +5,7 @@ import torch
from vllm.config import (
CacheConfig,
ECTransferConfig,
KVTransferConfig,
ModelConfig,
SchedulerConfig,
@@ -46,6 +47,8 @@ def create_scheduler(
num_speculative_tokens: int | None = None,
skip_tokenizer_init: bool = False,
async_scheduling: bool = False,
use_ec_connector: bool = False,
ec_role: str | None = None,
) -> Scheduler | AsyncScheduler:
"""Create scheduler under test.
@@ -107,12 +110,23 @@ def create_scheduler(
model="ngram", num_speculative_tokens=num_speculative_tokens
)
ec_transfer_config = (
ECTransferConfig(
ec_connector="ECSharedStorageConnector",
ec_role=ec_role,
ec_connector_extra_config={"shared_storage_path": "/tmp/ec_test"},
)
if use_ec_connector
else None
)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config,
ec_transfer_config=ec_transfer_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
@@ -140,12 +154,14 @@ _none_hash_initialized = False
def create_requests(
num_requests: int,
num_tokens: int = 10,
mm_hashes_list: list[list[str]] | None = None,
mm_positions: list[list[PlaceholderRange]] | None = None,
max_tokens: int = 16,
stop_token_ids: list[int] | None = None,
prompt_logprobs: int | None = None,
same_prompt: bool = False,
block_size: int = 16,
req_ids: list[str] | None = None,
) -> list[Request]:
global _none_hash_initialized
if not _none_hash_initialized:
@@ -160,25 +176,58 @@ def create_requests(
prompt_logprobs=prompt_logprobs,
)
requests = []
if mm_hashes_list is not None:
# NOTE: allow manual input; some mm items can have the same identifier
# no. of mm_hashes and mm_positions for each request should be identical
assert mm_positions is not None, (
"mm_positions must be provided when mm_hashes_list is provided"
)
assert len(mm_hashes_list) == len(mm_positions) == num_requests
assert [len(h) for h in mm_hashes_list] == [len(p) for p in mm_positions]
# Since same identifier would imply they are identical encoder output
# Verify mm items with identical identifier are having mm_position.length
seen_hashes: dict[str, int] = {}
if req_ids:
assert len(req_ids) == num_requests
else:
req_ids = [f"{i}" for i in range(num_requests)]
for i in range(num_requests):
mm_features = []
if mm_positions is not None:
mm_position = mm_positions[i]
for j, position in enumerate(mm_position):
# Dummy hash for each mm item should be unique
# since encoder cache tracks entries by hash
for j, position in enumerate(
mm_positions[i] if mm_positions is not None else []
):
if mm_hashes_list is not None:
identifier = mm_hashes_list[i][j]
# Verify if position length is identical
position_length = position.length
if identifier in seen_hashes:
assert seen_hashes[identifier] == position_length, (
f"mm_hash '{identifier}' has inconsistent position lengths: "
f"previously {seen_hashes[identifier]}, now {position_length} "
f"at request {i}, position {j}"
)
else:
seen_hashes[identifier] = position_length
else:
# Unique dummy hash for each mm item
identifier = f"hash{i}_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image",
)
mm_features.append(mm_feature)
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image",
)
mm_features.append(mm_feature)
prompt_token_ids = [0] * num_tokens if same_prompt else [i] * num_tokens
request = Request(
request_id=f"{i}",
request_id=req_ids[i],
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=None,