Remove weight_scale.T special case for SM90 Block FP8 CUTLASS kernel (#28431)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2025-11-11 09:46:04 -07:00
committed by GitHub
parent 287bbbeb06
commit f9a4087182
5 changed files with 36 additions and 36 deletions

View File

@@ -173,7 +173,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
layer.input_scale = None
if self.strategy == QuantizationStrategy.BLOCK:
maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
maybe_post_process_fp8_weight_block(layer)
def apply_weights(
self,

View File

@@ -540,7 +540,7 @@ class Fp8LinearMethod(LinearMethodBase):
return
if self.block_quant:
maybe_post_process_fp8_weight_block(layer, self.cutlass_block_fp8_supported)
maybe_post_process_fp8_weight_block(layer)
def apply(
self,

View File

@@ -55,17 +55,13 @@ def cutlass_scaled_mm(
Bs: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype = torch.float16,
is_hopper: bool | None = None,
) -> torch.Tensor:
if is_hopper is None:
is_hopper = current_platform.is_device_capability(90)
return ops.cutlass_scaled_mm(
A,
B.T,
out_dtype=output_dtype,
scale_a=As,
# SM90 block FP8 requires row-major scale_b, which we do ahead of time
scale_b=Bs if block_size is not None and is_hopper else Bs.T,
scale_b=Bs.T,
)
@@ -130,7 +126,7 @@ def _padded_cutlass(
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)
output = cutlass_scaled_mm(
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype, True
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
)
return output[0 : qx.shape[0], ...]
@@ -303,7 +299,6 @@ class W8A8BlockFp8LinearOp:
weight_scale,
list(self.weight_group_shape),
input_2d.dtype,
False,
)
def _run_aiter(
@@ -1125,9 +1120,7 @@ def process_fp8_weight_block_strategy(
return weight, weight_scale
def maybe_post_process_fp8_weight_block(
layer: torch.nn.Module, cutlass_block_fp8_supported: bool
):
def maybe_post_process_fp8_weight_block(layer: torch.nn.Module):
assert layer.weight_block_size is not None
from vllm.utils.deep_gemm import (
@@ -1146,15 +1139,6 @@ def maybe_post_process_fp8_weight_block(
requant_weight_ue8m0_inplace(
layer.weight.data, layer.weight_scale.data, block_sz
)
# SM90 Block FP8 CUTLASS requires row-major weight scales
elif (
current_platform.is_device_capability(90)
and cutlass_block_fp8_supported
and not should_use_deepgemm
):
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data.T.contiguous(), requires_grad=False
)
def expert_weight_is_col_major(x: torch.Tensor) -> bool: