[Core] Encoder separation for Encode-Prefill-Decode Disaggregation (#25233)
Signed-off-by: n00909098 <nguyen.kha.long@huawei.com> Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com> Signed-off-by: herotai214 <herotai214@gmail.com> Signed-off-by: Khuong Le <khuong.le.manh@huawei.com> Signed-off-by: Khuong Le <lemanhkhuong2611@gmail.com> Co-authored-by: n00909098 <nguyen.kha.long@huawei.com> Co-authored-by: knlnguyen1802 <knlnguyen1802@gmail.com> Co-authored-by: herotai214 <herotai214@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Khuong Le <khuong.le.manh@huawei.com> Co-authored-by: Khuong Le <lemanhkhuong2611@gmail.com>
This commit is contained in:
@@ -14,6 +14,7 @@ if TYPE_CHECKING:
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
@@ -21,6 +22,7 @@ if TYPE_CHECKING:
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.request import Request
|
||||
else:
|
||||
ECConnectorMetadata = object
|
||||
KVConnectorMetadata = object
|
||||
LoRARequest = object
|
||||
MultiModalFeatureSpec = object
|
||||
@@ -188,6 +190,9 @@ class SchedulerOutput:
|
||||
# KV Cache Connector metadata.
|
||||
kv_connector_metadata: KVConnectorMetadata | None = None
|
||||
|
||||
# EC Cache Connector metadata
|
||||
ec_connector_metadata: ECConnectorMetadata | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GrammarOutput:
|
||||
|
||||
@@ -7,6 +7,11 @@ from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import (
|
||||
ECConnectorMetadata,
|
||||
ECConnectorRole,
|
||||
)
|
||||
from vllm.distributed.ec_transfer.ec_connector.factory import ECConnectorFactory
|
||||
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
@@ -14,6 +19,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1 import (
|
||||
KVConnectorRole,
|
||||
SupportsHMA,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
@@ -104,6 +110,11 @@ class Scheduler(SchedulerInterface):
|
||||
self.kv_events_config,
|
||||
self.parallel_config.data_parallel_rank,
|
||||
)
|
||||
self.ec_connector = None
|
||||
if self.vllm_config.ec_transfer_config is not None:
|
||||
self.ec_connector = ECConnectorFactory.create_connector(
|
||||
config=self.vllm_config, role=ECConnectorRole.SCHEDULER
|
||||
)
|
||||
|
||||
num_gpu_blocks = self.cache_config.num_gpu_blocks
|
||||
assert num_gpu_blocks is not None and num_gpu_blocks > 0
|
||||
@@ -230,12 +241,14 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
# Schedule encoder inputs.
|
||||
encoder_inputs_to_schedule = None
|
||||
external_load_encoder_input: list[int] = []
|
||||
new_encoder_compute_budget = encoder_compute_budget
|
||||
if request.has_encoder_inputs:
|
||||
(
|
||||
encoder_inputs_to_schedule,
|
||||
num_new_tokens,
|
||||
new_encoder_compute_budget,
|
||||
external_load_encoder_input,
|
||||
) = self._try_schedule_encoder_inputs(
|
||||
request,
|
||||
request.num_computed_tokens,
|
||||
@@ -342,6 +355,11 @@ class Scheduler(SchedulerInterface):
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_compute_budget = new_encoder_compute_budget
|
||||
if external_load_encoder_input:
|
||||
for i in external_load_encoder_input:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
if self.ec_connector is not None:
|
||||
self.ec_connector.update_state_after_alloc(request, i)
|
||||
|
||||
# Record the LoRAs in scheduled_running_reqs
|
||||
scheduled_loras: set[int] = set()
|
||||
@@ -445,6 +463,7 @@ class Scheduler(SchedulerInterface):
|
||||
num_computed_tokens = request.num_computed_tokens
|
||||
|
||||
encoder_inputs_to_schedule = None
|
||||
external_load_encoder_input = []
|
||||
new_encoder_compute_budget = encoder_compute_budget
|
||||
|
||||
# KVTransfer: loading remote KV, do not allocate for new work.
|
||||
@@ -480,6 +499,7 @@ class Scheduler(SchedulerInterface):
|
||||
encoder_inputs_to_schedule,
|
||||
num_new_tokens,
|
||||
new_encoder_compute_budget,
|
||||
external_load_encoder_input,
|
||||
) = self._try_schedule_encoder_inputs(
|
||||
request,
|
||||
num_computed_tokens,
|
||||
@@ -583,7 +603,12 @@ class Scheduler(SchedulerInterface):
|
||||
for i in encoder_inputs_to_schedule:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
encoder_compute_budget = new_encoder_compute_budget
|
||||
|
||||
# Allocate for external load encoder cache
|
||||
if external_load_encoder_input:
|
||||
for i in external_load_encoder_input:
|
||||
self.encoder_cache_manager.allocate(request, i)
|
||||
if self.ec_connector is not None:
|
||||
self.ec_connector.update_state_after_alloc(request, i)
|
||||
# Put back any skipped requests at the head of the waiting queue
|
||||
if skipped_waiting_requests:
|
||||
self.waiting.prepend_requests(skipped_waiting_requests)
|
||||
@@ -591,6 +616,7 @@ class Scheduler(SchedulerInterface):
|
||||
# Check if the scheduling constraints are satisfied.
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
||||
|
||||
assert token_budget >= 0
|
||||
assert len(self.running) <= self.max_num_running_reqs
|
||||
# Since some requests in the RUNNING queue may not be scheduled in
|
||||
@@ -653,8 +679,18 @@ class Scheduler(SchedulerInterface):
|
||||
# 2. Wrap up all the KV cache load / save ops into an opaque object
|
||||
# 3. Clear the internal states of the connector
|
||||
if self.connector is not None:
|
||||
meta = self.connector.build_connector_meta(scheduler_output)
|
||||
meta: KVConnectorMetadata = self.connector.build_connector_meta(
|
||||
scheduler_output
|
||||
)
|
||||
scheduler_output.kv_connector_metadata = meta
|
||||
|
||||
# Build the connector meta for ECConnector
|
||||
if self.ec_connector is not None:
|
||||
ec_meta: ECConnectorMetadata = self.ec_connector.build_connector_meta(
|
||||
scheduler_output
|
||||
)
|
||||
scheduler_output.ec_connector_metadata = ec_meta
|
||||
|
||||
with record_function_or_nullcontext("schedule: update_after_schedule"):
|
||||
self._update_after_schedule(scheduler_output)
|
||||
return scheduler_output
|
||||
@@ -755,7 +791,7 @@ class Scheduler(SchedulerInterface):
|
||||
num_computed_tokens: int,
|
||||
num_new_tokens: int,
|
||||
encoder_compute_budget: int,
|
||||
) -> tuple[list[int], int, int]:
|
||||
) -> tuple[list[int], int, int, list[int]]:
|
||||
"""
|
||||
Determine which encoder inputs need to be scheduled in the current step,
|
||||
and update `num_new_tokens` and encoder token budget accordingly.
|
||||
@@ -765,6 +801,7 @@ class Scheduler(SchedulerInterface):
|
||||
in this step, i.e.,
|
||||
[num_computed_tokens, num_computed_tokens + num_new_tokens).
|
||||
- It is not already computed and stored in the encoder cache.
|
||||
- It is not exist on remote encoder cache (via ECConnector)
|
||||
- There is sufficient encoder token budget to process it.
|
||||
- The encoder cache has space to store it.
|
||||
|
||||
@@ -776,12 +813,16 @@ class Scheduler(SchedulerInterface):
|
||||
blocks and externally cached blocks (via KVConnector).
|
||||
"""
|
||||
if num_new_tokens == 0 or not request.has_encoder_inputs:
|
||||
return [], num_new_tokens, encoder_compute_budget
|
||||
return [], num_new_tokens, encoder_compute_budget, []
|
||||
encoder_inputs_to_schedule: list[int] = []
|
||||
mm_features = request.mm_features
|
||||
assert mm_features is not None
|
||||
assert len(mm_features) > 0
|
||||
external_load_encoder_input = []
|
||||
|
||||
# Check remote cache first
|
||||
if self.ec_connector is not None:
|
||||
remote_cache_has_item = self.ec_connector.has_caches(request)
|
||||
# NOTE: since scheduler operates on the request level (possibly with
|
||||
# multiple encoder inputs per request), we need to create temporary
|
||||
# trackers for accounting at the encoder input level.
|
||||
@@ -862,6 +903,12 @@ class Scheduler(SchedulerInterface):
|
||||
num_new_tokens = 0
|
||||
break
|
||||
|
||||
if self.ec_connector is not None and remote_cache_has_item[i]:
|
||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
||||
external_load_encoder_input.append(i)
|
||||
num_tokens_to_schedule += num_encoder_tokens
|
||||
continue
|
||||
|
||||
num_tokens_to_schedule += num_encoder_tokens
|
||||
encoder_compute_budget -= num_encoder_tokens
|
||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
||||
@@ -871,6 +918,7 @@ class Scheduler(SchedulerInterface):
|
||||
encoder_inputs_to_schedule,
|
||||
num_new_tokens,
|
||||
encoder_compute_budget,
|
||||
external_load_encoder_input,
|
||||
)
|
||||
|
||||
def get_grammar_bitmask(
|
||||
|
||||
Reference in New Issue
Block a user