[ Misc ] Refactor w8a8 to use process_weights_after_load (Simplify Weight Loading) (#5940)

Co-authored-by: Robert Shaw <rshaw@neuralmagic>
This commit is contained in:
Robert Shaw
2024-06-30 19:06:27 -04:00
committed by GitHub
parent 7836fdcc11
commit af9ad46fca
10 changed files with 151 additions and 156 deletions

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Union
import torch
from torch.nn import Module
@@ -98,7 +98,6 @@ class Fp8LinearMethod(LinearMethodBase):
"""
def __init__(self, quant_config: Fp8Config):
self.fused_module_in_checkpoint = False
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
@@ -114,12 +113,10 @@ class Fp8LinearMethod(LinearMethodBase):
requires_grad=False)
scale[:] = torch.finfo(torch.float8_e4m3fn).min
layer.register_parameter(scale_name, scale)
set_weight_attrs(
scale, {
**extra_weight_attrs,
"fp8_scales_shard_indexer":
self.scales_shard_indexer,
})
set_weight_attrs(scale, {
**extra_weight_attrs,
"needs_scalar_to_array": True,
})
def create_weights(
self,
@@ -170,26 +167,6 @@ class Fp8LinearMethod(LinearMethodBase):
output_partition_sizes=output_partition_sizes,
**extra_weight_attrs)
def scales_shard_indexer(
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Optional[Union[str,
int]]) -> Tuple[torch.Tensor, torch.Tensor]:
qkv_idxs = {"q": 0, "k": 1, "v": 2}
if shard_id is None:
shard_id = 0
self.fused_module_in_checkpoint = True
elif isinstance(shard_id, int):
pass
elif isinstance(shard_id, str):
if shard_id not in qkv_idxs:
raise ValueError(f"Unknown shard_id: {shard_id}")
shard_id = qkv_idxs[shard_id]
else:
ValueError(f"Shard id must be int or str but got {type(shard_id)}")
return param[shard_id], loaded_weight
def process_weights_after_loading(self, layer: Module) -> None:
if (not hasattr(layer, "process_after_load")
or not layer.process_after_load):
@@ -212,7 +189,17 @@ class Fp8LinearMethod(LinearMethodBase):
# Loop over logical weights, requantizing with single scale.
max_w_scale = layer.weight_scale.max()
if not self.fused_module_in_checkpoint:
# QKV / MLP is fused in the on disk checkpoint if any of the
# weight scales are still set to the default since we initialize
# N weight scales for N shards but we only load 1 weight scale
# from disk in this case. As a result, we skip dequant -> requant
# since we already have quantized QKV together.
# Sample Model with fused checkpoint:
# * nm-testing/Phi-3-mini-128k-instruct-FP8
unfused_module_in_checkpoint = (
layer.weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min)
if unfused_module_in_checkpoint:
start = 0
for idx, logical_width in enumerate(layer.logical_widths):
end = start + logical_width