[Misc] update fp8 to use vLLMParameter (#7437)
This commit is contained in:
@@ -208,10 +208,17 @@ class PerTensorScaleParameter(BasevLLMParameter):
|
||||
if isinstance(shard_id, int):
|
||||
return shard_id
|
||||
|
||||
# if not int, assume shard_id for qkv
|
||||
# map to int and return
|
||||
assert isinstance(shard_id, str)
|
||||
assert shard_id in self.qkv_idxs
|
||||
return self.qkv_idxs[shard_id]
|
||||
|
||||
# For row parallel layers, no sharding needed
|
||||
# load weight into parameter as is
|
||||
def load_row_parallel_weight(self, *args, **kwargs):
|
||||
super().load_row_parallel_weight(*args, **kwargs)
|
||||
|
||||
def load_merged_column_weight(self, *args, **kwargs):
|
||||
self._load_into_shard_id(*args, **kwargs)
|
||||
|
||||
@@ -219,7 +226,7 @@ class PerTensorScaleParameter(BasevLLMParameter):
|
||||
self._load_into_shard_id(*args, **kwargs)
|
||||
|
||||
def load_column_parallel_weight(self, *args, **kwargs):
|
||||
self._load_into_shard_id(*args, **kwargs)
|
||||
super().load_row_parallel_weight(*args, **kwargs)
|
||||
|
||||
def _load_into_shard_id(self, loaded_weight: torch.Tensor,
|
||||
shard_id: Union[str, int], **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user