[Kernel] Add Exllama as a backend for compressed-tensors (#9395)
This commit is contained in:
@@ -20,9 +20,9 @@ FUSED_LAYER_NAME_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
def pack_weights_into_int32(w_q: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
packed_dim: int = 0):
|
||||
def pack_quantized_values_into_int32(w_q: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
packed_dim: int = 0):
|
||||
# move dim to pack to the end
|
||||
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
|
||||
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
|
||||
@@ -42,9 +42,9 @@ def pack_weights_into_int32(w_q: torch.Tensor,
|
||||
return res.permute(inv_perm)
|
||||
|
||||
|
||||
def unpack_weights_into_int32(w_q: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
packed_dim: int = 0):
|
||||
def unpack_quantized_values_into_int32(w_q: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
packed_dim: int = 0):
|
||||
# move dim to pack to the end
|
||||
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
|
||||
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
|
||||
|
||||
Reference in New Issue
Block a user