[RL] [Weight Sync] Guard IPC update-info pickle deserialization behind insecure serialization flag (#35928)

Co-authored-by: Cursor Agent <cursoragent@cursor.com>
This commit is contained in:
Simon Mo
2026-03-04 14:05:32 -08:00
committed by GitHub
parent be0a3f7570
commit f678c3f61a
2 changed files with 26 additions and 2 deletions

View File

@@ -456,11 +456,13 @@ class TestIPCWeightTransferUpdateInfoValidation:
ipc_handles=ipc_handles,
)
def test_valid_update_info_from_pickled(self):
def test_valid_update_info_from_pickled(self, monkeypatch):
"""Test creating IPCWeightTransferUpdateInfo from pickled handles."""
if torch.cuda.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
dummy_tensor = torch.ones(10, 10, device="cuda:0")
ipc_handle = reduce_tensor(dummy_tensor)
gpu_uuid = str(torch.cuda.get_device_properties(0).uuid)
@@ -477,6 +479,18 @@ class TestIPCWeightTransferUpdateInfoValidation:
assert info.ipc_handles == ipc_handles
assert info.ipc_handles_pickled is None
def test_pickled_requires_insecure_serialization_flag(self, monkeypatch):
"""Test that pickled handles are rejected unless env flag is enabled."""
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0")
with pytest.raises(ValueError, match="VLLM_ALLOW_INSECURE_SERIALIZATION=1"):
IPCWeightTransferUpdateInfo(
names=[],
dtype_names=[],
shapes=[],
ipc_handles_pickled=base64.b64encode(pickle.dumps([])).decode("utf-8"),
)
def test_both_handles_and_pickled_raises(self):
"""Test that providing both ipc_handles and ipc_handles_pickled raises."""
if torch.cuda.device_count() < 1:
@@ -556,11 +570,13 @@ class TestIPCEngineParsing:
assert update_info.shapes == [[100, 100], [50]]
assert len(update_info.ipc_handles) == 2
def test_parse_update_info_pickled(self):
def test_parse_update_info_pickled(self, monkeypatch):
"""Test parsing update info with pickled IPC handles (HTTP path)."""
if torch.cuda.device_count() < 1:
pytest.skip("Need at least 1 GPU for this test")
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
config = WeightTransferConfig(backend="ipc")
parallel_config = create_mock_parallel_config()
engine = IPCWeightTransferEngine(config, parallel_config)