[Misc] update fp8 to use vLLMParameter (#7437)
This commit is contained in:
@@ -22,7 +22,7 @@ logger = init_logger(__name__)
|
||||
|
||||
WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
|
||||
"AWQLinearMethod", "GPTQMarlinLinearMethod"
|
||||
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod"
|
||||
]
|
||||
|
||||
|
||||
@@ -349,6 +349,11 @@ class ColumnParallelLinear(LinearBase):
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
# Special case for loading scales off disk, which often do not
|
||||
# have a shape (such as in the case of AutoFP8).
|
||||
if len(loaded_weight.shape) == 0:
|
||||
assert loaded_weight.numel() == 1
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
||||
|
||||
def forward(self, input_):
|
||||
@@ -1021,6 +1026,13 @@ class RowParallelLinear(LinearBase):
|
||||
|
||||
def weight_loader_v2(self, param: BasevLLMParameter,
|
||||
loaded_weight: torch.Tensor):
|
||||
|
||||
# Special case for loading scales off disk, which often do not
|
||||
# have a shape (such as in the case of AutoFP8).
|
||||
if len(loaded_weight.shape) == 0:
|
||||
assert loaded_weight.numel() == 1
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
|
||||
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
||||
|
||||
def forward(self, input_):
|
||||
|
||||
Reference in New Issue
Block a user