[NIXL][Bugfix] metrics & testing minor bug (#36051)

Signed-off-by: Andy Lo <andy@mistral.ai>
This commit is contained in:
Andy Lo
2026-03-18 13:39:14 +00:00
committed by GitHub
parent cef1f302d2
commit 98b09ddc27
2 changed files with 9 additions and 8 deletions

View File

@@ -694,16 +694,18 @@ class TestNixlHandshake:
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size(
self, local_tp_size: int, default_vllm_config, dist_init
self, local_tp_size: int, default_vllm_config, dist_init, monkeypatch
):
"""
Verify remote TP > local TP handshake succeeds with different
remote configurations.
"""
monkeypatch.setattr(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size",
lambda: local_tp_size,
)
vllm_config = create_vllm_config()
local_tp_size = 1
vllm_config.parallel_config.tensor_parallel_size = local_tp_size
connector = NixlConnector(
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
@@ -738,10 +740,10 @@ class TestNixlHandshake:
remote_agents = worker._nixl_handshake(
host="localhost",
port=1234,
remote_tp_size=2,
remote_tp_size=4,
expected_engine_id=worker.REMOTE_ENGINE_ID,
)
check_handshake(2)
check_handshake(4)
# NOTE flexibility: a second remote with higher number of ranks is
# discovered. This is not a scenario we actively support right now, but
@@ -759,9 +761,8 @@ class TestNixlHandshake:
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
@pytest.mark.parametrize("local_tp_size", [1, 2])
def test_prefill_tp_size_greater_than_decode_tp_size_mla(
self, local_tp_size: int, default_vllm_config, dist_init
self, default_vllm_config, dist_init
):
"""
Verify remote TP > local TP handshake succeeds with different

View File

@@ -1318,12 +1318,12 @@ class NixlConnectorWorker:
f"Expected {expected_engine_id},"
f"received {metadata.engine_id}."
)
setup_agent_time = time.perf_counter()
# Register Remote agent.
remote_agent_name = self.add_remote_agent(
metadata, remote_rank, remote_tp_size
)
setup_agent_time = time.perf_counter()
logger.debug(
"NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time,