[Core] Implement disagg prefill by StatelessProcessGroup (#10502)
This PR provides initial support for single-node disaggregated prefill in 1P1D scenario. Signed-off-by: KuntaiDu <kuntai@uchicago.edu> Co-authored-by: ApostaC <yihua98@uchicago.edu> Co-authored-by: YaoJiayi <120040070@link.cuhk.edu.cn>
This commit is contained in:
@@ -2052,6 +2052,88 @@ class ObservabilityConfig:
|
||||
f"installed. Original error:\n{otel_import_error_traceback}")
|
||||
|
||||
|
||||
class KVTransferConfig(BaseModel):
|
||||
"""Configuration for distributed KV cache transfer."""
|
||||
|
||||
# The KV connector for vLLM to transmit KV caches between vLLM instances.
|
||||
kv_connector: Optional[str] = None
|
||||
|
||||
# The device used by kv connector to buffer the KV cache.
|
||||
# Currently only support 'cuda'.
|
||||
kv_buffer_device: Optional[str] = "cuda"
|
||||
|
||||
# The buffer size for TorchDistributedConnector. Measured in number of
|
||||
# bytes. Recommended value: 1e9 (about 1GB).
|
||||
kv_buffer_size: float = 1e9
|
||||
|
||||
# Whether this vLLM instance produces, consumes KV cache, or both. Choices
|
||||
# are 'kv_producer', 'kv_consumer', and 'both'.
|
||||
kv_role: Optional[str] = None
|
||||
|
||||
# The rank of this vLLM instance in the KV cache transfer. Typical value:
|
||||
# 0 for prefill instance, 1 for decode instance.
|
||||
# Currently only 1P1D is supported.
|
||||
kv_rank: Optional[int] = None
|
||||
|
||||
# The number of parallel instances for KV cache transfer. For
|
||||
# PyNcclConnector, this should be 2.
|
||||
kv_parallel_size: int = 1
|
||||
|
||||
# The KV connector ip, used to build distributed connection
|
||||
kv_ip: str = "127.0.0.1"
|
||||
|
||||
# The KV connector port, used to build distributed connection
|
||||
kv_port: int = 14579
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, cli_value: str) -> "KVTransferConfig":
|
||||
"""Parse the CLI value for the compilation config."""
|
||||
return KVTransferConfig.model_validate_json(cli_value)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if all([
|
||||
self.kv_connector is not None,
|
||||
self.kv_connector != "PyNcclConnector"
|
||||
]):
|
||||
raise ValueError(f"Unsupported kv_connector: {self.kv_connector}. "
|
||||
f"Supported connectors are "
|
||||
f"`PyNcclConnector`.")
|
||||
|
||||
if self.kv_role is not None and self.kv_role not in [
|
||||
"kv_producer", "kv_consumer", "kv_both"
|
||||
]:
|
||||
raise ValueError(
|
||||
f"Unsupported kv_role: {self.kv_role}. "
|
||||
f"Supported roles are `kv_producer`, `kv_consumer`, "
|
||||
f"and `kv_both`")
|
||||
|
||||
if self.kv_connector is not None and self.kv_role is None:
|
||||
raise ValueError("Please specify kv_disagg_role when kv_connector "
|
||||
"is set, supported roles are `kv_producer`, "
|
||||
"`kv_consumer`, and `kv_both`")
|
||||
|
||||
@property
|
||||
def is_kv_transfer_instance(self) -> bool:
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in ["kv_producer", "kv_consumer", "kv_both"]
|
||||
|
||||
@property
|
||||
def need_kv_parallel_group(self) -> bool:
|
||||
# for those database-based connector, vLLM does not need to create
|
||||
# parallel group, and in that case the kv parallel size will be 1.
|
||||
return self.kv_connector is not None and self.kv_parallel_size > 1
|
||||
|
||||
@property
|
||||
def is_kv_producer(self) -> bool:
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in ["kv_producer", "kv_both"]
|
||||
|
||||
@property
|
||||
def is_kv_consumer(self) -> bool:
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in ["kv_consumer", "kv_both"]
|
||||
|
||||
|
||||
class CompilationLevel:
|
||||
# constants for the levels of the compilation process
|
||||
NO_COMPILATION = 0
|
||||
@@ -2317,6 +2399,8 @@ class VllmConfig:
|
||||
quant_config: Optional[QuantizationConfig] = None
|
||||
compilation_config: CompilationConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
kv_transfer_config: KVTransferConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def _get_quantization_config(
|
||||
|
||||
Reference in New Issue
Block a user