[RL] [Weight Sync] Guard IPC update-info pickle deserialization behind insecure serialization flag (#35928)
Co-authored-by: Cursor Agent <cursoragent@cursor.com>
This commit is contained in:
@@ -456,11 +456,13 @@ class TestIPCWeightTransferUpdateInfoValidation:
|
||||
ipc_handles=ipc_handles,
|
||||
)
|
||||
|
||||
def test_valid_update_info_from_pickled(self):
|
||||
def test_valid_update_info_from_pickled(self, monkeypatch):
|
||||
"""Test creating IPCWeightTransferUpdateInfo from pickled handles."""
|
||||
if torch.cuda.device_count() < 1:
|
||||
pytest.skip("Need at least 1 GPU for this test")
|
||||
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
dummy_tensor = torch.ones(10, 10, device="cuda:0")
|
||||
ipc_handle = reduce_tensor(dummy_tensor)
|
||||
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
|
||||
@@ -477,6 +479,18 @@ class TestIPCWeightTransferUpdateInfoValidation:
|
||||
assert info.ipc_handles == ipc_handles
|
||||
assert info.ipc_handles_pickled is None
|
||||
|
||||
def test_pickled_requires_insecure_serialization_flag(self, monkeypatch):
|
||||
"""Test that pickled handles are rejected unless env flag is enabled."""
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0")
|
||||
|
||||
with pytest.raises(ValueError, match="VLLM_ALLOW_INSECURE_SERIALIZATION=1"):
|
||||
IPCWeightTransferUpdateInfo(
|
||||
names=[],
|
||||
dtype_names=[],
|
||||
shapes=[],
|
||||
ipc_handles_pickled=base64.b64encode(pickle.dumps([])).decode("utf-8"),
|
||||
)
|
||||
|
||||
def test_both_handles_and_pickled_raises(self):
|
||||
"""Test that providing both ipc_handles and ipc_handles_pickled raises."""
|
||||
if torch.cuda.device_count() < 1:
|
||||
@@ -556,11 +570,13 @@ class TestIPCEngineParsing:
|
||||
assert update_info.shapes == [[100, 100], [50]]
|
||||
assert len(update_info.ipc_handles) == 2
|
||||
|
||||
def test_parse_update_info_pickled(self):
|
||||
def test_parse_update_info_pickled(self, monkeypatch):
|
||||
"""Test parsing update info with pickled IPC handles (HTTP path)."""
|
||||
if torch.cuda.device_count() < 1:
|
||||
pytest.skip("Need at least 1 GPU for this test")
|
||||
|
||||
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
config = WeightTransferConfig(backend="ipc")
|
||||
parallel_config = create_mock_parallel_config()
|
||||
engine = IPCWeightTransferEngine(config, parallel_config)
|
||||
|
||||
Reference in New Issue
Block a user