diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index c6e972e89..e5a216c77 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -44,7 +44,7 @@ ENV DEBIAN_FRONTEND=noninteractive # Install Python and other dependencies RUN apt-get update -y \ - && apt-get install -y software-properties-common git curl sudo vim less libgfortran5 libopenmpi-dev libpci-dev \ + && apt-get install -y software-properties-common git curl sudo vim less libgfortran5 libopenmpi-dev libpci-dev liblzma-dev pkg-config \ && for i in 1 2 3; do \ add-apt-repository -y ppa:deadsnakes/ppa && break || \ { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \ diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index 0cdff9032..075f4e085 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -10,6 +10,7 @@ import vllm.envs as envs from vllm.distributed import get_dp_group, get_ep_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.utils.flashinfer import ( has_flashinfer_nvlink_one_sided, has_flashinfer_nvlink_two_sided, @@ -325,14 +326,20 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): assert num_rdma_bytes is not None assert num_qps_per_rank is not None - return dict( + # TODO: remove platform-specific logic + # once ROCm DeepEP is updated with the latest APIs. + kwargs = dict( group=self.cpu_group, num_nvl_bytes=num_nvl_bytes, num_rdma_bytes=num_rdma_bytes, low_latency_mode=False, num_qps_per_rank=num_qps_per_rank, - explicitly_destroy=True, ) + if not current_platform.is_rocm(): + kwargs.update( + explicitly_destroy=True, + ) + return kwargs def get_handle(self, kwargs): assert len(kwargs) == 0, ( @@ -397,16 +404,22 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): ) assert num_rdma_bytes is not None - return dict( + # TODO: remove platform-specific logic + # once ROCm DeepEP is updated with the latest APIs. + kwargs = dict( group=self.cpu_group, num_nvl_bytes=num_nvl_bytes, num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, num_qps_per_rank=num_qps_per_rank, - allow_nvlink_for_low_latency_mode=True, - allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL, - explicitly_destroy=True, ) + if not current_platform.is_rocm(): + kwargs.update( + allow_nvlink_for_low_latency_mode=True, + allow_mnnvl=envs.VLLM_DEEPEP_LOW_LATENCY_USE_MNNVL, + explicitly_destroy=True, + ) + return kwargs def get_handle(self, kwargs): """ diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 2eb0f4921..f4e3ed8e0 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -346,7 +346,7 @@ class FusedMoEQuantConfig: @property def use_fp8_w8a8(self) -> bool: - return self.quant_dtype == torch.float8_e4m3fn + return self.quant_dtype == current_platform.fp8_dtype() @property def use_int8_w8a8(self) -> bool: @@ -566,7 +566,7 @@ def fp8_w8a8_moe_quant_config( Construct a quant config for fp8 activations and fp8 weights. """ return FusedMoEQuantConfig.make( - torch.float8_e4m3fn, + current_platform.fp8_dtype(), w1_scale=w1_scale, g1_alphas=g1_alphas, w2_scale=w2_scale, diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index e1d2d5740..a3266f5e8 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe.utils import ( moe_kernel_quantize_input, normalize_batched_scales_shape, ) +from vllm.platforms import current_platform from vllm.v1.worker.ubatching import ( dbo_current_ubatch_id, dbo_enabled, @@ -290,23 +291,46 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalizeModular): # Dispatch dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids) - expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch( - a1, - dispatch_topk_ids, - self.max_tokens_per_rank, - num_experts, - use_fp8=self.use_fp8_dispatch, - round_scale=self.use_ue8m0_dispatch, - use_ue8m0=self.use_ue8m0_dispatch, - **(dict(use_nvfp4=True) if use_nvfp4 else dict()), - **( - dict(x_global_scale=qc_a1_gscale_or_scale) - if qc_a1_gscale_or_scale is not None - else dict() - ), - async_finish=False, - return_recv_hook=True, - ) + if current_platform.is_rocm(): + ( + expert_x, + expert_num_tokens, + handle, + _, + hook, + ) = self.buffer.low_latency_dispatch( + a1, + dispatch_topk_ids, + self.max_tokens_per_rank, + num_experts, + use_fp8=self.use_fp8_dispatch, + async_finish=False, + return_recv_hook=True, + ) + else: + ( + expert_x, + expert_num_tokens, + handle, + _, + hook, + ) = self.buffer.low_latency_dispatch( + a1, + dispatch_topk_ids, + self.max_tokens_per_rank, + num_experts, + use_fp8=self.use_fp8_dispatch, + round_scale=self.use_ue8m0_dispatch, + use_ue8m0=self.use_ue8m0_dispatch, + **(dict(use_nvfp4=True) if use_nvfp4 else dict()), + **( + dict(x_global_scale=qc_a1_gscale_or_scale) + if qc_a1_gscale_or_scale is not None + else dict() + ), + async_finish=False, + return_recv_hook=True, + ) self.handles[a2a_idx] = handle return ( diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 9df94b72d..e2b5a8f67 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -1017,6 +1017,7 @@ class BatchedTritonExperts(mk.FusedMoEExpertsModular): torch.float16, torch.bfloat16, torch.float8_e4m3fn, + torch.float8_e4m3fnuz, ] assert expert_tokens_meta is not None @@ -1046,7 +1047,7 @@ class BatchedTritonExperts(mk.FusedMoEExpertsModular): compute_type = tl.float16 elif hidden_states.dtype == torch.float32: compute_type = tl.float32 - elif hidden_states.dtype == torch.float8_e4m3fn: + elif hidden_states.dtype == current_platform.fp8_dtype(): compute_type = tl.bfloat16 else: raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}") diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 03ca8ba11..d5b8feb3c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1616,7 +1616,7 @@ def _get_config_quant_dtype( fused_experts_impl. """ if use_fp8_w8a8: - return torch.float8_e4m3fn + return current_platform.fp8_dtype() elif use_int8_w8a8: return torch.int8 elif ocp_mx_scheme == "w_mxfp4_a_mxfp4": diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index c733f233f..ba4494f6c 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -25,6 +25,7 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( per_tensor_dequantize, ) +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import is_torch_equal_or_newer @@ -265,7 +266,7 @@ def moe_kernel_quantize_input( # weights are already dequantized, and we proceed with normal # activation quantization below. - if quant_dtype == torch.float8_e4m3fn: + if quant_dtype == current_platform.fp8_dtype(): return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) elif quant_dtype == torch.int8: return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)