[Bugfix] Fix unstable silu_mul+nvfp4 quant fusion test (#24370)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm._custom_ops import scaled_fp4_quant
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
@@ -65,3 +66,10 @@ def break_fp4_bytes(a, dtype):
|
||||
|
||||
# Reshape to final form
|
||||
return values.reshape(m, n * 2).to(dtype=dtype)
|
||||
|
||||
|
||||
def quant_nvfp4_tensor(a: torch.Tensor):
|
||||
a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) /
|
||||
torch.abs(a).max().to(torch.float32))
|
||||
a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale)
|
||||
return a_quant, a_block_scale, a_global_scale
|
||||
|
||||
Reference in New Issue
Block a user