[Core] Support weight_loader_v2 for UnquantizedLinearMethod (#23036)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
Kyle Sayers
2025-09-24 01:30:26 +01:00
committed by GitHub
parent 1983609239
commit de94289a98
3 changed files with 70 additions and 12 deletions

View File

@@ -61,9 +61,24 @@ class BasevLLMParameter(Parameter):
self.tp_size = get_tensor_model_parallel_world_size()
@property
def weight_loader(self):
def weight_loader(self) -> Callable:
# NOTE(@ksayers) some models such as mamba_mixer2 override the
# weight loader to support custom loading. In the future, model-specific
# weight loading should be implemented via Model.load_weights. In the
# meantime, support deleting and overriding `weight_loader`` attribute
if self._weight_loader is None:
raise AttributeError(f"{self.__class__.__name__} weight_loader "
"attribute has been deleted")
return self._weight_loader
@weight_loader.setter
def weight_loader(self, value: Callable):
self._weight_loader = value
@weight_loader.deleter
def weight_loader(self):
self._weight_loader = None # type: ignore[assignment]
def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):
cond1 = self.data.ndim == 1 and self.data.numel() == 1
cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1
@@ -97,6 +112,12 @@ class BasevLLMParameter(Parameter):
assert shard_id in qkv_idxs
return qkv_idxs[shard_id]
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return super().__torch_function__(func, types, args, kwargs)
class _ColumnvLLMParameter(BasevLLMParameter):
"""