[Misc] Add channel-wise quantization support for w8a8 dynamic per token activation quantization (#5542)

This commit is contained in:
Dipika Sikka
2024-06-18 12:45:05 -04:00
committed by GitHub
parent 7879f24dcc
commit 95db455e7f
4 changed files with 45 additions and 32 deletions

View File

@@ -468,13 +468,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.")
if fp8_scales_shard_indexer is None:
if len(param_data.shape) == 0:
param_data = param_data.reshape(1)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
@@ -686,12 +679,6 @@ class QKVParallelLinear(ColumnParallelLinear):
"QKVParallelLinear, assume the weight is the same "
"for all partitions.")
if len(param_data.shape) == 0:
param_data = param_data.reshape(1)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)