Bump Flashinfer to v0.4.0 (#26326)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
@@ -66,9 +66,11 @@ def break_fp4_bytes(a, dtype):
|
||||
return values.reshape(m, n * 2).to(dtype=dtype)
|
||||
|
||||
|
||||
def get_nvfp4_global_scale(a: torch.Tensor):
|
||||
return (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(torch.float32)
|
||||
|
||||
|
||||
def quant_nvfp4_tensor(a: torch.Tensor):
|
||||
a_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(
|
||||
torch.float32
|
||||
)
|
||||
a_global_scale = get_nvfp4_global_scale(a)
|
||||
a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale)
|
||||
return a_quant, a_block_scale, a_global_scale
|
||||
|
||||
Reference in New Issue
Block a user