[Core] Implement disagg prefill by StatelessProcessGroup (#10502)

This PR provides initial support for single-node disaggregated prefill in 1P1D scenario.
Signed-off-by: KuntaiDu <kuntai@uchicago.edu>
Co-authored-by: ApostaC <yihua98@uchicago.edu>
Co-authored-by: YaoJiayi <120040070@link.cuhk.edu.cn>
This commit is contained in:
Kuntai Du
2024-12-01 19:01:00 -06:00
committed by GitHub
parent c11f172187
commit 0590ec3fd9
33 changed files with 2525 additions and 21 deletions

View File

@@ -2052,6 +2052,88 @@ class ObservabilityConfig:
f"installed. Original error:\n{otel_import_error_traceback}")
class KVTransferConfig(BaseModel):
"""Configuration for distributed KV cache transfer."""
# The KV connector for vLLM to transmit KV caches between vLLM instances.
kv_connector: Optional[str] = None
# The device used by kv connector to buffer the KV cache.
# Currently only support 'cuda'.
kv_buffer_device: Optional[str] = "cuda"
# The buffer size for TorchDistributedConnector. Measured in number of
# bytes. Recommended value: 1e9 (about 1GB).
kv_buffer_size: float = 1e9
# Whether this vLLM instance produces, consumes KV cache, or both. Choices
# are 'kv_producer', 'kv_consumer', and 'both'.
kv_role: Optional[str] = None
# The rank of this vLLM instance in the KV cache transfer. Typical value:
# 0 for prefill instance, 1 for decode instance.
# Currently only 1P1D is supported.
kv_rank: Optional[int] = None
# The number of parallel instances for KV cache transfer. For
# PyNcclConnector, this should be 2.
kv_parallel_size: int = 1
# The KV connector ip, used to build distributed connection
kv_ip: str = "127.0.0.1"
# The KV connector port, used to build distributed connection
kv_port: int = 14579
@classmethod
def from_cli(cls, cli_value: str) -> "KVTransferConfig":
"""Parse the CLI value for the compilation config."""
return KVTransferConfig.model_validate_json(cli_value)
def model_post_init(self, __context: Any) -> None:
if all([
self.kv_connector is not None,
self.kv_connector != "PyNcclConnector"
]):
raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. "
f"Supported connectors are "
f"`PyNcclConnector`.")
if self.kv_role is not None and self.kv_role not in [
"kv_producer", "kv_consumer", "kv_both"
]:
raise ValueError(
f"Unsupported kv_role: {self.kv_role}. "
f"Supported roles are `kv_producer`, `kv_consumer`, "
f"and `kv_both`")
if self.kv_connector is not None and self.kv_role is None:
raise ValueError("Please specify kv_disagg_role when kv_connector "
"is set, supported roles are `kv_producer`, "
"`kv_consumer`, and `kv_both`")
@property
def is_kv_transfer_instance(self) -> bool:
return self.kv_connector is not None and \
self.kv_role in ["kv_producer", "kv_consumer", "kv_both"]
@property
def need_kv_parallel_group(self) -> bool:
# for those database-based connector, vLLM does not need to create
# parallel group, and in that case the kv parallel size will be 1.
return self.kv_connector is not None and self.kv_parallel_size > 1
@property
def is_kv_producer(self) -> bool:
return self.kv_connector is not None and \
self.kv_role in ["kv_producer", "kv_both"]
@property
def is_kv_consumer(self) -> bool:
return self.kv_connector is not None and \
self.kv_role in ["kv_consumer", "kv_both"]
class CompilationLevel:
# constants for the levels of the compilation process
NO_COMPILATION = 0
@@ -2317,6 +2399,8 @@ class VllmConfig:
quant_config: Optional[QuantizationConfig] = None
compilation_config: CompilationConfig = field(default=None,
init=True) # type: ignore
kv_transfer_config: KVTransferConfig = field(default=None,
init=True) # type: ignore
@staticmethod
def _get_quantization_config(