Revert "[Feature] Integrate new deepgemm (#19820)" (#20049)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-06-24 22:45:22 -04:00
committed by GitHub
parent 1afa9948f5
commit a6c4b87fbc
8 changed files with 254 additions and 230 deletions

View File

@@ -114,10 +114,6 @@ def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor):
and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0)
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
# TODO fix ROCm->Triton custom path:
# https://github.com/vllm-project/vllm/issues/14397
def apply_w8a8_block_fp8_linear(
@@ -162,6 +158,9 @@ def apply_w8a8_block_fp8_linear(
if current_platform.is_cuda():
if current_platform.has_device_capability(100):
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
use_cutlass = cutlass_block_fp8_supported and (
ceil_div(weight.shape[0], 128) == weight_scale.shape[0]
and ceil_div(weight.shape[1], 128) == weight_scale.shape[1])
@@ -656,67 +655,3 @@ def w8a8_block_fp8_matmul(
)
return C
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
# TODO(wentao): remove this function when DeepGEMM exposes this function
def get_tma_aligned_size(x: int, element_size: int) -> int:
"""
Global memory address of TMA must be 16-byte aligned.
Since we use column-major layout for the LHS scaling tensor,
the M-axis of the LHS scaling tensor needs to be padded to a multiple of
16 bytes.
Arguments:
x: original M-axis shape of the LHS scaling tensor.
element_size: element size of the LHS scaling tensor.
Returns:
M-axis shape of the LHS scaling tensor after padding.
"""
tma_alignment_bytes = 16
assert tma_alignment_bytes % element_size == 0
alignment = tma_alignment_bytes // element_size
return ceil_div(x, alignment) * alignment
# Taken from https://github.com/deepseek-ai/DeepGEMM/blob/0c88cd01392c1073c7049a97d6328c7bba9b3947
# TODO(wentao): remove this function when DeepGEMM exposes this function
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
"""
Returns TMA-aligned transposed format of the input tensor. `torch.transpose`
will be called if necessary.
If the input tensor is already column-major layout and 16-byte aligned along
the M axis (thus meets the requirement of LHS scaling tensor in
DeepGEMM), this function will do nothing.
Arguments:
x: usually the LHS scaling tensor in GEMM.
Returns:
The LHS scaling tensor of TMA-aligned transposed format.
"""
# NOTES: for the extreme performance, you may rewrite/fuse this function in
# CUDA
assert x.dim() in (2, 3)
remove_dim = False
m, n = x.shape[-2], x.shape[-1]
aligned_m = get_tma_aligned_size(m, x.element_size())
if x.dim() == 2:
if x.stride(0) == 1 and x.stride(1) == aligned_m:
return x
x, remove_dim = x.unsqueeze(0), True
b = x.shape[0]
# The last kernel gives a column-major TMA aligned layout
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(
2) == aligned_m:
return x.squeeze(0) if remove_dim else x
# Normal layout requires transposing
aligned_x = torch.transpose(
torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
aligned_x[:, :m, :] = x
aligned_x = aligned_x[:, :m, :]
return aligned_x.squeeze(0) if remove_dim else aligned_x