Bump Flashinfer to v0.4.0 (#26326)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
elvischenv
2025-10-09 14:58:44 +08:00
committed by GitHub
parent 0d7c3cb51d
commit 5e49c3e777
7 changed files with 25 additions and 23 deletions

View File

@@ -66,9 +66,11 @@ def break_fp4_bytes(a, dtype):
return values.reshape(m, n * 2).to(dtype=dtype)
def get_nvfp4_global_scale(a: torch.Tensor):
return (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(torch.float32)
def quant_nvfp4_tensor(a: torch.Tensor):
a_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(
torch.float32
)
a_global_scale = get_nvfp4_global_scale(a)
a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale)
return a_quant, a_block_scale, a_global_scale