Upstream triton fp4 weight preshuffle (#28888)
Signed-off-by: Aleksandr Malyshev <maleksan@amd.com> Co-authored-by: Aleksandr Malyshev <maleksan@amd.com>
This commit is contained in:
committed by
GitHub
parent
30b44a1598
commit
b7f1f490a6
@@ -948,6 +948,31 @@ class rocm_aiter_ops:
|
|||||||
(8192, 32768),
|
(8192, 32768),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_triton_gemm_afp4wfp4_presh_ws_tuned(n: int, k: int) -> bool:
|
||||||
|
return (n, k) in [
|
||||||
|
(8192, 4096),
|
||||||
|
(1280, 8192),
|
||||||
|
(16384, 53248),
|
||||||
|
(106496, 16384),
|
||||||
|
(57344, 8192),
|
||||||
|
(8192, 2048),
|
||||||
|
(2560, 8192),
|
||||||
|
(10240, 8192),
|
||||||
|
(16384, 16384),
|
||||||
|
(8192, 28672),
|
||||||
|
(28672, 8192),
|
||||||
|
(18432, 16384),
|
||||||
|
(8192, 1024),
|
||||||
|
(7168, 8192),
|
||||||
|
(5120, 8192),
|
||||||
|
(8192, 8192),
|
||||||
|
(8192, 7168),
|
||||||
|
(14336, 8192),
|
||||||
|
(8192, 14336),
|
||||||
|
(8192, 3584),
|
||||||
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def shuffle_weight(
|
def shuffle_weight(
|
||||||
self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
|
self, tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||||
dequant_mxfp4,
|
dequant_mxfp4,
|
||||||
@@ -49,7 +50,10 @@ def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from aiter.ops.shuffle import shuffle_weight
|
from aiter.ops.shuffle import shuffle_weight
|
||||||
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
from aiter.ops.triton.gemm_afp4wfp4 import (
|
||||||
|
gemm_afp4wfp4,
|
||||||
|
gemm_afp4wfp4_preshuffled_weight_scales,
|
||||||
|
)
|
||||||
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
||||||
|
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
@@ -66,7 +70,37 @@ try:
|
|||||||
x_scales: torch.Tensor | None = None,
|
x_scales: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
M = x.shape[0]
|
M = x.shape[0]
|
||||||
|
N = weight.shape[0]
|
||||||
|
K = weight.shape[1]
|
||||||
if rocm_use_aiter_fp4_asm_gemm:
|
if rocm_use_aiter_fp4_asm_gemm:
|
||||||
|
if M <= 64 and rocm_aiter_ops.is_triton_gemm_afp4wfp4_presh_ws_tuned(N, K):
|
||||||
|
if x_scales is None:
|
||||||
|
# use hip quant kernel for performance
|
||||||
|
if M >= 32:
|
||||||
|
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
|
||||||
|
else:
|
||||||
|
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=False)
|
||||||
|
else:
|
||||||
|
x_q = x
|
||||||
|
x_s = x_scales
|
||||||
|
|
||||||
|
if M >= 32:
|
||||||
|
x_s = x_s.view(torch.uint8).view(x_s.shape[0] // 32, -1)
|
||||||
|
else:
|
||||||
|
x_s = x_s[:M, ...].view(torch.uint8)
|
||||||
|
|
||||||
|
y = torch.empty(M, N, device=x_q.device, dtype=out_dtype)
|
||||||
|
gemm_afp4wfp4_preshuffled_weight_scales(
|
||||||
|
x_q.view(torch.uint8),
|
||||||
|
weight.view(torch.uint8).view(weight.shape[0] // 16, -1),
|
||||||
|
x_s,
|
||||||
|
weight_scale.view(torch.uint8).view(
|
||||||
|
weight_scale.shape[0] // 32, -1
|
||||||
|
),
|
||||||
|
out_dtype,
|
||||||
|
y,
|
||||||
|
)
|
||||||
|
else:
|
||||||
if x_scales is None:
|
if x_scales is None:
|
||||||
# use hip quant kernel for performance
|
# use hip quant kernel for performance
|
||||||
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
|
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
|
||||||
@@ -77,7 +111,10 @@ try:
|
|||||||
# 32 alignment is enough for dim0 padding of output for
|
# 32 alignment is enough for dim0 padding of output for
|
||||||
# gemm_a4w4 kernel
|
# gemm_a4w4 kernel
|
||||||
y = torch.empty(
|
y = torch.empty(
|
||||||
(M + 31) // 32 * 32, weight.shape[0], device=x_q.device, dtype=out_dtype
|
(M + 31) // 32 * 32,
|
||||||
|
weight.shape[0],
|
||||||
|
device=x_q.device,
|
||||||
|
dtype=out_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
gemm_a4w4(
|
gemm_a4w4(
|
||||||
|
|||||||
Reference in New Issue
Block a user