[Bugfix] fix device_name for routing replay (#34336)
Signed-off-by: liyongwen <1310439159@qq.com>
This commit is contained in:
@@ -20,6 +20,7 @@ import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -132,7 +133,7 @@ class RoutedExpertsCapturer:
|
||||
self._device_buffer = torch.zeros(
|
||||
(max_num_batched_tokens, num_layers, num_experts_per_tok),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
device=current_platform.device_type,
|
||||
)
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
|
||||
Reference in New Issue
Block a user