[ROCm] Enable DeepEP ROCm as all2allbackend for AMD GPUs. (#34692)

Signed-off-by: Tej Kiran <vpolamre@amd.com>
Co-authored-by: Tej Kiran <vpolamre@amd.com>
This commit is contained in:
Chaitanya Sri Krishna Lolla
2026-03-21 13:02:31 +05:30
committed by GitHub
parent 02eec7ecbe
commit 3982bc2cd0
7 changed files with 68 additions and 29 deletions

View File

@@ -10,6 +10,7 @@ import vllm.envs as envs
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
has_flashinfer_nvlink_one_sided,
has_flashinfer_nvlink_two_sided,
@@ -325,14 +326,20 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
assert num_rdma_bytes is not None
assert num_qps_per_rank is not None
return dict(
# TODO: remove platform-specific logic
# once ROCm DeepEP is updated with the latest APIs.
kwargs = dict(
group=self.cpu_group,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank,
explicitly_destroy=True,
)
if not current_platform.is_rocm():
kwargs.update(
explicitly_destroy=True,
)
return kwargs
def get_handle(self, kwargs):
assert len(kwargs) == 0, (
@@ -397,16 +404,22 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
)
assert num_rdma_bytes is not None
return dict(
# TODO: remove platform-specific logic
# once ROCm DeepEP is updated with the latest APIs.
kwargs = dict(
group=self.cpu_group,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank,
allow_nvlink_for_low_latency_mode=True,
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
explicitly_destroy=True,
)
if not current_platform.is_rocm():
kwargs.update(
allow_nvlink_for_low_latency_mode=True,
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
explicitly_destroy=True,
)
return kwargs
def get_handle(self, kwargs):
"""