[Kernel] Support Fp8 Checkpoints (Dynamic + Static) (#4332)
Co-authored-by: Philipp Moritz <pcmoritz@gmail.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: mgoin <michael@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
@@ -246,6 +246,10 @@ class ColumnParallelLinear(LinearBase):
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
# Special case for Fp8 scales.
|
||||
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
|
||||
None)
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
param_data = param.data
|
||||
@@ -254,6 +258,12 @@ class ColumnParallelLinear(LinearBase):
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
# Special case for Fp8 scales.
|
||||
elif fp8_scales_shard_indexer is not None:
|
||||
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
|
||||
loaded_weight,
|
||||
shard_id=0)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
@@ -317,7 +327,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
# Special case for AQLM codebooks.
|
||||
is_metadata = getattr(param, "is_metadata", False)
|
||||
# Special case for Fp8 scales.
|
||||
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
|
||||
None)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already packed.
|
||||
if output_dim is None:
|
||||
@@ -331,14 +346,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
current_shard_offset += output_size
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# Special case for Quantization.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
|
||||
# If marlin, we need to adjust the offset and size to
|
||||
# account for the tiling.
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
@@ -353,15 +367,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
if output_dim is not None:
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // tp_size
|
||||
# Special case for quantization.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
|
||||
# If marlin, we need to adjust the offset and size to
|
||||
# account for the tiling.
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
@@ -370,11 +383,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
# Special case for AQLM codebooks.
|
||||
elif is_metadata:
|
||||
# metadata indicates fixed size concatenated along dim 0
|
||||
shard_size = loaded_weight.shape[0]
|
||||
shard_offset = loaded_shard_id * shard_size
|
||||
param_data = param_data.narrow(0, shard_offset, shard_size)
|
||||
# Special case for Fp8 scales.
|
||||
elif fp8_scales_shard_indexer is not None:
|
||||
param_data, loaded_weight = fp8_scales_shard_indexer(
|
||||
param_data, loaded_weight, loaded_shard_id)
|
||||
|
||||
else:
|
||||
ignore_warning = getattr(param, "ignore_warning", False)
|
||||
if not ignore_warning:
|
||||
@@ -455,7 +474,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
# Special case for AQLM codebooks.
|
||||
is_metadata = getattr(param, "is_metadata", False)
|
||||
# Special case for Fp8 scales.
|
||||
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
|
||||
None)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already packed.
|
||||
@@ -473,14 +496,14 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
]
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
for shard_id, shard_offset, shard_size in shard_offsets:
|
||||
# Special case for Quantized Weights.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
|
||||
# If marlin, we need to adjust the offset and size to
|
||||
# account for the tiling.
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
@@ -502,6 +525,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_offset = (self.num_heads +
|
||||
self.num_kv_heads) * self.head_size
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
# Special case for Quantized Weights.
|
||||
# If quantized, we need to adjust the offset and size to account
|
||||
# for the packing.
|
||||
packed_dim = getattr(param, "packed_dim", None)
|
||||
@@ -509,8 +533,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
|
||||
# If marlin, we need to adjust the offset and size to
|
||||
# account for the tiling.
|
||||
# Special case for Marlin.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
@@ -523,12 +546,17 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
start_idx = shard_id * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
# Special case for for AQLM codebooks.
|
||||
elif is_metadata:
|
||||
# metadata indicates fixed size concatenated along dim 0
|
||||
shard_size = loaded_weight.shape[0]
|
||||
shard_index = ["q", "k", "v"].index(loaded_shard_id)
|
||||
param_data = param_data.narrow(0, shard_index * shard_size,
|
||||
shard_size)
|
||||
# Special case for Fp8 scales.
|
||||
elif fp8_scales_shard_indexer is not None:
|
||||
param_data, loaded_weight = fp8_scales_shard_indexer(
|
||||
param_data, loaded_weight, loaded_shard_id)
|
||||
else:
|
||||
ignore_warning = getattr(param, "ignore_warning", False)
|
||||
if not ignore_warning:
|
||||
@@ -611,6 +639,10 @@ class RowParallelLinear(LinearBase):
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
# Special case for Fp8 scales.
|
||||
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
|
||||
None)
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
input_dim = getattr(param, "input_dim", None)
|
||||
param_data = param.data
|
||||
@@ -619,6 +651,12 @@ class RowParallelLinear(LinearBase):
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
|
||||
shard_size)
|
||||
# Special case for Fp8 scales.
|
||||
elif fp8_scales_shard_indexer is not None:
|
||||
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
|
||||
loaded_weight,
|
||||
shard_id=0)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user