[ROCm] Enable DeepEP ROCm as all2allbackend for AMD GPUs. (#34692)
Signed-off-by: Tej Kiran <vpolamre@amd.com> Co-authored-by: Tej Kiran <vpolamre@amd.com>
This commit is contained in:
committed by
GitHub
parent
02eec7ecbe
commit
3982bc2cd0
@@ -10,6 +10,7 @@ import vllm.envs as envs
|
||||
from vllm.distributed import get_dp_group, get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import (
|
||||
has_flashinfer_nvlink_one_sided,
|
||||
has_flashinfer_nvlink_two_sided,
|
||||
@@ -325,14 +326,20 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
|
||||
assert num_rdma_bytes is not None
|
||||
assert num_qps_per_rank is not None
|
||||
return dict(
|
||||
# TODO: remove platform-specific logic
|
||||
# once ROCm DeepEP is updated with the latest APIs.
|
||||
kwargs = dict(
|
||||
group=self.cpu_group,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
explicitly_destroy=True,
|
||||
)
|
||||
if not current_platform.is_rocm():
|
||||
kwargs.update(
|
||||
explicitly_destroy=True,
|
||||
)
|
||||
return kwargs
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
assert len(kwargs) == 0, (
|
||||
@@ -397,16 +404,22 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
)
|
||||
|
||||
assert num_rdma_bytes is not None
|
||||
return dict(
|
||||
# TODO: remove platform-specific logic
|
||||
# once ROCm DeepEP is updated with the latest APIs.
|
||||
kwargs = dict(
|
||||
group=self.cpu_group,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
allow_nvlink_for_low_latency_mode=True,
|
||||
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
|
||||
explicitly_destroy=True,
|
||||
)
|
||||
if not current_platform.is_rocm():
|
||||
kwargs.update(
|
||||
allow_nvlink_for_low_latency_mode=True,
|
||||
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
|
||||
explicitly_destroy=True,
|
||||
)
|
||||
return kwargs
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user