Fix compressor kv_score: use forward() for NVFP4 quantized weights

Raw torch.mm doesn't work with packed uint8 NVFP4 weights.
Use MergedColumnParallelLinear.forward() which handles dequantization.
This commit is contained in:
2026-05-19 00:29:43 +00:00
parent 10c14ddb49
commit d9dc042ff7

View File

@@ -366,9 +366,17 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
compressor = self.compressor
def compressor_kv_score() -> torch.Tensor:
# For NVFP4-quantized weights, we can't do a raw torch.mm
# with packed uint8 weights. Use the layer's forward()
# which handles dequantization properly.
wkv_wgate_weight = compressor.fused_wkv_wgate.weight
if wkv_wgate_weight.dtype == torch.uint8:
# NVFP4 packed weights — use forward() for dequant+matmul
score, _ = compressor.fused_wkv_wgate(hidden_states)
return score.to(torch.float32)
return torch.mm(
hidden_states,
compressor.fused_wkv_wgate.weight.T,
wkv_wgate_weight.T,
out_dtype=torch.float32,
)
@@ -383,9 +391,13 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
return weights
def indexer_compressor_kv_score() -> torch.Tensor:
wkv_wgate_weight = indexer.compressor.fused_wkv_wgate.weight
if wkv_wgate_weight.dtype == torch.uint8:
score, _ = indexer.compressor.fused_wkv_wgate(hidden_states)
return score.to(torch.float32)
return torch.mm(
hidden_states,
indexer.compressor.fused_wkv_wgate.weight.T,
wkv_wgate_weight.T,
out_dtype=torch.float32,
)