[Bugfix] Handle Asym W4A16 (ConchLinearKernel) for CT (#33200)
Signed-off-by: Matthias Gehre <matthias.gehre@amd.com> Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -56,6 +56,7 @@ class ConchLinearKernel(MPLinearKernel):
|
|||||||
# note assumes that
|
# note assumes that
|
||||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||||
|
# `weight_zero_point` is: {input_dim = 1, output_dim = 0, packed_dim = 0}
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
def transform_w_q(x):
|
def transform_w_q(x):
|
||||||
assert isinstance(x, BasevLLMParameter)
|
assert isinstance(x, BasevLLMParameter)
|
||||||
@@ -69,8 +70,49 @@ class ConchLinearKernel(MPLinearKernel):
|
|||||||
x.data = x.data.contiguous()
|
x.data = x.data.contiguous()
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def transform_w_zp(x):
|
||||||
|
# Zero points are stored PACKED as [N//pack_factor, K//G]
|
||||||
|
# The Conch kernel expects UNPACKED zeros: [K//G, N]
|
||||||
|
# We need to unpack and reorder
|
||||||
|
assert isinstance(x, BasevLLMParameter)
|
||||||
|
packed = x.data # shape: [N//pack_factor, K//G], dtype: int32
|
||||||
|
|
||||||
|
# Determine packing based on weight bit width
|
||||||
|
size_bits = self.config.weight_type.size_bits
|
||||||
|
pack_factor = 32 // size_bits # 8 for 4-bit, 4 for 8-bit
|
||||||
|
mask = (1 << size_bits) - 1 # 0xF for 4-bit, 0xFF for 8-bit
|
||||||
|
|
||||||
|
n_packed, k_groups = packed.shape
|
||||||
|
n_full = n_packed * pack_factor
|
||||||
|
|
||||||
|
# Unpack using vectorized bitwise ops
|
||||||
|
# shifts = [0, size_bits, 2*size_bits, ...] for each packed position
|
||||||
|
shifts = torch.arange(
|
||||||
|
0, 32, size_bits, dtype=torch.int32, device=packed.device
|
||||||
|
)
|
||||||
|
# packed: [N//pack_factor, K//G] -> [N//pack_factor, K//G, 1]
|
||||||
|
# shifts: [pack_factor] -> [1, 1, pack_factor]
|
||||||
|
# Result: [N//pack_factor, K//G, pack_factor]
|
||||||
|
unpacked = (packed.unsqueeze(-1) >> shifts) & mask
|
||||||
|
|
||||||
|
# Permute to [K//G, N//pack_factor, pack_factor] then reshape to [K//G, N]
|
||||||
|
unpacked = unpacked.permute(1, 0, 2).reshape(k_groups, n_full)
|
||||||
|
|
||||||
|
x.data = unpacked.to(torch.uint8).contiguous()
|
||||||
|
|
||||||
|
# Update metadata - zeros are no longer packed
|
||||||
|
if hasattr(x, "_input_dim"):
|
||||||
|
x._input_dim = 0
|
||||||
|
if hasattr(x, "_output_dim"):
|
||||||
|
x._output_dim = 1
|
||||||
|
if hasattr(x, "_packed_factor"):
|
||||||
|
x._packed_factor = 1
|
||||||
|
return x
|
||||||
|
|
||||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||||
|
if self.config.zero_points:
|
||||||
|
self._transform_param(layer, self.w_zp_name, transform_w_zp)
|
||||||
|
|
||||||
def apply_weights(
|
def apply_weights(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user