[Kernel] [FP8] Improve FP8 linear layer performance (#4691)

This PR improves the FP8 performance of linear layers, which had been lacking before (#4118 (comment) and #4118 (comment)).

We noticed that CUBLASLt can find a better algorithm if the first dimension of the matrix is greater than 16. So this PR enlarges matrices appropriately during quantization. This improves FP8 performance and removes the performance regression vs. FP16, in many cases exceeding FP16 performance.

Here are benchmarks on llama3 70b (ITL numbers for 1000 input and 50 output tokens at fixed qps and at TP 4), all FP8 measurements are for dynamic quantization:

qps = 1: 24 ms (FP8, this PR), 32 ms (FP8, previous main), 26 ms (FP16)
qps = 2: 26 ms (FP8, this PR), 34ms (FP8, previous main), 28 ms (FP16) 
qps = 4: 33 ms (FP8, this PR), 44 ms (FP8, previous main), 36 ms (FP16)
qps = 6: 46 ms (FP8, this PR), 56 ms (FP8, previous main), 54 ms (FP16)
qps = 8: 85 ms (FP8, this PR), 85 ms (FP8, previous main), 138 ms (FP16)
This commit is contained in:
Philipp Moritz
2024-05-09 16:38:07 -07:00
committed by GitHub
parent ebce310b74
commit 379da6dcb5
2 changed files with 35 additions and 4 deletions

View File

@@ -231,9 +231,14 @@ class Fp8LinearMethod(LinearMethodBase):
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.act_scale is None and x_scale computed from x.
# If static, layer.act_scale is scalar and x_scale set to act_scale.
qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale)
qinput, x_scale = ops.scaled_fp8_quant(x,
layer.act_scale,
batch_dim_padding=17)
# Fused GEMM_DQ
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output, _ = torch._scaled_mm(
qinput,
layer.weight,
@@ -243,7 +248,7 @@ class Fp8LinearMethod(LinearMethodBase):
bias=bias,
)
return output
return torch.narrow(output, 0, 0, x.shape[0])
def all_close_1d(x: torch.Tensor) -> bool: