Remove weight_scale.T special case for SM90 Block FP8 CUTLASS kernel (#28431)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -173,7 +173,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
layer.input_scale = None
|
||||
|
||||
if self.strategy == QuantizationStrategy.BLOCK:
|
||||
maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
|
||||
maybe_post_process_fp8_weight_block(layer)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
|
||||
@@ -540,7 +540,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
return
|
||||
|
||||
if self.block_quant:
|
||||
maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
|
||||
maybe_post_process_fp8_weight_block(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
|
||||
@@ -55,17 +55,13 @@ def cutlass_scaled_mm(
|
||||
Bs: torch.Tensor,
|
||||
block_size: list[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
is_hopper: bool | None = None,
|
||||
) -> torch.Tensor:
|
||||
if is_hopper is None:
|
||||
is_hopper = current_platform.is_device_capability(90)
|
||||
return ops.cutlass_scaled_mm(
|
||||
A,
|
||||
B.T,
|
||||
out_dtype=output_dtype,
|
||||
scale_a=As,
|
||||
# SM90 block FP8 requires row-major scale_b, which we do ahead of time
|
||||
scale_b=Bs if block_size is not None and is_hopper else Bs.T,
|
||||
scale_b=Bs.T,
|
||||
)
|
||||
|
||||
|
||||
@@ -130,7 +126,7 @@ def _padded_cutlass(
|
||||
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
|
||||
|
||||
output = cutlass_scaled_mm(
|
||||
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype, True
|
||||
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
|
||||
)
|
||||
return output[0 : qx.shape[0], ...]
|
||||
|
||||
@@ -303,7 +299,6 @@ class W8A8BlockFp8LinearOp:
|
||||
weight_scale,
|
||||
list(self.weight_group_shape),
|
||||
input_2d.dtype,
|
||||
False,
|
||||
)
|
||||
|
||||
def _run_aiter(
|
||||
@@ -1125,9 +1120,7 @@ def process_fp8_weight_block_strategy(
|
||||
return weight, weight_scale
|
||||
|
||||
|
||||
def maybe_post_process_fp8_weight_block(
|
||||
layer: torch.nn.Module, cutlass_block_fp8_supported: bool
|
||||
):
|
||||
def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
|
||||
assert layer.weight_block_size is not None
|
||||
|
||||
from vllm.utils.deep_gemm import (
|
||||
@@ -1146,15 +1139,6 @@ def maybe_post_process_fp8_weight_block(
|
||||
requant_weight_ue8m0_inplace(
|
||||
layer.weight.data, layer.weight_scale.data, block_sz
|
||||
)
|
||||
# SM90 Block FP8 CUTLASS requires row-major weight scales
|
||||
elif (
|
||||
current_platform.is_device_capability(90)
|
||||
and cutlass_block_fp8_supported
|
||||
and not should_use_deepgemm
|
||||
):
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data.T.contiguous(), requires_grad=False
|
||||
)
|
||||
|
||||
|
||||
def expert_weight_is_col_major(x: torch.Tensor) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user