[Core] Encoder separation for Encode-Prefill-Decode Disaggregation (#25233)

Signed-off-by: n00909098 <nguyen.kha.long@huawei.com>
Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>
Signed-off-by: herotai214 <herotai214@gmail.com>
Signed-off-by: Khuong Le <khuong.le.manh@huawei.com>
Signed-off-by: Khuong Le <lemanhkhuong2611@gmail.com>
Co-authored-by: n00909098 <nguyen.kha.long@huawei.com>
Co-authored-by: knlnguyen1802 <knlnguyen1802@gmail.com>
Co-authored-by: herotai214 <herotai214@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Khuong Le <khuong.le.manh@huawei.com>
Co-authored-by: Khuong Le <lemanhkhuong2611@gmail.com>
This commit is contained in:
Chenguang Zheng
2025-11-12 10:58:33 +08:00
committed by GitHub
parent cbb799e314
commit 4ccffe561f
31 changed files with 5026 additions and 42 deletions

View File

@@ -7,6 +7,11 @@ from collections.abc import Iterable
from typing import Any
from vllm.config import VllmConfig
from vllm.distributed.ec_transfer.ec_connector.base import (
ECConnectorMetadata,
ECConnectorRole,
)
from vllm.distributed.ec_transfer.ec_connector.factory import ECConnectorFactory
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.distributed.kv_transfer.kv_connector.v1 import (
@@ -14,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
KVConnectorRole,
SupportsHMA,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
@@ -104,6 +110,11 @@ class Scheduler(SchedulerInterface):
self.kv_events_config,
self.parallel_config.data_parallel_rank,
)
self.ec_connector = None
if self.vllm_config.ec_transfer_config is not None:
self.ec_connector = ECConnectorFactory.create_connector(
config=self.vllm_config, role=ECConnectorRole.SCHEDULER
)
num_gpu_blocks = self.cache_config.num_gpu_blocks
assert num_gpu_blocks is not None and num_gpu_blocks > 0
@@ -230,12 +241,14 @@ class Scheduler(SchedulerInterface):
# Schedule encoder inputs.
encoder_inputs_to_schedule = None
external_load_encoder_input: list[int] = []
new_encoder_compute_budget = encoder_compute_budget
if request.has_encoder_inputs:
(
encoder_inputs_to_schedule,
num_new_tokens,
new_encoder_compute_budget,
external_load_encoder_input,
) = self._try_schedule_encoder_inputs(
request,
request.num_computed_tokens,
@@ -342,6 +355,11 @@ class Scheduler(SchedulerInterface):
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_compute_budget = new_encoder_compute_budget
if external_load_encoder_input:
for i in external_load_encoder_input:
self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i)
# Record the LoRAs in scheduled_running_reqs
scheduled_loras: set[int] = set()
@@ -445,6 +463,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens = request.num_computed_tokens
encoder_inputs_to_schedule = None
external_load_encoder_input = []
new_encoder_compute_budget = encoder_compute_budget
# KVTransfer: loading remote KV, do not allocate for new work.
@@ -480,6 +499,7 @@ class Scheduler(SchedulerInterface):
encoder_inputs_to_schedule,
num_new_tokens,
new_encoder_compute_budget,
external_load_encoder_input,
) = self._try_schedule_encoder_inputs(
request,
num_computed_tokens,
@@ -583,7 +603,12 @@ class Scheduler(SchedulerInterface):
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_compute_budget = new_encoder_compute_budget
# Allocate for external load encoder cache
if external_load_encoder_input:
for i in external_load_encoder_input:
self.encoder_cache_manager.allocate(request, i)
if self.ec_connector is not None:
self.ec_connector.update_state_after_alloc(request, i)
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests)
@@ -591,6 +616,7 @@ class Scheduler(SchedulerInterface):
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs
# Since some requests in the RUNNING queue may not be scheduled in
@@ -653,8 +679,18 @@ class Scheduler(SchedulerInterface):
# 2. Wrap up all the KV cache load / save ops into an opaque object
# 3. Clear the internal states of the connector
if self.connector is not None:
meta = self.connector.build_connector_meta(scheduler_output)
meta: KVConnectorMetadata = self.connector.build_connector_meta(
scheduler_output
)
scheduler_output.kv_connector_metadata = meta
# Build the connector meta for ECConnector
if self.ec_connector is not None:
ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(
scheduler_output
)
scheduler_output.ec_connector_metadata = ec_meta
with record_function_or_nullcontext("schedule: update_after_schedule"):
self._update_after_schedule(scheduler_output)
return scheduler_output
@@ -755,7 +791,7 @@ class Scheduler(SchedulerInterface):
num_computed_tokens: int,
num_new_tokens: int,
encoder_compute_budget: int,
) -> tuple[list[int], int, int]:
) -> tuple[list[int], int, int, list[int]]:
"""
Determine which encoder inputs need to be scheduled in the current step,
and update `num_new_tokens` and encoder token budget accordingly.
@@ -765,6 +801,7 @@ class Scheduler(SchedulerInterface):
in this step, i.e.,
[num_computed_tokens, num_computed_tokens + num_new_tokens).
- It is not already computed and stored in the encoder cache.
- It is not exist on remote encoder cache (via ECConnector)
- There is sufficient encoder token budget to process it.
- The encoder cache has space to store it.
@@ -776,12 +813,16 @@ class Scheduler(SchedulerInterface):
blocks and externally cached blocks (via KVConnector).
"""
if num_new_tokens == 0 or not request.has_encoder_inputs:
return [], num_new_tokens, encoder_compute_budget
return [], num_new_tokens, encoder_compute_budget, []
encoder_inputs_to_schedule: list[int] = []
mm_features = request.mm_features
assert mm_features is not None
assert len(mm_features) > 0
external_load_encoder_input = []
# Check remote cache first
if self.ec_connector is not None:
remote_cache_has_item = self.ec_connector.has_caches(request)
# NOTE: since scheduler operates on the request level (possibly with
# multiple encoder inputs per request), we need to create temporary
# trackers for accounting at the encoder input level.
@@ -862,6 +903,12 @@ class Scheduler(SchedulerInterface):
num_new_tokens = 0
break
if self.ec_connector is not None and remote_cache_has_item[i]:
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
external_load_encoder_input.append(i)
num_tokens_to_schedule += num_encoder_tokens
continue
num_tokens_to_schedule += num_encoder_tokens
encoder_compute_budget -= num_encoder_tokens
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
@@ -871,6 +918,7 @@ class Scheduler(SchedulerInterface):
encoder_inputs_to_schedule,
num_new_tokens,
encoder_compute_budget,
external_load_encoder_input,
)
def get_grammar_bitmask(