diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 8184b0732..109e8496f 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -1308,7 +1308,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ) kv_c_normed = workspace[:toks]\ - [..., :self.kv_lora_rank].unsqueeze(1) + [..., :self.kv_lora_rank] k_pe = workspace[:toks]\ [..., self.kv_lora_rank:].unsqueeze(1) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index c98262eea..0b55854de 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -874,7 +874,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ) kv_c_normed = workspace[:toks]\ - [..., :self.kv_lora_rank].unsqueeze(1) + [..., :self.kv_lora_rank] k_pe = workspace[:toks]\ [..., self.kv_lora_rank:].unsqueeze(1)