[Feature] Integrate SM100 DeepGEMM support (#20087)
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -13,7 +14,7 @@ import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
scaled_dequantize)
|
||||
group_broadcast)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED)
|
||||
from vllm.platforms import current_platform
|
||||
@@ -235,7 +236,7 @@ def block_quant_to_tensor_quant(
|
||||
The outputs are tensor-wise quantization tensor and tensor-wise
|
||||
quantization scale. Note only float8 is supported for now.
|
||||
"""
|
||||
x_dq_block = scaled_dequantize(x_q_block, x_s)
|
||||
x_dq_block = group_broadcast(x_q_block, x_s)
|
||||
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
|
||||
return x_q_tensor, scale
|
||||
|
||||
@@ -651,3 +652,124 @@ def w8a8_block_fp8_matmul(
|
||||
)
|
||||
|
||||
return C
|
||||
|
||||
|
||||
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
|
||||
# TODO(wentao): remove this function when DeepGEMM exposes this function
|
||||
def get_tma_aligned_size(x: int, element_size: int) -> int:
|
||||
"""
|
||||
Global memory address of TMA must be 16-byte aligned.
|
||||
Since we use column-major layout for the LHS scaling tensor,
|
||||
the M-axis of the LHS scaling tensor needs to be padded to a multiple of
|
||||
16 bytes.
|
||||
|
||||
Arguments:
|
||||
x: original M-axis shape of the LHS scaling tensor.
|
||||
element_size: element size of the LHS scaling tensor.
|
||||
|
||||
Returns:
|
||||
M-axis shape of the LHS scaling tensor after padding.
|
||||
"""
|
||||
tma_alignment_bytes = 16
|
||||
assert tma_alignment_bytes % element_size == 0
|
||||
alignment = tma_alignment_bytes // element_size
|
||||
return cdiv(x, alignment) * alignment
|
||||
|
||||
|
||||
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
|
||||
# TODO(wentao): remove this function when DeepGEMM exposes this function
|
||||
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Returns TMA-aligned transposed format of the input tensor. `torch.transpose`
|
||||
will be called if necessary.
|
||||
If the input tensor is already column-major layout and 16-byte aligned along
|
||||
the M axis (thus meets the requirement of LHS scaling tensor in
|
||||
DeepGEMM), this function will do nothing.
|
||||
|
||||
Arguments:
|
||||
x: usually the LHS scaling tensor in GEMM.
|
||||
|
||||
Returns:
|
||||
The LHS scaling tensor of TMA-aligned transposed format.
|
||||
"""
|
||||
# NOTES: for the extreme performance, you may rewrite/fuse this function in
|
||||
# CUDA
|
||||
assert x.dim() in (2, 3)
|
||||
remove_dim = False
|
||||
m, n = x.shape[-2], x.shape[-1]
|
||||
aligned_m = get_tma_aligned_size(m, x.element_size())
|
||||
if x.dim() == 2:
|
||||
if x.stride(0) == 1 and x.stride(1) == aligned_m:
|
||||
return x
|
||||
x, remove_dim = x.unsqueeze(0), True
|
||||
|
||||
b = x.shape[0]
|
||||
|
||||
# The last kernel gives a column-major TMA aligned layout
|
||||
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(
|
||||
2) == aligned_m:
|
||||
return x.squeeze(0) if remove_dim else x
|
||||
|
||||
# Normal layout requires transposing
|
||||
aligned_x = torch.transpose(
|
||||
torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
||||
aligned_x[:, :m, :] = x
|
||||
aligned_x = aligned_x[:, :m, :]
|
||||
return aligned_x.squeeze(0) if remove_dim else aligned_x
|
||||
|
||||
|
||||
def requant_weight_ue8m0_inplace(
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
block_size: Sequence[int] = (128, 128),
|
||||
) -> None:
|
||||
"""Re-quantise *weight* so that its per-block scaling factors are in the
|
||||
UE8M0 (power-of-two) format expected by the new DeepGEMM kernels inplace.
|
||||
|
||||
Args:
|
||||
weight: Block-quantised weight tensor stored in ``torch.float8_e4m3fn``.
|
||||
Expected shape ``(..., M, K)``.
|
||||
weight_scale: Corresponding per-block scale tensor (``torch.float32``)
|
||||
with shape ``(..., M // block_size[0], K // block_size[1])``.
|
||||
block_size: 2-element iterable ``[block_m, block_k]`` describing the
|
||||
block quantisation granularity.
|
||||
"""
|
||||
if weight.numel() == 0:
|
||||
return
|
||||
|
||||
if weight.dtype != torch.float8_e4m3fn:
|
||||
raise ValueError("Expected *weight* to be torch.float8_e4m3fn, got "
|
||||
f"{weight.dtype} instead.")
|
||||
|
||||
from vllm.utils.deep_gemm import per_block_cast_to_fp8
|
||||
|
||||
block_m, block_k = int(block_size[0]), int(block_size[1])
|
||||
|
||||
# Flatten leading dimensions so we can iterate over the last two dims.
|
||||
leading_shape = weight.shape[:-2]
|
||||
if len(leading_shape) == 0:
|
||||
w_view = weight.unsqueeze(0)
|
||||
s_view = weight_scale.unsqueeze(0)
|
||||
else:
|
||||
w_view = weight.reshape(-1, weight.shape[-2], weight.shape[-1])
|
||||
s_view = weight_scale.reshape(-1, *weight_scale.shape[-2:])
|
||||
|
||||
num_mats = w_view.size(0)
|
||||
for idx in range(num_mats):
|
||||
w_q = w_view[idx]
|
||||
s_old = s_view[idx]
|
||||
|
||||
# De-quantise with the *old* scaling factors (float32).
|
||||
m_cur, k_cur = w_q.shape
|
||||
s_float = s_old.to(torch.float32)
|
||||
# Expand scales along rows and cols by block size, then crop.
|
||||
s_exp_r = torch.repeat_interleave(s_float, block_m, dim=0)
|
||||
s_exp = torch.repeat_interleave(s_exp_r, block_k, dim=1)
|
||||
s_exp = s_exp[:m_cur, :k_cur]
|
||||
w_dq = w_q.to(torch.float32) * s_exp
|
||||
# Re-quantise using power-of-two scaling (UE8M0).
|
||||
w_requant, s_requant = per_block_cast_to_fp8(w_dq, [block_m, block_k])
|
||||
|
||||
# Write back the results in-place.
|
||||
w_q.copy_(w_requant)
|
||||
s_old.copy_(s_requant)
|
||||
|
||||
Reference in New Issue
Block a user