[Kernel] Added flashinfer fp8 per-tensor gemms (#22895)
Signed-off-by: Julien Lin <jullin@nvidia.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -223,8 +223,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.act_q_static,
|
||||
act_quant_group_shape=self.act_q_group_shape,
|
||||
cutlass_fp8_supported=cutlass_fp8_supported())
|
||||
act_quant_group_shape=self.act_q_group_shape)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
@@ -376,6 +375,8 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# Update the layer with the new values.
|
||||
layer.weight = Parameter(qweight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
# layer.input_scale is None indicates dynamic quant and scale is
|
||||
# computed from input.
|
||||
layer.input_scale = None
|
||||
|
||||
# If checkpoint is fp8, handle that there are N scales for N
|
||||
|
||||
@@ -97,8 +97,8 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
|
||||
self.quant_config.is_checkpoint_fp8_serialized = False
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=False,
|
||||
cutlass_fp8_supported=False,
|
||||
act_quant_group_shape=GroupShape.PER_TOKEN)
|
||||
act_quant_group_shape=GroupShape.PER_TOKEN,
|
||||
force_fp8_e4m3fnuz=True)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||
|
||||
@@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
||||
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
||||
@@ -157,6 +158,19 @@ def cutlass_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
|
||||
return output.view(*output_shape)
|
||||
|
||||
|
||||
def flashinfer_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
|
||||
out_dtype: torch.dtype, scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
output_shape: list, **kwargs) -> torch.Tensor:
|
||||
|
||||
return flashinfer_scaled_fp8_mm(qinput,
|
||||
weight,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b,
|
||||
bias=bias)
|
||||
|
||||
|
||||
def rocm_per_tensor_w8a8_scaled_mm_impl(
|
||||
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
@@ -231,8 +245,8 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
input_2d: torch.Tensor,
|
||||
output_shape: list) -> torch.Tensor:
|
||||
input_2d: torch.Tensor, output_shape: list,
|
||||
**kwargs) -> torch.Tensor:
|
||||
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
|
||||
# when using it.
|
||||
# For now it has only been validated on ROCm platform.
|
||||
@@ -303,16 +317,22 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
|
||||
|
||||
|
||||
def dispatch_w8a8_scaled_mm(
|
||||
cutlass_fp8_supported: bool, per_tensor_weights: bool,
|
||||
preferred_backend: str, per_tensor_weights: bool,
|
||||
per_tensor_activations: bool) -> Callable[..., torch.Tensor]:
|
||||
|
||||
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
||||
if cutlass_fp8_supported:
|
||||
return cutlass_w8a8_scaled_mm
|
||||
if per_tensor_weights and per_tensor_activations:
|
||||
if current_platform.is_rocm():
|
||||
if preferred_backend == "rocm":
|
||||
return rocm_per_tensor_w8a8_scaled_mm
|
||||
if preferred_backend == "flashinfer":
|
||||
return flashinfer_w8a8_scaled_mm
|
||||
if preferred_backend == "cutlass":
|
||||
return cutlass_w8a8_scaled_mm
|
||||
return torch_per_tensor_w8a8_scaled_mm
|
||||
|
||||
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
||||
if preferred_backend == "cutlass" or preferred_backend == "flashinfer":
|
||||
return cutlass_w8a8_scaled_mm
|
||||
|
||||
# If torch.scaled_mm supports per-channel (weights) per-token (inputs)
|
||||
if not per_tensor_weights and not per_tensor_activations \
|
||||
and USE_ROWWISE_TORCH_SCALED_MM:
|
||||
@@ -334,10 +354,20 @@ class Fp8LinearOp:
|
||||
|
||||
def __init__(self,
|
||||
act_quant_static: bool,
|
||||
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
|
||||
act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR,
|
||||
pad_output: Optional[bool] = None):
|
||||
self.cutlass_fp8_supported = cutlass_fp8_supported
|
||||
pad_output: Optional[bool] = None,
|
||||
force_fp8_e4m3fnuz: bool = False):
|
||||
if current_platform.is_rocm():
|
||||
self.preferred_backend = "rocm"
|
||||
elif current_platform.is_cuda(
|
||||
) and not force_fp8_e4m3fnuz and cutlass_fp8_supported():
|
||||
if has_flashinfer() and current_platform.has_device_capability(
|
||||
100):
|
||||
self.preferred_backend = "flashinfer"
|
||||
else:
|
||||
self.preferred_backend = "cutlass"
|
||||
else:
|
||||
self.preferred_backend = "torch"
|
||||
|
||||
# Note: we pad the input because torch._scaled_mm is more performant
|
||||
# for matrices with batch dimension > 16.
|
||||
@@ -347,8 +377,7 @@ class Fp8LinearOp:
|
||||
if pad_output is None:
|
||||
config = get_current_vllm_config().compilation_config
|
||||
pad_output = config.level < CompilationLevel.PIECEWISE and \
|
||||
not cutlass_fp8_supported and \
|
||||
not current_platform.is_rocm()
|
||||
self.preferred_backend == "torch"
|
||||
|
||||
self.output_padding = 17 if pad_output else None
|
||||
self.act_quant_static = act_quant_static
|
||||
@@ -393,9 +422,9 @@ class Fp8LinearOp:
|
||||
per_tensor_activations = (x_scale.numel() == 1)
|
||||
|
||||
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
|
||||
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
|
||||
self.cutlass_fp8_supported, per_tensor_weights,
|
||||
per_tensor_activations)
|
||||
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(self.preferred_backend,
|
||||
per_tensor_weights,
|
||||
per_tensor_activations)
|
||||
|
||||
return w8a8_scaled_mm_func(qinput=qinput,
|
||||
weight=weight,
|
||||
|
||||
Reference in New Issue
Block a user