diff --git a/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py b/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py index 7608e06aa..b061b3d38 100644 --- a/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py +++ b/vllm/model_executor/layers/fused_moe/routed_experts_capturer.py @@ -20,6 +20,7 @@ import torch from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_rank from vllm.forward_context import get_forward_context +from vllm.platforms import current_platform logger = logging.getLogger(__name__) @@ -132,7 +133,7 @@ class RoutedExpertsCapturer: self._device_buffer = torch.zeros( (max_num_batched_tokens, num_layers, num_experts_per_tok), dtype=torch.int32, - device="cuda", + device=current_platform.device_type, ) self.dp_rank = vllm_config.parallel_config.data_parallel_rank