[RL] [Weight Sync] Guard IPC update-info pickle deserialization behind insecure serialization flag (#35928)
Co-authored-by: Cursor Agent <cursoragent@cursor.com>
This commit is contained in:
@@ -456,11 +456,13 @@ class TestIPCWeightTransferUpdateInfoValidation:
|
|||||||
ipc_handles=ipc_handles,
|
ipc_handles=ipc_handles,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_valid_update_info_from_pickled(self):
|
def test_valid_update_info_from_pickled(self, monkeypatch):
|
||||||
"""Test creating IPCWeightTransferUpdateInfo from pickled handles."""
|
"""Test creating IPCWeightTransferUpdateInfo from pickled handles."""
|
||||||
if torch.cuda.device_count() < 1:
|
if torch.cuda.device_count() < 1:
|
||||||
pytest.skip("Need at least 1 GPU for this test")
|
pytest.skip("Need at least 1 GPU for this test")
|
||||||
|
|
||||||
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
|
|
||||||
dummy_tensor = torch.ones(10, 10, device="cuda:0")
|
dummy_tensor = torch.ones(10, 10, device="cuda:0")
|
||||||
ipc_handle = reduce_tensor(dummy_tensor)
|
ipc_handle = reduce_tensor(dummy_tensor)
|
||||||
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
|
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
|
||||||
@@ -477,6 +479,18 @@ class TestIPCWeightTransferUpdateInfoValidation:
|
|||||||
assert info.ipc_handles == ipc_handles
|
assert info.ipc_handles == ipc_handles
|
||||||
assert info.ipc_handles_pickled is None
|
assert info.ipc_handles_pickled is None
|
||||||
|
|
||||||
|
def test_pickled_requires_insecure_serialization_flag(self, monkeypatch):
|
||||||
|
"""Test that pickled handles are rejected unless env flag is enabled."""
|
||||||
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="VLLM_ALLOW_INSECURE_SERIALIZATION=1"):
|
||||||
|
IPCWeightTransferUpdateInfo(
|
||||||
|
names=[],
|
||||||
|
dtype_names=[],
|
||||||
|
shapes=[],
|
||||||
|
ipc_handles_pickled=base64.b64encode(pickle.dumps([])).decode("utf-8"),
|
||||||
|
)
|
||||||
|
|
||||||
def test_both_handles_and_pickled_raises(self):
|
def test_both_handles_and_pickled_raises(self):
|
||||||
"""Test that providing both ipc_handles and ipc_handles_pickled raises."""
|
"""Test that providing both ipc_handles and ipc_handles_pickled raises."""
|
||||||
if torch.cuda.device_count() < 1:
|
if torch.cuda.device_count() < 1:
|
||||||
@@ -556,11 +570,13 @@ class TestIPCEngineParsing:
|
|||||||
assert update_info.shapes == [[100, 100], [50]]
|
assert update_info.shapes == [[100, 100], [50]]
|
||||||
assert len(update_info.ipc_handles) == 2
|
assert len(update_info.ipc_handles) == 2
|
||||||
|
|
||||||
def test_parse_update_info_pickled(self):
|
def test_parse_update_info_pickled(self, monkeypatch):
|
||||||
"""Test parsing update info with pickled IPC handles (HTTP path)."""
|
"""Test parsing update info with pickled IPC handles (HTTP path)."""
|
||||||
if torch.cuda.device_count() < 1:
|
if torch.cuda.device_count() < 1:
|
||||||
pytest.skip("Need at least 1 GPU for this test")
|
pytest.skip("Need at least 1 GPU for this test")
|
||||||
|
|
||||||
|
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||||
|
|
||||||
config = WeightTransferConfig(backend="ipc")
|
config = WeightTransferConfig(backend="ipc")
|
||||||
parallel_config = create_mock_parallel_config()
|
parallel_config = create_mock_parallel_config()
|
||||||
engine = IPCWeightTransferEngine(config, parallel_config)
|
engine = IPCWeightTransferEngine(config, parallel_config)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import requests
|
|||||||
import torch
|
import torch
|
||||||
from torch.multiprocessing.reductions import reduce_tensor
|
from torch.multiprocessing.reductions import reduce_tensor
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
from vllm.config.parallel import ParallelConfig
|
from vllm.config.parallel import ParallelConfig
|
||||||
from vllm.config.weight_transfer import WeightTransferConfig
|
from vllm.config.weight_transfer import WeightTransferConfig
|
||||||
from vllm.distributed.weight_transfer.base import (
|
from vllm.distributed.weight_transfer.base import (
|
||||||
@@ -74,6 +75,13 @@ class IPCWeightTransferUpdateInfo(WeightTransferUpdateInfo):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot specify both `ipc_handles` and `ipc_handles_pickled`"
|
"Cannot specify both `ipc_handles` and `ipc_handles_pickled`"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||||
|
raise ValueError(
|
||||||
|
"Refusing to deserialize `ipc_handles_pickled` without "
|
||||||
|
"VLLM_ALLOW_INSECURE_SERIALIZATION=1"
|
||||||
|
)
|
||||||
|
|
||||||
self.ipc_handles = pickle.loads(base64.b64decode(self.ipc_handles_pickled))
|
self.ipc_handles = pickle.loads(base64.b64decode(self.ipc_handles_pickled))
|
||||||
self.ipc_handles_pickled = None
|
self.ipc_handles_pickled = None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user