[Feature] Integrate new deepgemm (#19820)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -114,6 +114,10 @@ def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor):
|
||||
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
|
||||
|
||||
|
||||
def ceil_div(x: int, y: int) -> int:
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
# TODO fix ROCm->Triton custom path:
|
||||
# https://github.com/vllm-project/vllm/issues/14397
|
||||
def apply_w8a8_block_fp8_linear(
|
||||
@@ -158,9 +162,6 @@ def apply_w8a8_block_fp8_linear(
|
||||
if current_platform.is_cuda():
|
||||
if current_platform.has_device_capability(100):
|
||||
|
||||
def ceil_div(x: int, y: int) -> int:
|
||||
return (x + y - 1) // y
|
||||
|
||||
use_cutlass = cutlass_block_fp8_supported and (
|
||||
ceil_div(weight.shape[0], 128) == weight_scale.shape[0]
|
||||
and ceil_div(weight.shape[1], 128) == weight_scale.shape[1])
|
||||
@@ -655,3 +656,67 @@ def w8a8_block_fp8_matmul(
|
||||
)
|
||||
|
||||
return C
|
||||
|
||||
|
||||
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
|
||||
# TODO(wentao): remove this function when DeepGEMM exposes this function
|
||||
def get_tma_aligned_size(x: int, element_size: int) -> int:
|
||||
"""
|
||||
Global memory address of TMA must be 16-byte aligned.
|
||||
Since we use column-major layout for the LHS scaling tensor,
|
||||
the M-axis of the LHS scaling tensor needs to be padded to a multiple of
|
||||
16 bytes.
|
||||
|
||||
Arguments:
|
||||
x: original M-axis shape of the LHS scaling tensor.
|
||||
element_size: element size of the LHS scaling tensor.
|
||||
|
||||
Returns:
|
||||
M-axis shape of the LHS scaling tensor after padding.
|
||||
"""
|
||||
tma_alignment_bytes = 16
|
||||
assert tma_alignment_bytes % element_size == 0
|
||||
alignment = tma_alignment_bytes // element_size
|
||||
return ceil_div(x, alignment) * alignment
|
||||
|
||||
|
||||
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
|
||||
# TODO(wentao): remove this function when DeepGEMM exposes this function
|
||||
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Returns TMA-aligned transposed format of the input tensor. `torch.transpose`
|
||||
will be called if necessary.
|
||||
If the input tensor is already column-major layout and 16-byte aligned along
|
||||
the M axis (thus meets the requirement of LHS scaling tensor in
|
||||
DeepGEMM), this function will do nothing.
|
||||
|
||||
Arguments:
|
||||
x: usually the LHS scaling tensor in GEMM.
|
||||
|
||||
Returns:
|
||||
The LHS scaling tensor of TMA-aligned transposed format.
|
||||
"""
|
||||
# NOTES: for the extreme performance, you may rewrite/fuse this function in
|
||||
# CUDA
|
||||
assert x.dim() in (2, 3)
|
||||
remove_dim = False
|
||||
m, n = x.shape[-2], x.shape[-1]
|
||||
aligned_m = get_tma_aligned_size(m, x.element_size())
|
||||
if x.dim() == 2:
|
||||
if x.stride(0) == 1 and x.stride(1) == aligned_m:
|
||||
return x
|
||||
x, remove_dim = x.unsqueeze(0), True
|
||||
|
||||
b = x.shape[0]
|
||||
|
||||
# The last kernel gives a column-major TMA aligned layout
|
||||
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(
|
||||
2) == aligned_m:
|
||||
return x.squeeze(0) if remove_dim else x
|
||||
|
||||
# Normal layout requires transposing
|
||||
aligned_x = torch.transpose(
|
||||
torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
|
||||
aligned_x[:, :m, :] = x
|
||||
aligned_x = aligned_x[:, :m, :]
|
||||
return aligned_x.squeeze(0) if remove_dim else aligned_x
|
||||
|
||||
Reference in New Issue
Block a user