[Perf] Use Triton instead of Torch for DeepGEMM Per Token Group Quant (#20841)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-07-12 22:38:45 -04:00
committed by GitHub
parent f45a332886
commit 42d440c22b
6 changed files with 26 additions and 42 deletions

View File

@@ -15,8 +15,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_block_fp8_matmul)
from vllm.platforms import current_platform
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import (fp8_gemm_nt, per_block_cast_to_fp8,
per_token_group_cast_to_fp8)
from vllm.utils.deep_gemm import fp8_gemm_nt, per_block_cast_to_fp8
if current_platform.get_device_capability() < (9, 0):
pytest.skip("FP8 Triton requires CUDA 9.0 or higher",
@@ -117,7 +116,7 @@ def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed):
A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max
A_fp8, As_fp8 = per_token_group_cast_to_fp8(A_fp32, block_size[1])
A_fp8, As_fp8 = per_token_group_quant_fp8(A_fp32, block_size[1])
B_fp8, Bs_fp8 = per_block_cast_to_fp8(B_fp32)
As = As_fp8.to(torch.float32)