[ROCm] Enable DeepEP ROCm as all2allbackend for AMD GPUs. (#34692)
Signed-off-by: Tej Kiran <vpolamre@amd.com> Co-authored-by: Tej Kiran <vpolamre@amd.com>
This commit is contained in:
committed by
GitHub
parent
02eec7ecbe
commit
3982bc2cd0
@@ -44,7 +44,7 @@ ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install Python and other dependencies
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y software-properties-common git curl sudo vim less libgfortran5 libopenmpi-dev libpci-dev \
|
||||
&& apt-get install -y software-properties-common git curl sudo vim less libgfortran5 libopenmpi-dev libpci-dev liblzma-dev pkg-config \
|
||||
&& for i in 1 2 3; do \
|
||||
add-apt-repository -y ppa:deadsnakes/ppa && break || \
|
||||
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
|
||||
|
||||
@@ -10,6 +10,7 @@ import vllm.envs as envs
|
||||
from vllm.distributed import get_dp_group, get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import (
|
||||
has_flashinfer_nvlink_one_sided,
|
||||
has_flashinfer_nvlink_two_sided,
|
||||
@@ -325,14 +326,20 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
|
||||
assert num_rdma_bytes is not None
|
||||
assert num_qps_per_rank is not None
|
||||
return dict(
|
||||
# TODO: remove platform-specific logic
|
||||
# once ROCm DeepEP is updated with the latest APIs.
|
||||
kwargs = dict(
|
||||
group=self.cpu_group,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=False,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
explicitly_destroy=True,
|
||||
)
|
||||
if not current_platform.is_rocm():
|
||||
kwargs.update(
|
||||
explicitly_destroy=True,
|
||||
)
|
||||
return kwargs
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
assert len(kwargs) == 0, (
|
||||
@@ -397,16 +404,22 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
)
|
||||
|
||||
assert num_rdma_bytes is not None
|
||||
return dict(
|
||||
# TODO: remove platform-specific logic
|
||||
# once ROCm DeepEP is updated with the latest APIs.
|
||||
kwargs = dict(
|
||||
group=self.cpu_group,
|
||||
num_nvl_bytes=num_nvl_bytes,
|
||||
num_rdma_bytes=num_rdma_bytes,
|
||||
low_latency_mode=True,
|
||||
num_qps_per_rank=num_qps_per_rank,
|
||||
allow_nvlink_for_low_latency_mode=True,
|
||||
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
|
||||
explicitly_destroy=True,
|
||||
)
|
||||
if not current_platform.is_rocm():
|
||||
kwargs.update(
|
||||
allow_nvlink_for_low_latency_mode=True,
|
||||
allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL,
|
||||
explicitly_destroy=True,
|
||||
)
|
||||
return kwargs
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
"""
|
||||
|
||||
@@ -346,7 +346,7 @@ class FusedMoEQuantConfig:
|
||||
|
||||
@property
|
||||
def use_fp8_w8a8(self) -> bool:
|
||||
return self.quant_dtype == torch.float8_e4m3fn
|
||||
return self.quant_dtype == current_platform.fp8_dtype()
|
||||
|
||||
@property
|
||||
def use_int8_w8a8(self) -> bool:
|
||||
@@ -566,7 +566,7 @@ def fp8_w8a8_moe_quant_config(
|
||||
Construct a quant config for fp8 activations and fp8 weights.
|
||||
"""
|
||||
return FusedMoEQuantConfig.make(
|
||||
torch.float8_e4m3fn,
|
||||
current_platform.fp8_dtype(),
|
||||
w1_scale=w1_scale,
|
||||
g1_alphas=g1_alphas,
|
||||
w2_scale=w2_scale,
|
||||
|
||||
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.utils import (
|
||||
moe_kernel_quantize_input,
|
||||
normalize_batched_scales_shape,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.worker.ubatching import (
|
||||
dbo_current_ubatch_id,
|
||||
dbo_enabled,
|
||||
@@ -290,23 +291,46 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
|
||||
|
||||
# Dispatch
|
||||
dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids)
|
||||
expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch(
|
||||
a1,
|
||||
dispatch_topk_ids,
|
||||
self.max_tokens_per_rank,
|
||||
num_experts,
|
||||
use_fp8=self.use_fp8_dispatch,
|
||||
round_scale=self.use_ue8m0_dispatch,
|
||||
use_ue8m0=self.use_ue8m0_dispatch,
|
||||
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
|
||||
**(
|
||||
dict(x_global_scale=qc_a1_gscale_or_scale)
|
||||
if qc_a1_gscale_or_scale is not None
|
||||
else dict()
|
||||
),
|
||||
async_finish=False,
|
||||
return_recv_hook=True,
|
||||
)
|
||||
if current_platform.is_rocm():
|
||||
(
|
||||
expert_x,
|
||||
expert_num_tokens,
|
||||
handle,
|
||||
_,
|
||||
hook,
|
||||
) = self.buffer.low_latency_dispatch(
|
||||
a1,
|
||||
dispatch_topk_ids,
|
||||
self.max_tokens_per_rank,
|
||||
num_experts,
|
||||
use_fp8=self.use_fp8_dispatch,
|
||||
async_finish=False,
|
||||
return_recv_hook=True,
|
||||
)
|
||||
else:
|
||||
(
|
||||
expert_x,
|
||||
expert_num_tokens,
|
||||
handle,
|
||||
_,
|
||||
hook,
|
||||
) = self.buffer.low_latency_dispatch(
|
||||
a1,
|
||||
dispatch_topk_ids,
|
||||
self.max_tokens_per_rank,
|
||||
num_experts,
|
||||
use_fp8=self.use_fp8_dispatch,
|
||||
round_scale=self.use_ue8m0_dispatch,
|
||||
use_ue8m0=self.use_ue8m0_dispatch,
|
||||
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
|
||||
**(
|
||||
dict(x_global_scale=qc_a1_gscale_or_scale)
|
||||
if qc_a1_gscale_or_scale is not None
|
||||
else dict()
|
||||
),
|
||||
async_finish=False,
|
||||
return_recv_hook=True,
|
||||
)
|
||||
self.handles[a2a_idx] = handle
|
||||
|
||||
return (
|
||||
|
||||
@@ -1017,6 +1017,7 @@ class BatchedTritonExperts(mk.FusedMoEExpertsModular):
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float8_e4m3fn,
|
||||
torch.float8_e4m3fnuz,
|
||||
]
|
||||
assert expert_tokens_meta is not None
|
||||
|
||||
@@ -1046,7 +1047,7 @@ class BatchedTritonExperts(mk.FusedMoEExpertsModular):
|
||||
compute_type = tl.float16
|
||||
elif hidden_states.dtype == torch.float32:
|
||||
compute_type = tl.float32
|
||||
elif hidden_states.dtype == torch.float8_e4m3fn:
|
||||
elif hidden_states.dtype == current_platform.fp8_dtype():
|
||||
compute_type = tl.bfloat16
|
||||
else:
|
||||
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
|
||||
|
||||
@@ -1616,7 +1616,7 @@ def _get_config_quant_dtype(
|
||||
fused_experts_impl.
|
||||
"""
|
||||
if use_fp8_w8a8:
|
||||
return torch.float8_e4m3fn
|
||||
return current_platform.fp8_dtype()
|
||||
elif use_int8_w8a8:
|
||||
return torch.int8
|
||||
elif ocp_mx_scheme == "w_mxfp4_a_mxfp4":
|
||||
|
||||
@@ -25,6 +25,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
per_tensor_dequantize,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
@@ -265,7 +266,7 @@ def moe_kernel_quantize_input(
|
||||
# weights are already dequantized, and we proceed with normal
|
||||
# activation quantization below.
|
||||
|
||||
if quant_dtype == torch.float8_e4m3fn:
|
||||
if quant_dtype == current_platform.fp8_dtype():
|
||||
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||
elif quant_dtype == torch.int8:
|
||||
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
|
||||
|
||||
Reference in New Issue
Block a user