[ROCm] Enable DeepEP ROCm as all2allbackend for AMD GPUs. (#34692)

Signed-off-by: Tej Kiran <vpolamre@amd.com>
Co-authored-by: Tej Kiran <vpolamre@amd.com>
This commit is contained in:
Chaitanya Sri Krishna Lolla
2026-03-21 13:02:31 +05:30
committed by GitHub
parent 02eec7ecbe
commit 3982bc2cd0
7 changed files with 68 additions and 29 deletions

View File

@@ -346,7 +346,7 @@ class FusedMoEQuantConfig:
@property
def use_fp8_w8a8(self) -> bool:
return self.quant_dtype == torch.float8_e4m3fn
return self.quant_dtype == current_platform.fp8_dtype()
@property
def use_int8_w8a8(self) -> bool:
@@ -566,7 +566,7 @@ def fp8_w8a8_moe_quant_config(
Construct a quant config for fp8 activations and fp8 weights.
"""
return FusedMoEQuantConfig.make(
torch.float8_e4m3fn,
current_platform.fp8_dtype(),
w1_scale=w1_scale,
g1_alphas=g1_alphas,
w2_scale=w2_scale,

View File

@@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input,
normalize_batched_scales_shape,
)
from vllm.platforms import current_platform
from vllm.v1.worker.ubatching import (
dbo_current_ubatch_id,
dbo_enabled,
@@ -290,23 +291,46 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular):
# Dispatch
dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids)
expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch(
a1,
dispatch_topk_ids,
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
round_scale=self.use_ue8m0_dispatch,
use_ue8m0=self.use_ue8m0_dispatch,
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
**(
dict(x_global_scale=qc_a1_gscale_or_scale)
if qc_a1_gscale_or_scale is not None
else dict()
),
async_finish=False,
return_recv_hook=True,
)
if current_platform.is_rocm():
(
expert_x,
expert_num_tokens,
handle,
_,
hook,
) = self.buffer.low_latency_dispatch(
a1,
dispatch_topk_ids,
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
async_finish=False,
return_recv_hook=True,
)
else:
(
expert_x,
expert_num_tokens,
handle,
_,
hook,
) = self.buffer.low_latency_dispatch(
a1,
dispatch_topk_ids,
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
round_scale=self.use_ue8m0_dispatch,
use_ue8m0=self.use_ue8m0_dispatch,
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
**(
dict(x_global_scale=qc_a1_gscale_or_scale)
if qc_a1_gscale_or_scale is not None
else dict()
),
async_finish=False,
return_recv_hook=True,
)
self.handles[a2a_idx] = handle
return (

View File

@@ -1017,6 +1017,7 @@ class BatchedTritonExperts(mk.FusedMoEExpertsModular):
torch.float16,
torch.bfloat16,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
]
assert expert_tokens_meta is not None
@@ -1046,7 +1047,7 @@ class BatchedTritonExperts(mk.FusedMoEExpertsModular):
compute_type = tl.float16
elif hidden_states.dtype == torch.float32:
compute_type = tl.float32
elif hidden_states.dtype == torch.float8_e4m3fn:
elif hidden_states.dtype == current_platform.fp8_dtype():
compute_type = tl.bfloat16
else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")

View File

@@ -1616,7 +1616,7 @@ def _get_config_quant_dtype(
fused_experts_impl.
"""
if use_fp8_w8a8:
return torch.float8_e4m3fn
return current_platform.fp8_dtype()
elif use_int8_w8a8:
return torch.int8
elif ocp_mx_scheme == "w_mxfp4_a_mxfp4":

View File

@@ -25,6 +25,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
per_tensor_dequantize,
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import is_torch_equal_or_newer
@@ -265,7 +266,7 @@ def moe_kernel_quantize_input(
# weights are already dequantized, and we proceed with normal
# activation quantization below.
if quant_dtype == torch.float8_e4m3fn:
if quant_dtype == current_platform.fp8_dtype():
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.int8:
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)