[Misc] Add channel-wise quantization support for w8a8 dynamic per token activation quantization (#5542)
This commit is contained in:
@@ -468,13 +468,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
"MergedColumnParallelLinear, assume the weight is "
|
||||
"the same for all partitions.")
|
||||
|
||||
if fp8_scales_shard_indexer is None:
|
||||
if len(param_data.shape) == 0:
|
||||
param_data = param_data.reshape(1)
|
||||
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
@@ -686,12 +679,6 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
"QKVParallelLinear, assume the weight is the same "
|
||||
"for all partitions.")
|
||||
|
||||
if len(param_data.shape) == 0:
|
||||
param_data = param_data.reshape(1)
|
||||
|
||||
if len(loaded_weight.shape) == 0:
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user