[Kernel] Expand FP8 support to Ampere GPUs using FP8 Marlin (#5975)

This commit is contained in:
Michael Goin
2024-07-03 13:38:00 -04:00
committed by GitHub
parent 7cd2ebb025
commit 47f0954af0
11 changed files with 1585 additions and 42 deletions

View File

@@ -14,13 +14,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_pack_factor, quantize_weights, sort_weights)
from vllm.platforms import current_platform
__cuda_arch = current_platform.get_device_capability()
MARLIN_TILE = 16
def is_marlin_supported():
return __cuda_arch[0] >= 8
capability = current_platform.get_device_capability()
return capability[0] >= 8
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE):
@@ -223,3 +222,26 @@ class MarlinWorkspace:
self.scratch = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda")
def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
"""
Repack FP8 weights to gptq format (packed int32 elements)
"""
assert fp8_tensor.dtype == torch.float8_e4m3fn
assert fp8_tensor.shape[0] % 4 == 0
# Reshape to prepare for packing
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
# Convert fp8 to uint8 (byte) representation
byte_tensor = reshaped.view(torch.uint8)
# Pack 4 uint8 values into one int32
packed = (byte_tensor[:, 0].to(torch.int32) |
(byte_tensor[:, 1].to(torch.int32) << 8) |
(byte_tensor[:, 2].to(torch.int32) << 16) |
(byte_tensor[:, 3].to(torch.int32) << 24))
return packed.view(fp8_tensor.shape[0] // 4,
*fp8_tensor.shape[1:]).contiguous()