Fix compressor kv_score: use forward() for NVFP4 quantized weights
Raw torch.mm doesn't work with packed uint8 NVFP4 weights. Use MergedColumnParallelLinear.forward() which handles dequantization.
This commit is contained in:
@@ -366,9 +366,17 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
compressor = self.compressor
|
||||
|
||||
def compressor_kv_score() -> torch.Tensor:
|
||||
# For NVFP4-quantized weights, we can't do a raw torch.mm
|
||||
# with packed uint8 weights. Use the layer's forward()
|
||||
# which handles dequantization properly.
|
||||
wkv_wgate_weight = compressor.fused_wkv_wgate.weight
|
||||
if wkv_wgate_weight.dtype == torch.uint8:
|
||||
# NVFP4 packed weights — use forward() for dequant+matmul
|
||||
score, _ = compressor.fused_wkv_wgate(hidden_states)
|
||||
return score.to(torch.float32)
|
||||
return torch.mm(
|
||||
hidden_states,
|
||||
compressor.fused_wkv_wgate.weight.T,
|
||||
wkv_wgate_weight.T,
|
||||
out_dtype=torch.float32,
|
||||
)
|
||||
|
||||
@@ -383,9 +391,13 @@ class DeepseekV4MultiHeadLatentAttentionWrapper(PluggableLayer):
|
||||
return weights
|
||||
|
||||
def indexer_compressor_kv_score() -> torch.Tensor:
|
||||
wkv_wgate_weight = indexer.compressor.fused_wkv_wgate.weight
|
||||
if wkv_wgate_weight.dtype == torch.uint8:
|
||||
score, _ = indexer.compressor.fused_wkv_wgate(hidden_states)
|
||||
return score.to(torch.float32)
|
||||
return torch.mm(
|
||||
hidden_states,
|
||||
indexer.compressor.fused_wkv_wgate.weight.T,
|
||||
wkv_wgate_weight.T,
|
||||
out_dtype=torch.float32,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user