diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 8b5b87cba..16c5799f7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -587,6 +587,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): prefix: str = "", use_sparse: bool = False, indexer: object | None = None, + **extra_impl_args, ): super().__init__() self.num_heads = num_heads @@ -639,6 +640,7 @@ class MLAAttention(nn.Module, AttentionLayerBase): v_head_dim=self.v_head_dim, kv_b_proj=kv_b_proj, indexer=indexer, + **extra_impl_args, ) self.use_direct_call = not current_platform.opaque_attention_op() diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index de8083313..576977b00 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -17,9 +17,13 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from .deepseek_v2 import DeepseekV2DecoderLayer, get_spec_layer_idx_from_weight_name +from .deepseek_v2 import ( + DeepseekV2DecoderLayer, + get_spec_layer_idx_from_weight_name, +) from .interfaces import SupportsPP from .utils import maybe_prefix @@ -56,6 +60,8 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) + self.device = current_platform.device_type + self.is_v32 = hasattr(config, "index_topk") if self.is_v32: topk_tokens = config.index_topk @@ -63,7 +69,7 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): vllm_config.scheduler_config.max_num_batched_tokens, topk_tokens, dtype=torch.int32, - device="cuda", + device=self.device, ) else: topk_indices_buffer = None diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 970fa8082..3d26327c7 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -1165,6 +1165,7 @@ class DeepseekV2Model(nn.Module): config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config + self.device = current_platform.device_type self.vocab_size = config.vocab_size self.is_v32 = hasattr(config, "index_topk") @@ -1174,7 +1175,7 @@ class DeepseekV2Model(nn.Module): vllm_config.scheduler_config.max_num_batched_tokens, topk_tokens, dtype=torch.int32, - device="cuda", + device=self.device, ) else: topk_indices_buffer = None