[Kernel] Added flashinfer fp8 per-tensor gemms (#22895)

Signed-off-by: Julien Lin <jullin@nvidia.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
nvjullin
2025-08-26 21:54:04 +08:00
committed by GitHub
parent b78bed1bc5
commit f66673a39d
9 changed files with 198 additions and 36 deletions

View File

@@ -223,8 +223,7 @@ class Fp8LinearMethod(LinearMethodBase):
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.act_q_static,
act_quant_group_shape=self.act_q_group_shape,
cutlass_fp8_supported=cutlass_fp8_supported())
act_quant_group_shape=self.act_q_group_shape)
def create_weights(
self,
@@ -376,6 +375,8 @@ class Fp8LinearMethod(LinearMethodBase):
# Update the layer with the new values.
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
# layer.input_scale is None indicates dynamic quant and scale is
# computed from input.
layer.input_scale = None
# If checkpoint is fp8, handle that there are N scales for N

View File

@@ -97,8 +97,8 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
self.quant_config.is_checkpoint_fp8_serialized = False
self.fp8_linear = Fp8LinearOp(
act_quant_static=False,
cutlass_fp8_supported=False,
act_quant_group_shape=GroupShape.PER_TOKEN)
act_quant_group_shape=GroupShape.PER_TOKEN,
force_fp8_e4m3fnuz=True)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data,

View File

@@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
@@ -157,6 +158,19 @@ def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
return output.view(*output_shape)
def flashinfer_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
out_dtype: torch.dtype, scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
output_shape: list, **kwargs) -> torch.Tensor:
return flashinfer_scaled_fp8_mm(qinput,
weight,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b,
bias=bias)
def rocm_per_tensor_w8a8_scaled_mm_impl(
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
@@ -231,8 +245,8 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: list) -> torch.Tensor:
input_2d: torch.Tensor, output_shape: list,
**kwargs) -> torch.Tensor:
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
# when using it.
# For now it has only been validated on ROCm platform.
@@ -303,16 +317,22 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
def dispatch_w8a8_scaled_mm(
cutlass_fp8_supported: bool, per_tensor_weights: bool,
preferred_backend: str, per_tensor_weights: bool,
per_tensor_activations: bool) -> Callable[..., torch.Tensor]:
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if cutlass_fp8_supported:
return cutlass_w8a8_scaled_mm
if per_tensor_weights and per_tensor_activations:
if current_platform.is_rocm():
if preferred_backend == "rocm":
return rocm_per_tensor_w8a8_scaled_mm
if preferred_backend == "flashinfer":
return flashinfer_w8a8_scaled_mm
if preferred_backend == "cutlass":
return cutlass_w8a8_scaled_mm
return torch_per_tensor_w8a8_scaled_mm
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if preferred_backend == "cutlass" or preferred_backend == "flashinfer":
return cutlass_w8a8_scaled_mm
# If torch.scaled_mm supports per-channel (weights) per-token (inputs)
if not per_tensor_weights and not per_tensor_activations \
and USE_ROWWISE_TORCH_SCALED_MM:
@@ -334,10 +354,20 @@ class Fp8LinearOp:
def __init__(self,
act_quant_static: bool,
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR,
pad_output: Optional[bool] = None):
self.cutlass_fp8_supported = cutlass_fp8_supported
pad_output: Optional[bool] = None,
force_fp8_e4m3fnuz: bool = False):
if current_platform.is_rocm():
self.preferred_backend = "rocm"
elif current_platform.is_cuda(
) and not force_fp8_e4m3fnuz and cutlass_fp8_supported():
if has_flashinfer() and current_platform.has_device_capability(
100):
self.preferred_backend = "flashinfer"
else:
self.preferred_backend = "cutlass"
else:
self.preferred_backend = "torch"
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
@@ -347,8 +377,7 @@ class Fp8LinearOp:
if pad_output is None:
config = get_current_vllm_config().compilation_config
pad_output = config.level < CompilationLevel.PIECEWISE and \
not cutlass_fp8_supported and \
not current_platform.is_rocm()
self.preferred_backend == "torch"
self.output_padding = 17 if pad_output else None
self.act_quant_static = act_quant_static
@@ -393,9 +422,9 @@ class Fp8LinearOp:
per_tensor_activations = (x_scale.numel() == 1)
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
self.cutlass_fp8_supported, per_tensor_weights,
per_tensor_activations)
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(self.preferred_backend,
per_tensor_weights,
per_tensor_activations)
return w8a8_scaled_mm_func(qinput=qinput,
weight=weight,