[Misc] update fp8 to use vLLMParameter (#7437)

This commit is contained in:
Dipika Sikka
2024-08-22 08:36:18 -04:00
committed by GitHub
parent 55d63b1211
commit 955b5191c9
4 changed files with 51 additions and 17 deletions

View File

@@ -208,10 +208,17 @@ class PerTensorScaleParameter(BasevLLMParameter):
if isinstance(shard_id, int):
return shard_id
# if not int, assume shard_id for qkv
# map to int and return
assert isinstance(shard_id, str)
assert shard_id in self.qkv_idxs
return self.qkv_idxs[shard_id]
# For row parallel layers, no sharding needed
# load weight into parameter as is
def load_row_parallel_weight(self, *args, **kwargs):
super().load_row_parallel_weight(*args, **kwargs)
def load_merged_column_weight(self, *args, **kwargs):
self._load_into_shard_id(*args, **kwargs)
@@ -219,7 +226,7 @@ class PerTensorScaleParameter(BasevLLMParameter):
self._load_into_shard_id(*args, **kwargs)
def load_column_parallel_weight(self, *args, **kwargs):
self._load_into_shard_id(*args, **kwargs)
super().load_row_parallel_weight(*args, **kwargs)
def _load_into_shard_id(self, loaded_weight: torch.Tensor,
shard_id: Union[str, int], **kwargs):