[Perf] Use Triton instead of Torch for DeepGEMM Per Token Group Quant (#20841)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -13,9 +13,10 @@ import torch
|
||||
|
||||
# vLLM fused-expert reference (Triton fallback + DeepGEMM option)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.utils import has_deep_gemm
|
||||
from vllm.utils.deep_gemm import (calc_diff, per_block_cast_to_fp8,
|
||||
per_token_group_cast_to_fp8)
|
||||
from vllm.utils.deep_gemm import calc_diff, per_block_cast_to_fp8
|
||||
|
||||
BLOCK_SIZE = [128, 128]
|
||||
|
||||
@@ -81,7 +82,7 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
|
||||
"""
|
||||
tokens_bf16 = torch.randn(
|
||||
m, k, device="cuda", dtype=torch.bfloat16).clamp_min_(-1).clamp_max_(1)
|
||||
_, a1_scale = per_token_group_cast_to_fp8(tokens_bf16, block_size[1])
|
||||
_, a1_scale = per_token_group_quant_fp8(tokens_bf16, block_size[1])
|
||||
|
||||
# expert weight tensors
|
||||
w1, w2, w1_s, w2_s = make_block_quant_fp8_weights(num_experts, n, k,
|
||||
|
||||
Reference in New Issue
Block a user