diff --git a/vllm/patches/deepseek_v4_attention.py b/vllm/patches/deepseek_v4_attention.py index ca565a85..501898d1 100644 --- a/vllm/patches/deepseek_v4_attention.py +++ b/vllm/patches/deepseek_v4_attention.py @@ -366,9 +366,17 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): compressor = self.compressor def compressor_kv_score() -> torch.Tensor: + # For NVFP4-quantized weights, we can't do a raw torch.mm + # with packed uint8 weights. Use the layer's forward() + # which handles dequantization properly. + wkv_wgate_weight = compressor.fused_wkv_wgate.weight + if wkv_wgate_weight.dtype == torch.uint8: + # NVFP4 packed weights — use forward() for dequant+matmul + score, _ = compressor.fused_wkv_wgate(hidden_states) + return score.to(torch.float32) return torch.mm( hidden_states, - compressor.fused_wkv_wgate.weight.T, + wkv_wgate_weight.T, out_dtype=torch.float32, ) @@ -383,9 +391,13 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer): return weights def indexer_compressor_kv_score() -> torch.Tensor: + wkv_wgate_weight = indexer.compressor.fused_wkv_wgate.weight + if wkv_wgate_weight.dtype == torch.uint8: + score, _ = indexer.compressor.fused_wkv_wgate(hidden_states) + return score.to(torch.float32) return torch.mm( hidden_states, - indexer.compressor.fused_wkv_wgate.weight.T, + wkv_wgate_weight.T, out_dtype=torch.float32, )