[Misc] update fp8 to use vLLMParameter (#7437)

This commit is contained in:
Dipika Sikka
2024-08-22 08:36:18 -04:00
committed by GitHub
parent 55d63b1211
commit 955b5191c9
4 changed files with 51 additions and 17 deletions

View File

@@ -22,7 +22,7 @@ logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
"AWQLinearMethod", "GPTQMarlinLinearMethod"
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod"
]
@@ -349,6 +349,11 @@ class ColumnParallelLinear(LinearBase):
param_data.copy_(loaded_weight)
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
param.load_column_parallel_weight(loaded_weight=loaded_weight)
def forward(self, input_):
@@ -1021,6 +1026,13 @@ class RowParallelLinear(LinearBase):
def weight_loader_v2(self, param: BasevLLMParameter,
loaded_weight: torch.Tensor):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
param.load_row_parallel_weight(loaded_weight=loaded_weight)
def forward(self, input_):