[ Misc ] Refactor w8a8 to use process_weights_after_load (Simplify Weight Loading) (#5940)
Co-authored-by: Robert Shaw <rshaw@neuralmagic>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
@@ -98,7 +98,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
self.fused_module_in_checkpoint = False
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
||||
|
||||
@@ -114,12 +113,10 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
requires_grad=False)
|
||||
scale[:] = torch.finfo(torch.float8_e4m3fn).min
|
||||
layer.register_parameter(scale_name, scale)
|
||||
set_weight_attrs(
|
||||
scale, {
|
||||
**extra_weight_attrs,
|
||||
"fp8_scales_shard_indexer":
|
||||
self.scales_shard_indexer,
|
||||
})
|
||||
set_weight_attrs(scale, {
|
||||
**extra_weight_attrs,
|
||||
"needs_scalar_to_array": True,
|
||||
})
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -170,26 +167,6 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
**extra_weight_attrs)
|
||||
|
||||
def scales_shard_indexer(
|
||||
self, param: torch.Tensor, loaded_weight: torch.Tensor,
|
||||
shard_id: Optional[Union[str,
|
||||
int]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
qkv_idxs = {"q": 0, "k": 1, "v": 2}
|
||||
|
||||
if shard_id is None:
|
||||
shard_id = 0
|
||||
self.fused_module_in_checkpoint = True
|
||||
elif isinstance(shard_id, int):
|
||||
pass
|
||||
elif isinstance(shard_id, str):
|
||||
if shard_id not in qkv_idxs:
|
||||
raise ValueError(f"Unknown shard_id: {shard_id}")
|
||||
shard_id = qkv_idxs[shard_id]
|
||||
else:
|
||||
ValueError(f"Shard id must be int or str but got {type(shard_id)}")
|
||||
|
||||
return param[shard_id], loaded_weight
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if (not hasattr(layer, "process_after_load")
|
||||
or not layer.process_after_load):
|
||||
@@ -212,7 +189,17 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# Loop over logical weights, requantizing with single scale.
|
||||
max_w_scale = layer.weight_scale.max()
|
||||
|
||||
if not self.fused_module_in_checkpoint:
|
||||
# QKV / MLP is fused in the on disk checkpoint if any of the
|
||||
# weight scales are still set to the default since we initialize
|
||||
# N weight scales for N shards but we only load 1 weight scale
|
||||
# from disk in this case. As a result, we skip dequant -> requant
|
||||
# since we already have quantized QKV together.
|
||||
# Sample Model with fused checkpoint:
|
||||
# * nm-testing/Phi-3-mini-128k-instruct-FP8
|
||||
unfused_module_in_checkpoint = (
|
||||
layer.weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min)
|
||||
|
||||
if unfused_module_in_checkpoint:
|
||||
start = 0
|
||||
for idx, logical_width in enumerate(layer.logical_widths):
|
||||
end = start + logical_width
|
||||
|
||||
Reference in New Issue
Block a user