[Bugfix] fix device_name for routing replay (#34336)

Signed-off-by: liyongwen <1310439159@qq.com>
This commit is contained in:
Li-Yongwen
2026-02-26 20:18:38 +08:00
committed by GitHub
parent c0615a296d
commit c6ca51598a

View File

@@ -20,6 +20,7 @@ import torch
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.forward_context import get_forward_context
from vllm.platforms import current_platform
logger = logging.getLogger(__name__)
@@ -132,7 +133,7 @@ class RoutedExpertsCapturer:
self._device_buffer = torch.zeros(
(max_num_batched_tokens, num_layers, num_experts_per_tok),
dtype=torch.int32,
device="cuda",
device=current_platform.device_type,
)
self.dp_rank = vllm_config.parallel_config.data_parallel_rank