[quantization] use channel scales for w4a8 + misc fixes (#23570)

Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
czhu-cohere
2025-08-26 21:23:23 -04:00
committed by GitHub
parent c7c80af084
commit 2c2b140ae8
4 changed files with 63 additions and 14 deletions

View File

@@ -20,6 +20,7 @@ class MPLinearLayerConfig:
group_size: int
zero_points: bool
has_g_idx: bool
out_type: Optional[torch.dtype] = None
class MPLinearKernel(ABC):

View File

@@ -60,13 +60,17 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
if in_features % 128 or out_features % 128:
return False, "K and N must be divisible by 128, got "\
f"{c.partition_weight_shape}"
if c.out_type != torch.bfloat16:
return False, "Only bfloat16 output type currently supported"\
f"got {c.out_type=}"
return True, None
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config
# TODO(czhu): optimize speed/mem usage
def transform_w_q(x):
@@ -86,19 +90,15 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
# Encode/reorder weights and pack scales
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
# TODO(czhu): support loading channel scales
self.w_ch_s = torch.ones((c.partition_weight_shape[1], ),
dtype=torch.float32,
device='cuda')
self._transform_param(layer, "weight_chan_scale", lambda x: x)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
assert bias is None, "bias not supported by CUTLASS W4A8"
c = self.config
w_q, w_s, _, _ = self._get_weight_params(layer)
w_ch_s = layer.weight_chan_scale
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
@@ -109,6 +109,9 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
b_group_scales=w_s,
b_group_size=c.group_size,
a_token_scales=act_scales,
b_channel_scales=self.w_ch_s)
b_channel_scales=w_ch_s)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)