From 73419abfae97c0b797c7e1fe997913e9ebfeff68 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Sat, 31 Jan 2026 07:21:51 +0100 Subject: [PATCH] [Bugfix] Handle Asym W4A16 (ConchLinearKernel) for CT (#33200) Signed-off-by: Matthias Gehre Co-authored-by: Cursor --- .../kernels/mixed_precision/conch.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py b/vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py index 53b2e15df..e98676e01 100644 --- a/vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py +++ b/vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py @@ -56,6 +56,7 @@ class ConchLinearKernel(MPLinearKernel): # note assumes that # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} # `weight_scale` is: {input_dim = 0, output_dim = 1} + # `weight_zero_point` is: {input_dim = 1, output_dim = 0, packed_dim = 0} def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def transform_w_q(x): assert isinstance(x, BasevLLMParameter) @@ -69,8 +70,49 @@ class ConchLinearKernel(MPLinearKernel): x.data = x.data.contiguous() return x + def transform_w_zp(x): + # Zero points are stored PACKED as [N//pack_factor, K//G] + # The Conch kernel expects UNPACKED zeros: [K//G, N] + # We need to unpack and reorder + assert isinstance(x, BasevLLMParameter) + packed = x.data # shape: [N//pack_factor, K//G], dtype: int32 + + # Determine packing based on weight bit width + size_bits = self.config.weight_type.size_bits + pack_factor = 32 // size_bits # 8 for 4-bit, 4 for 8-bit + mask = (1 << size_bits) - 1 # 0xF for 4-bit, 0xFF for 8-bit + + n_packed, k_groups = packed.shape + n_full = n_packed * pack_factor + + # Unpack using vectorized bitwise ops + # shifts = [0, size_bits, 2*size_bits, ...] for each packed position + shifts = torch.arange( + 0, 32, size_bits, dtype=torch.int32, device=packed.device + ) + # packed: [N//pack_factor, K//G] -> [N//pack_factor, K//G, 1] + # shifts: [pack_factor] -> [1, 1, pack_factor] + # Result: [N//pack_factor, K//G, pack_factor] + unpacked = (packed.unsqueeze(-1) >> shifts) & mask + + # Permute to [K//G, N//pack_factor, pack_factor] then reshape to [K//G, N] + unpacked = unpacked.permute(1, 0, 2).reshape(k_groups, n_full) + + x.data = unpacked.to(torch.uint8).contiguous() + + # Update metadata - zeros are no longer packed + if hasattr(x, "_input_dim"): + x._input_dim = 0 + if hasattr(x, "_output_dim"): + x._output_dim = 1 + if hasattr(x, "_packed_factor"): + x._packed_factor = 1 + return x + self._transform_param(layer, self.w_q_name, transform_w_q) self._transform_param(layer, self.w_s_name, transform_w_s) + if self.config.zero_points: + self._transform_param(layer, self.w_zp_name, transform_w_zp) def apply_weights( self,