[NIXL][Bugfix] metrics & testing minor bug (#36051)
Signed-off-by: Andy Lo <andy@mistral.ai>
This commit is contained in:
@@ -694,16 +694,18 @@ class TestNixlHandshake:
|
||||
)
|
||||
@pytest.mark.parametrize("local_tp_size", [1, 2])
|
||||
def test_prefill_tp_size_greater_than_decode_tp_size(
|
||||
self, local_tp_size: int, default_vllm_config, dist_init
|
||||
self, local_tp_size: int, default_vllm_config, dist_init, monkeypatch
|
||||
):
|
||||
"""
|
||||
Verify remote TP > local TP handshake succeeds with different
|
||||
remote configurations.
|
||||
"""
|
||||
monkeypatch.setattr(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size",
|
||||
lambda: local_tp_size,
|
||||
)
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
local_tp_size = 1
|
||||
vllm_config.parallel_config.tensor_parallel_size = local_tp_size
|
||||
|
||||
connector = NixlConnector(
|
||||
vllm_config, KVConnectorRole.WORKER, make_kv_cache_config(block_size=16)
|
||||
@@ -738,10 +740,10 @@ class TestNixlHandshake:
|
||||
remote_agents = worker._nixl_handshake(
|
||||
host="localhost",
|
||||
port=1234,
|
||||
remote_tp_size=2,
|
||||
remote_tp_size=4,
|
||||
expected_engine_id=worker.REMOTE_ENGINE_ID,
|
||||
)
|
||||
check_handshake(2)
|
||||
check_handshake(4)
|
||||
|
||||
# NOTE flexibility: a second remote with higher number of ranks is
|
||||
# discovered. This is not a scenario we actively support right now, but
|
||||
@@ -759,9 +761,8 @@ class TestNixlHandshake:
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper,
|
||||
)
|
||||
@pytest.mark.parametrize("local_tp_size", [1, 2])
|
||||
def test_prefill_tp_size_greater_than_decode_tp_size_mla(
|
||||
self, local_tp_size: int, default_vllm_config, dist_init
|
||||
self, default_vllm_config, dist_init
|
||||
):
|
||||
"""
|
||||
Verify remote TP > local TP handshake succeeds with different
|
||||
|
||||
@@ -1318,12 +1318,12 @@ class NixlConnectorWorker:
|
||||
f"Expected {expected_engine_id},"
|
||||
f"received {metadata.engine_id}."
|
||||
)
|
||||
setup_agent_time = time.perf_counter()
|
||||
|
||||
# Register Remote agent.
|
||||
remote_agent_name = self.add_remote_agent(
|
||||
metadata, remote_rank, remote_tp_size
|
||||
)
|
||||
setup_agent_time = time.perf_counter()
|
||||
logger.debug(
|
||||
"NIXL handshake: add agent took: %s",
|
||||
setup_agent_time - got_metadata_time,
|
||||
|
||||
Reference in New Issue
Block a user