[Transform] [Quantization] Add transforms to compressed tensors (#22486)
This commit is contained in:
@@ -35,6 +35,7 @@ logger = init_logger(__name__)
|
||||
|
||||
WEIGHT_LOADER_V2_SUPPORTED = [
|
||||
"CompressedTensorsLinearMethod",
|
||||
"CompressedTensorsLinearTransformMethod",
|
||||
"BitBLASLinearMethod",
|
||||
"GPTQBitBLASLinearMethod",
|
||||
"AWQMarlinLinearMethod",
|
||||
@@ -199,6 +200,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# special postprocessing for CPU SGL
|
||||
if current_platform.is_cpu() and envs.VLLM_CPU_SGL_KERNEL:
|
||||
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
|
||||
N, K = layer.weight.size()
|
||||
@@ -1470,7 +1472,7 @@ class QKVCrossParallelLinear(LinearBase):
|
||||
self.bias = torch.nn.Parameter()
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
"weight_loader": self.weight_loader_v1,
|
||||
})
|
||||
else:
|
||||
self.bias = None
|
||||
@@ -1580,6 +1582,18 @@ class QKVCrossParallelLinear(LinearBase):
|
||||
k, v = kv_enc.split(self.kv_size, dim=-1)
|
||||
return q, k, v
|
||||
|
||||
def weight_loader_v1(self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
# just like all other parameters, does not yet
|
||||
# support loading bias with weight_loader_v2
|
||||
layer = (self.q_proj_decoder
|
||||
if loaded_shard_id == "q" else self.kv_proj_encoder)
|
||||
target_param = self.select_proj_params(layer, param)
|
||||
shard_id_args = (loaded_shard_id, ) if loaded_shard_id != "q" else ()
|
||||
layer.weight_loader(target_param, loaded_weight, *shard_id_args)
|
||||
|
||||
def weight_loader(self,
|
||||
param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user