[Misc][Quark] Upstream Quark format to VLLM (#10765)
Signed-off-by: kewang-xlnx <kewang@xilinx.com> Signed-off-by: kewang2 <kewang2@amd.com> Co-authored-by: kewang2 <kewang2@amd.com> Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
@@ -56,8 +56,14 @@ class BasevLLMParameter(Parameter):
|
||||
def weight_loader(self):
|
||||
return self._weight_loader
|
||||
|
||||
def _is_1d_and_scalar(self, loaded_weight: torch.Tensor):
|
||||
cond1 = self.data.ndim == 1 and self.data.numel() == 1
|
||||
cond2 = loaded_weight.ndim == 0 and loaded_weight.numel() == 1
|
||||
return (cond1 and cond2)
|
||||
|
||||
def _assert_and_load(self, loaded_weight: torch.Tensor):
|
||||
assert self.data.shape == loaded_weight.shape
|
||||
assert (self.data.shape == loaded_weight.shape
|
||||
or self._is_1d_and_scalar(loaded_weight))
|
||||
self.data.copy_(loaded_weight)
|
||||
|
||||
def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
|
||||
|
||||
Reference in New Issue
Block a user