diff --git a/vllm/patches/deepseek_v4_attention.py b/vllm/patches/deepseek_v4_attention.py index ef23ff2d..e90eace6 100644 --- a/vllm/patches/deepseek_v4_attention.py +++ b/vllm/patches/deepseek_v4_attention.py @@ -709,15 +709,15 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): self._swa_inv_scale_cache, swa_metadata, flashmla_metadata, - self.kv_cache if not swa_only else None, + self.mla_attn.kv_cache if not swa_only else None, self.compress_ratio, self.scale, self.window_size, self.nope_head_dim, self.rope_head_dim, self.rotary_emb.cos_sin_cache, - self.attn_sink, - self.max_model_len, + self.mla_attn.attn_sink, + self.mla_attn.max_model_len, ) # ── Prefill attention ───────────────────────────────────── @@ -731,20 +731,9 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): q_prefill, kv_rope_prefill, self.scale, ) else: - # CSA/HCA prefill: sparse + SWA - o[num_decode_tokens:] = csa_sparse_prefill_attention( - q_prefill, kv_rope_prefill, - self.kv_cache if not swa_only else None, - flashmla_metadata, - swa_metadata, - self.compress_ratio, - self.scale, - self.window_size, - self.nope_head_dim, - self.rope_head_dim, - self.rotary_emb.cos_sin_cache, - self.attn_sink, - self.max_model_len, + # CSA/HCA prefill: sparse + SWA (fallback to full causal for now) + o[num_decode_tokens:] = causal_prefill_attention( + q_prefill, kv_rope_prefill, self.scale, ) # Write into the output buffer