[Misc] adjust for ipv6 for mookcacke url parse (#20107)
Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
@@ -20,10 +20,11 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
|||||||
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
|
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
|
||||||
MemorySnapshot, PlaceholderModule, StoreBoolean,
|
MemorySnapshot, PlaceholderModule, StoreBoolean,
|
||||||
bind_kv_cache, common_broadcastable_dtype,
|
bind_kv_cache, common_broadcastable_dtype,
|
||||||
deprecate_kwargs, get_open_port, is_lossless_cast,
|
deprecate_kwargs, get_open_port, get_tcp_uri,
|
||||||
make_zmq_path, make_zmq_socket, memory_profiling,
|
is_lossless_cast, join_host_port, make_zmq_path,
|
||||||
merge_async_iterators, sha256, split_zmq_path,
|
make_zmq_socket, memory_profiling,
|
||||||
supports_kw, swap_dict_values)
|
merge_async_iterators, sha256, split_host_port,
|
||||||
|
split_zmq_path, supports_kw, swap_dict_values)
|
||||||
|
|
||||||
from .utils import create_new_process_for_each_test, error_on_warning
|
from .utils import create_new_process_for_each_test, error_on_warning
|
||||||
|
|
||||||
@@ -876,3 +877,44 @@ def test_make_zmq_socket_ipv6():
|
|||||||
def test_make_zmq_path():
|
def test_make_zmq_path():
|
||||||
assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555"
|
assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555"
|
||||||
assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555"
|
assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_tcp_uri():
|
||||||
|
assert get_tcp_uri("127.0.0.1", 5555) == "tcp://127.0.0.1:5555"
|
||||||
|
assert get_tcp_uri("::1", 5555) == "tcp://[::1]:5555"
|
||||||
|
|
||||||
|
|
||||||
|
def test_split_host_port():
|
||||||
|
# valid ipv4
|
||||||
|
assert split_host_port("127.0.0.1:5555") == ("127.0.0.1", 5555)
|
||||||
|
# invalid ipv4
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# multi colon
|
||||||
|
assert split_host_port("127.0.0.1::5555")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# tailing colon
|
||||||
|
assert split_host_port("127.0.0.1:5555:")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# no colon
|
||||||
|
assert split_host_port("127.0.0.15555")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# none int port
|
||||||
|
assert split_host_port("127.0.0.1:5555a")
|
||||||
|
|
||||||
|
# valid ipv6
|
||||||
|
assert split_host_port("[::1]:5555") == ("::1", 5555)
|
||||||
|
# invalid ipv6
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# multi colon
|
||||||
|
assert split_host_port("[::1]::5555")
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
# no colon
|
||||||
|
assert split_host_port("[::1]5555")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# none int port
|
||||||
|
assert split_host_port("[::1]:5555a")
|
||||||
|
|
||||||
|
|
||||||
|
def test_join_host_port():
|
||||||
|
assert join_host_port("127.0.0.1", 5555) == "127.0.0.1:5555"
|
||||||
|
assert join_host_port("::1", 5555) == "[::1]:5555"
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from safetensors.torch import save as safetensors_save
|
|||||||
from vllm.config import KVTransferConfig
|
from vllm.config import KVTransferConfig
|
||||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import join_host_port, make_zmq_path, split_host_port
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
NONE_INT = -150886311
|
NONE_INT = -150886311
|
||||||
@@ -79,18 +80,19 @@ class MooncakeTransferEngine:
|
|||||||
logger.error(
|
logger.error(
|
||||||
"An error occurred while loading the configuration: %s", exc)
|
"An error occurred while loading the configuration: %s", exc)
|
||||||
raise
|
raise
|
||||||
prefill_host, base_prefill_port = self.config.prefill_url.split(':')
|
prefill_host, base_prefill_port = split_host_port(
|
||||||
decode_host, base_decode_port = self.config.decode_url.split(':')
|
self.config.prefill_url)
|
||||||
|
decode_host, base_decode_port = split_host_port(self.config.decode_url)
|
||||||
|
|
||||||
# Avoid ports conflict when running prefill and decode on the same node
|
# Avoid ports conflict when running prefill and decode on the same node
|
||||||
if prefill_host == decode_host and \
|
if prefill_host == decode_host and \
|
||||||
base_prefill_port == base_decode_port:
|
base_prefill_port == base_decode_port:
|
||||||
base_decode_port = str(int(base_decode_port) + 100)
|
base_decode_port = base_decode_port + 100
|
||||||
|
|
||||||
prefill_port = int(base_prefill_port) + self.local_rank
|
prefill_port = base_prefill_port + self.local_rank
|
||||||
decode_port = int(base_decode_port) + self.local_rank
|
decode_port = base_decode_port + self.local_rank
|
||||||
self.prefill_url = ':'.join([prefill_host, str(prefill_port)])
|
self.prefill_url = join_host_port(prefill_host, prefill_port)
|
||||||
self.decode_url = ':'.join([decode_host, str(decode_port)])
|
self.decode_url = join_host_port(decode_host, decode_port)
|
||||||
|
|
||||||
self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url,
|
self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url,
|
||||||
self.config.metadata_server, self.config.protocol,
|
self.config.metadata_server, self.config.protocol,
|
||||||
@@ -110,22 +112,30 @@ class MooncakeTransferEngine:
|
|||||||
self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port,
|
self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port,
|
||||||
decode_host, base_decode_port)
|
decode_host, base_decode_port)
|
||||||
|
|
||||||
def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: str,
|
def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: int,
|
||||||
d_host: str, d_port: str) -> None:
|
d_host: str, d_port: int) -> None:
|
||||||
"""Set up ZeroMQ sockets for sending and receiving data."""
|
"""Set up ZeroMQ sockets for sending and receiving data."""
|
||||||
# Offsets < 8 are left for initialization in case tp and pp are enabled
|
# Offsets < 8 are left for initialization in case tp and pp are enabled
|
||||||
p_rank_offset = int(p_port) + 8 + self.local_rank * 2
|
p_rank_offset = p_port + 8 + self.local_rank * 2
|
||||||
d_rank_offset = int(d_port) + 8 + self.local_rank * 2
|
d_rank_offset = d_port + 8 + self.local_rank * 2
|
||||||
if kv_rank == 0:
|
if kv_rank == 0:
|
||||||
self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}")
|
self.sender_socket.bind(
|
||||||
self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
|
make_zmq_path("tcp", p_host, p_rank_offset + 1))
|
||||||
self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
|
self.receiver_socket.connect(
|
||||||
self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}")
|
make_zmq_path("tcp", d_host, d_rank_offset + 1))
|
||||||
|
self.sender_ack.connect(
|
||||||
|
make_zmq_path("tcp", d_host, d_rank_offset + 2))
|
||||||
|
self.receiver_ack.bind(
|
||||||
|
make_zmq_path("tcp", p_host, p_rank_offset + 2))
|
||||||
else:
|
else:
|
||||||
self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
|
self.receiver_socket.connect(
|
||||||
self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}")
|
make_zmq_path("tcp", p_host, p_rank_offset + 1))
|
||||||
self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}")
|
self.sender_socket.bind(
|
||||||
self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
|
make_zmq_path("tcp", d_host, d_rank_offset + 1))
|
||||||
|
self.receiver_ack.bind(
|
||||||
|
make_zmq_path("tcp", d_host, d_rank_offset + 2))
|
||||||
|
self.sender_ack.connect(
|
||||||
|
make_zmq_path("tcp", p_host, p_rank_offset + 2))
|
||||||
|
|
||||||
def initialize(self, local_hostname: str, metadata_server: str,
|
def initialize(self, local_hostname: str, metadata_server: str,
|
||||||
protocol: str, device_name: str,
|
protocol: str, device_name: str,
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ from dataclasses import dataclass, field
|
|||||||
from functools import cache, lru_cache, partial, wraps
|
from functools import cache, lru_cache, partial, wraps
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
|
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
|
||||||
Optional, TypeVar, Union, cast, overload)
|
Optional, Tuple, TypeVar, Union, cast, overload)
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
@@ -628,14 +628,34 @@ def is_valid_ipv6_address(address: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def split_host_port(host_port: str) -> Tuple[str, int]:
|
||||||
|
# ipv6
|
||||||
|
if host_port.startswith('['):
|
||||||
|
host, port = host_port.rsplit(']', 1)
|
||||||
|
host = host[1:]
|
||||||
|
port = port.split(':')[1]
|
||||||
|
return host, int(port)
|
||||||
|
else:
|
||||||
|
host, port = host_port.split(':')
|
||||||
|
return host, int(port)
|
||||||
|
|
||||||
|
|
||||||
|
def join_host_port(host: str, port: int) -> str:
|
||||||
|
if is_valid_ipv6_address(host):
|
||||||
|
return f"[{host}]:{port}"
|
||||||
|
else:
|
||||||
|
return f"{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
def get_distributed_init_method(ip: str, port: int) -> str:
|
def get_distributed_init_method(ip: str, port: int) -> str:
|
||||||
return get_tcp_uri(ip, port)
|
return get_tcp_uri(ip, port)
|
||||||
|
|
||||||
|
|
||||||
def get_tcp_uri(ip: str, port: int) -> str:
|
def get_tcp_uri(ip: str, port: int) -> str:
|
||||||
# Brackets are not permitted in ipv4 addresses,
|
if is_valid_ipv6_address(ip):
|
||||||
# see https://github.com/python/cpython/issues/103848
|
return f"tcp://[{ip}]:{port}"
|
||||||
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
|
else:
|
||||||
|
return f"tcp://{ip}:{port}"
|
||||||
|
|
||||||
|
|
||||||
def get_open_zmq_ipc_path() -> str:
|
def get_open_zmq_ipc_path() -> str:
|
||||||
|
|||||||
Reference in New Issue
Block a user