[ROCm] Enable DeepEP ROCm as all2allbackend for AMD GPUs. (#34692)

Signed-off-by: Tej Kiran <vpolamre@amd.com>
Co-authored-by: Tej Kiran <vpolamre@amd.com>
This commit is contained in:
Chaitanya Sri Krishna Lolla
2026-03-21 13:02:31 +05:30
committed by GitHub
parent 02eec7ecbe
commit 3982bc2cd0
7 changed files with 68 additions and 29 deletions

View File

@@ -44,7 +44,7 @@ ENV DEBIAN_FRONTEND=noninteractive
# Install Python and other dependencies
RUN apt-get update -y \
&& apt-get install -y software-properties-common git curl sudo vim less libgfortran5 libopenmpi-dev libpci-dev \
&& apt-get install -y software-properties-common git curl sudo vim less libgfortran5 libopenmpi-dev libpci-dev liblzma-dev pkg-config \
&& for i in 1 2 3; do \
add-apt-repository -y ppa:deadsnakes/ppa && break || \
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \

View File

@@ -10,6 +10,7 @@ import vllm.envs as envs
from vllm.distributed import get_dp_group, get_ep_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.flashinfer import (
has_flashinfer_nvlink_one_sided,
has_flashinfer_nvlink_two_sided,
@@ -325,14 +326,20 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
assert num_rdma_bytes is not None
assert num_qps_per_rank is not None
return dict(
# TODO: remove platform-specific logic
# once ROCm DeepEP is updated with the latest APIs.
kwargs = dict(
group=self.cpu_group,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=False,
num_qps_per_rank=num_qps_per_rank,
explicitly_destroy=True,
)
if not current_platform.is_rocm():
kwargs.update(
explicitly_destroy=True,
)
return kwargs
def get_handle(self, kwargs):
assert len(kwargs) == 0, (
@@ -397,16 +404,22 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
)
assert num_rdma_bytes is not None
return dict(
# TODO: remove platform-specific logic
# once ROCm DeepEP is updated with the latest APIs.
kwargs = dict(
group=self.cpu_group,
num_nvl_bytes=num_nvl_bytes,
num_rdma_bytes=num_rdma_bytes,
low_latency_mode=True,
num_qps_per_rank=num_qps_per_rank,
allow_nvlink_for_low_latency_mode=True,
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
explicitly_destroy=True,
)
if not current_platform.is_rocm():
kwargs.update(
allow_nvlink_for_low_latency_mode=True,
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
explicitly_destroy=True,
)
return kwargs
def get_handle(self, kwargs):
"""

View File

@@ -346,7 +346,7 @@ class FusedMoEQuantConfig:
@property
def use_fp8_w8a8(self) -> bool:
return self.quant_dtype == torch.float8_e4m3fn
return self.quant_dtype == current_platform.fp8_dtype()
@property
def use_int8_w8a8(self) -> bool:
@@ -566,7 +566,7 @@ def fp8_w8a8_moe_quant_config(
Construct a quant config for fp8 activations and fp8 weights.
"""
return FusedMoEQuantConfig.make(
torch.float8_e4m3fn,
current_platform.fp8_dtype(),
w1_scale=w1_scale,
g1_alphas=g1_alphas,
w2_scale=w2_scale,

View File

@@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input,
normalize_batched_scales_shape,
)
from vllm.platforms import current_platform
from vllm.v1.worker.ubatching import (
dbo_current_ubatch_id,
dbo_enabled,
@@ -290,23 +291,46 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
# Dispatch
dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids)
expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch(
a1,
dispatch_topk_ids,
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
round_scale=self.use_ue8m0_dispatch,
use_ue8m0=self.use_ue8m0_dispatch,
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
**(
dict(x_global_scale=qc_a1_gscale_or_scale)
if qc_a1_gscale_or_scale is not None
else dict()
),
async_finish=False,
return_recv_hook=True,
)
if current_platform.is_rocm():
(
expert_x,
expert_num_tokens,
handle,
_,
hook,
) = self.buffer.low_latency_dispatch(
a1,
dispatch_topk_ids,
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
async_finish=False,
return_recv_hook=True,
)
else:
(
expert_x,
expert_num_tokens,
handle,
_,
hook,
) = self.buffer.low_latency_dispatch(
a1,
dispatch_topk_ids,
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
round_scale=self.use_ue8m0_dispatch,
use_ue8m0=self.use_ue8m0_dispatch,
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
**(
dict(x_global_scale=qc_a1_gscale_or_scale)
if qc_a1_gscale_or_scale is not None
else dict()
),
async_finish=False,
return_recv_hook=True,
)
self.handles[a2a_idx] = handle
return (

View File

@@ -1017,6 +1017,7 @@ class BatchedTritonExperts(mk.FusedMoEExpertsModular):
torch.float16,
torch.bfloat16,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
]
assert expert_tokens_meta is not None
@@ -1046,7 +1047,7 @@ class BatchedTritonExperts(mk.FusedMoEExpertsModular):
compute_type = tl.float16
elif hidden_states.dtype == torch.float32:
compute_type = tl.float32
elif hidden_states.dtype == torch.float8_e4m3fn:
elif hidden_states.dtype == current_platform.fp8_dtype():
compute_type = tl.bfloat16
else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")

View File

@@ -1616,7 +1616,7 @@ def _get_config_quant_dtype(
fused_experts_impl.
"""
if use_fp8_w8a8:
return torch.float8_e4m3fn
return current_platform.fp8_dtype()
elif use_int8_w8a8:
return torch.int8
elif ocp_mx_scheme == "w_mxfp4_a_mxfp4":

View File

@@ -25,6 +25,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
per_tensor_dequantize,
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import is_torch_equal_or_newer
@@ -265,7 +266,7 @@ def moe_kernel_quantize_input(
# weights are already dequantized, and we proceed with normal
# activation quantization below.
if quant_dtype == torch.float8_e4m3fn:
if quant_dtype == current_platform.fp8_dtype():
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.int8:
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)