[Kernel] Expand FP8 support to Ampere GPUs using FP8 Marlin (#5975)
This commit is contained in:
@@ -14,13 +14,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
get_pack_factor, quantize_weights, sort_weights)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
__cuda_arch = current_platform.get_device_capability()
|
||||
|
||||
MARLIN_TILE = 16
|
||||
|
||||
|
||||
def is_marlin_supported():
|
||||
return __cuda_arch[0] >= 8
|
||||
capability = current_platform.get_device_capability()
|
||||
return capability[0] >= 8
|
||||
|
||||
|
||||
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE):
|
||||
@@ -223,3 +222,26 @@ class MarlinWorkspace:
|
||||
self.scratch = torch.zeros(max_workspace_size,
|
||||
dtype=torch.int,
|
||||
device="cuda")
|
||||
|
||||
|
||||
def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Repack FP8 weights to gptq format (packed int32 elements)
|
||||
"""
|
||||
assert fp8_tensor.dtype == torch.float8_e4m3fn
|
||||
assert fp8_tensor.shape[0] % 4 == 0
|
||||
|
||||
# Reshape to prepare for packing
|
||||
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])
|
||||
|
||||
# Convert fp8 to uint8 (byte) representation
|
||||
byte_tensor = reshaped.view(torch.uint8)
|
||||
|
||||
# Pack 4 uint8 values into one int32
|
||||
packed = (byte_tensor[:, 0].to(torch.int32) |
|
||||
(byte_tensor[:, 1].to(torch.int32) << 8) |
|
||||
(byte_tensor[:, 2].to(torch.int32) << 16) |
|
||||
(byte_tensor[:, 3].to(torch.int32) << 24))
|
||||
|
||||
return packed.view(fp8_tensor.shape[0] // 4,
|
||||
*fp8_tensor.shape[1:]).contiguous()
|
||||
|
||||
Reference in New Issue
Block a user