[Perf] Use Triton instead of Torch for DeepGEMM Per Token Group Quant (#20841)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2025-07-12 22:38:45 -04:00
committed by GitHub
parent f45a332886
commit 42d440c22b
6 changed files with 26 additions and 42 deletions

View File

@@ -13,9 +13,10 @@ import torch
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.utils import has_deep_gemm
from vllm.utils.deep_gemm import (calc_diff, per_block_cast_to_fp8,
per_token_group_cast_to_fp8)
from vllm.utils.deep_gemm import calc_diff, per_block_cast_to_fp8
BLOCK_SIZE = [128, 128]
@@ -81,7 +82,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
"""
tokens_bf16 = torch.randn(
m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1)
_, a1_scale = per_token_group_cast_to_fp8(tokens_bf16, block_size[1])
_, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])
# expert weight tensors
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k,