[quantization] use channel scales for w4a8 + misc fixes (#23570)
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
@@ -79,7 +79,8 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
|
||||
act_type=torch.float8_e4m3fn, # always use fp8(e4m3)
|
||||
group_size=self.group_size,
|
||||
zero_points=not self.symmetric,
|
||||
has_g_idx=self.has_g_idx
|
||||
has_g_idx=self.has_g_idx,
|
||||
out_type=params_dtype
|
||||
)
|
||||
|
||||
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||
@@ -122,7 +123,7 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
|
||||
torch.empty(
|
||||
output_size_per_partition,
|
||||
scales_and_zp_size,
|
||||
dtype=params_dtype,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -140,9 +141,17 @@ class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
|
||||
dtype=torch.int64),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# per-channel scales
|
||||
weight_chan_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((output_size_per_partition, 1),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
layer.register_parameter("weight_shape", weight_shape)
|
||||
layer.register_parameter("weight_chan_scale", weight_chan_scale)
|
||||
|
||||
self.kernel = kernel_type(mp_linear_kernel_config,
|
||||
w_q_param_name="weight_packed",
|
||||
|
||||
@@ -20,6 +20,7 @@ class MPLinearLayerConfig:
|
||||
group_size: int
|
||||
zero_points: bool
|
||||
has_g_idx: bool
|
||||
out_type: Optional[torch.dtype] = None
|
||||
|
||||
|
||||
class MPLinearKernel(ABC):
|
||||
|
||||
@@ -60,13 +60,17 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
|
||||
if in_features % 128 or out_features % 128:
|
||||
return False, "K and N must be divisible by 128, got "\
|
||||
f"{c.partition_weight_shape}"
|
||||
|
||||
if c.out_type != torch.bfloat16:
|
||||
return False, "Only bfloat16 output type currently supported"\
|
||||
f"got {c.out_type=}"
|
||||
|
||||
return True, None
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
c = self.config
|
||||
|
||||
# TODO(czhu): optimize speed/mem usage
|
||||
def transform_w_q(x):
|
||||
@@ -86,19 +90,15 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
|
||||
# Encode/reorder weights and pack scales
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
# TODO(czhu): support loading channel scales
|
||||
self.w_ch_s = torch.ones((c.partition_weight_shape[1], ),
|
||||
dtype=torch.float32,
|
||||
device='cuda')
|
||||
self._transform_param(layer, "weight_chan_scale", lambda x: x)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
assert bias is None, "bias not supported by CUTLASS W4A8"
|
||||
c = self.config
|
||||
w_q, w_s, _, _ = self._get_weight_params(layer)
|
||||
w_ch_s = layer.weight_chan_scale
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||
@@ -109,6 +109,9 @@ class CutlassW4A8LinearKernel(MPLinearKernel):
|
||||
b_group_scales=w_s,
|
||||
b_group_size=c.group_size,
|
||||
a_token_scales=act_scales,
|
||||
b_channel_scales=self.w_ch_s)
|
||||
b_channel_scales=w_ch_s)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
Reference in New Issue
Block a user