[Bugfix] Fix unstable silu_mul+nvfp4 quant fusion test (#24370)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
elvischenv
2025-09-07 04:39:34 +08:00
committed by GitHub
parent a3645ed94d
commit e68dc2f014
2 changed files with 38 additions and 16 deletions

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm._custom_ops import scaled_fp4_quant
from vllm.scalar_type import scalar_types
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
@@ -65,3 +66,10 @@ def break_fp4_bytes(a, dtype):
# Reshape to final form
return values.reshape(m, n * 2).to(dtype=dtype)
def quant_nvfp4_tensor(a: torch.Tensor):
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
torch.abs(a).max().to(torch.float32))
a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale)
return a_quant, a_block_scale, a_global_scale