[Core] Support weight_loader_v2 for UnquantizedLinearMethod (#23036)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
@@ -61,9 +61,24 @@ class BasevLLMParameter(Parameter):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
@property
|
||||
def weight_loader(self):
|
||||
def weight_loader(self) -> Callable:
|
||||
# NOTE(@ksayers) some models such as mamba_mixer2 override the
|
||||
# weight loader to support custom loading. In the future, model-specific
|
||||
# weight loading should be implemented via Model.load_weights. In the
|
||||
# meantime, support deleting and overriding `weight_loader`` attribute
|
||||
if self._weight_loader is None:
|
||||
raise AttributeError(f"{self.__class__.__name__} weight_loader "
|
||||
"attribute has been deleted")
|
||||
return self._weight_loader
|
||||
|
||||
@weight_loader.setter
|
||||
def weight_loader(self, value: Callable):
|
||||
self._weight_loader = value
|
||||
|
||||
@weight_loader.deleter
|
||||
def weight_loader(self):
|
||||
self._weight_loader = None # type: ignore[assignment]
|
||||
|
||||
def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):
|
||||
cond1 = self.data.ndim == 1 and self.data.numel() == 1
|
||||
cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1
|
||||
@@ -97,6 +112,12 @@ class BasevLLMParameter(Parameter):
|
||||
assert shard_id in qkv_idxs
|
||||
return qkv_idxs[shard_id]
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
|
||||
|
||||
class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user