Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -14,14 +14,15 @@ FLOAT8_E8M0_MAX_EXP = 127
|
||||
FLOAT4_EXP_BIAS = 1
|
||||
FLOAT4_MANTISSA_BITS = 1
|
||||
|
||||
FLOAT16_VAL_TO_ADD = (1 << (FLOAT16_MANTISSA_BITS - FLOAT4_MANTISSA_BITS - 1))
|
||||
FLOAT16_SIGN_EXPONENT_MASK = ((
|
||||
(1 << (FLOAT16_EXP_BITS + 1)) - 1) << FLOAT16_MANTISSA_BITS)
|
||||
FLOAT16_VAL_TO_ADD = 1 << (FLOAT16_MANTISSA_BITS - FLOAT4_MANTISSA_BITS - 1)
|
||||
FLOAT16_SIGN_EXPONENT_MASK = (
|
||||
(1 << (FLOAT16_EXP_BITS + 1)) - 1
|
||||
) << FLOAT16_MANTISSA_BITS
|
||||
|
||||
BFLOAT16_VAL_TO_ADD = (1 <<
|
||||
(BFLOAT16_MANTISSA_BITS - FLOAT4_MANTISSA_BITS - 1))
|
||||
BFLOAT16_SIGN_EXPONENT_MASK = ((
|
||||
(1 << (BFLOAT16_EXP_BITS + 1)) - 1) << BFLOAT16_MANTISSA_BITS)
|
||||
BFLOAT16_VAL_TO_ADD = 1 << (BFLOAT16_MANTISSA_BITS - FLOAT4_MANTISSA_BITS - 1)
|
||||
BFLOAT16_SIGN_EXPONENT_MASK = (
|
||||
(1 << (BFLOAT16_EXP_BITS + 1)) - 1
|
||||
) << BFLOAT16_MANTISSA_BITS
|
||||
|
||||
|
||||
def e8m0_to_half(scale, half_dtype: torch.dtype):
|
||||
@@ -30,19 +31,19 @@ def e8m0_to_half(scale, half_dtype: torch.dtype):
|
||||
scale_exp = scale.to(torch.int16) - 127
|
||||
|
||||
# This can be implemented with bitwise operations in a proper kernel.
|
||||
scale_half = 2.0**(scale_exp.to(torch.float))
|
||||
scale_half = 2.0 ** (scale_exp.to(torch.float))
|
||||
|
||||
return scale_half.to(half_dtype)
|
||||
|
||||
|
||||
def upcast_fp4_to_fp16_or_bf16(val, float_dtype: torch.dtype,
|
||||
half_exp_bias: int, half_mantissa_bits: int):
|
||||
def upcast_fp4_to_fp16_or_bf16(
|
||||
val, float_dtype: torch.dtype, half_exp_bias: int, half_mantissa_bits: int
|
||||
):
|
||||
assert val.dtype == torch.uint8
|
||||
|
||||
unpacked = torch.zeros(*val.shape[:-1],
|
||||
val.shape[-1] * 2,
|
||||
dtype=torch.uint8,
|
||||
device=val.device)
|
||||
unpacked = torch.zeros(
|
||||
*val.shape[:-1], val.shape[-1] * 2, dtype=torch.uint8, device=val.device
|
||||
)
|
||||
unpacked[..., 1::2] = (val >> 4) & 0x0F # Extract high 4 bits.
|
||||
unpacked[..., ::2] = val & 0x0F # Extract low 4 bits.
|
||||
|
||||
@@ -72,8 +73,11 @@ def upcast_fp4_to_fp16_or_bf16(val, float_dtype: torch.dtype,
|
||||
new_exp = new_exp.to(torch.int32)
|
||||
sign = sign.to(torch.int32)
|
||||
|
||||
qdq_val = (sign << 15) + (new_exp << half_mantissa_bits) + (
|
||||
new_mantissa << (half_mantissa_bits - 1))
|
||||
qdq_val = (
|
||||
(sign << 15)
|
||||
+ (new_exp << half_mantissa_bits)
|
||||
+ (new_mantissa << (half_mantissa_bits - 1))
|
||||
)
|
||||
|
||||
assert qdq_val.max() <= 65535
|
||||
assert qdq_val.min() >= 0
|
||||
@@ -84,8 +88,9 @@ def upcast_fp4_to_fp16_or_bf16(val, float_dtype: torch.dtype,
|
||||
return result
|
||||
|
||||
|
||||
def dq_mxfp4_torch(x: torch.Tensor, scale: torch.Tensor,
|
||||
float_dtype: torch.dtype) -> torch.Tensor:
|
||||
def dq_mxfp4_torch(
|
||||
x: torch.Tensor, scale: torch.Tensor, float_dtype: torch.dtype
|
||||
) -> torch.Tensor:
|
||||
assert x.dtype == torch.uint8
|
||||
assert scale.dtype == torch.uint8
|
||||
|
||||
@@ -98,10 +103,12 @@ def dq_mxfp4_torch(x: torch.Tensor, scale: torch.Tensor,
|
||||
|
||||
scale_half = e8m0_to_half(scale, half_dtype=float_dtype)
|
||||
|
||||
x_half = upcast_fp4_to_fp16_or_bf16(x,
|
||||
float_dtype=float_dtype,
|
||||
half_exp_bias=half_exp_bias,
|
||||
half_mantissa_bits=half_mantissa_bits)
|
||||
x_half = upcast_fp4_to_fp16_or_bf16(
|
||||
x,
|
||||
float_dtype=float_dtype,
|
||||
half_exp_bias=half_exp_bias,
|
||||
half_mantissa_bits=half_mantissa_bits,
|
||||
)
|
||||
|
||||
x_half = x_half.reshape(*x_half.shape[:-1], -1, 32)
|
||||
x_half = x_half * scale_half[..., None]
|
||||
@@ -110,8 +117,9 @@ def dq_mxfp4_torch(x: torch.Tensor, scale: torch.Tensor,
|
||||
return x_half
|
||||
|
||||
|
||||
def fp16_to_fp4_simulate(val, half_mantissa_bits: int, half_exp_bits: int,
|
||||
half_exp_bias: int):
|
||||
def fp16_to_fp4_simulate(
|
||||
val, half_mantissa_bits: int, half_exp_bits: int, half_exp_bias: int
|
||||
):
|
||||
# Casts an fp16/bf16 input to the restricted values of float4_e2m1,
|
||||
# that is to say [0., 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0,
|
||||
# -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0].
|
||||
@@ -119,7 +127,7 @@ def fp16_to_fp4_simulate(val, half_mantissa_bits: int, half_exp_bits: int,
|
||||
float_type = val.dtype
|
||||
|
||||
# "rshift_cuda" not implemented for 'UInt16'
|
||||
val_view = val.view(torch.int16) #.to(torch.int32)
|
||||
val_view = val.view(torch.int16) # .to(torch.int32)
|
||||
|
||||
exp = val_view >> half_mantissa_bits
|
||||
exp = exp & ((1 << half_exp_bits) - 1)
|
||||
@@ -147,23 +155,15 @@ def fp16_to_fp4_simulate(val, half_mantissa_bits: int, half_exp_bits: int,
|
||||
|
||||
tail = mantissa_plus_one & ((1 << tail_bits) - 1)
|
||||
|
||||
round_close = (tail < half) # round towards 0
|
||||
round_away = (tail > half) # round away from 0
|
||||
round_close = tail < half # round towards 0
|
||||
round_away = tail > half # round away from 0
|
||||
tie = tail == half
|
||||
|
||||
new_mantissa_close = torch.zeros(val.shape,
|
||||
device=val.device,
|
||||
dtype=torch.bool)
|
||||
new_exp_close = torch.zeros(val.shape,
|
||||
device=val.device,
|
||||
dtype=torch.uint16)
|
||||
new_mantissa_close = torch.zeros(val.shape, device=val.device, dtype=torch.bool)
|
||||
new_exp_close = torch.zeros(val.shape, device=val.device, dtype=torch.uint16)
|
||||
|
||||
new_mantissa_away = torch.zeros(val.shape,
|
||||
device=val.device,
|
||||
dtype=torch.bool)
|
||||
new_exp_away = torch.zeros(val.shape,
|
||||
device=val.device,
|
||||
dtype=torch.uint16)
|
||||
new_mantissa_away = torch.zeros(val.shape, device=val.device, dtype=torch.bool)
|
||||
new_exp_away = torch.zeros(val.shape, device=val.device, dtype=torch.uint16)
|
||||
|
||||
new_exp_tie = torch.zeros(val.shape, device=val.device, dtype=torch.uint16)
|
||||
|
||||
@@ -202,27 +202,29 @@ def fp16_to_fp4_simulate(val, half_mantissa_bits: int, half_exp_bits: int,
|
||||
new_exp_tie = (exp > (half_exp_bias - 2)) * (exp + (mantissa_last == 1))
|
||||
|
||||
# Gather round up, round down and tie.
|
||||
new_exp = round_away * new_exp_away \
|
||||
+ round_close * new_exp_close \
|
||||
+ tie * new_exp_tie
|
||||
new_exp = (
|
||||
round_away * new_exp_away + round_close * new_exp_close + tie * new_exp_tie
|
||||
)
|
||||
|
||||
new_mantissa = round_away * new_mantissa_away \
|
||||
+ round_close * new_mantissa_close
|
||||
new_mantissa = round_away * new_mantissa_away + round_close * new_mantissa_close
|
||||
|
||||
# if new_exp > 3:
|
||||
# new_mantissa = 1
|
||||
new_mantissa = new_mantissa + (new_exp >
|
||||
(2 + half_exp_bias)) * (new_mantissa == 0)
|
||||
new_mantissa = new_mantissa + (new_exp > (2 + half_exp_bias)) * (new_mantissa == 0)
|
||||
|
||||
# Clamp the exponent to acceptable values.
|
||||
new_exp = (new_exp >= (half_exp_bias - 2)) * torch.clamp(
|
||||
new_exp, half_exp_bias - 2, half_exp_bias + 2)
|
||||
new_exp, half_exp_bias - 2, half_exp_bias + 2
|
||||
)
|
||||
|
||||
sign = sign.to(torch.int32)
|
||||
new_mantissa = new_mantissa.to(torch.int32)
|
||||
|
||||
qdq_val = (sign << 15) + (new_exp << half_mantissa_bits) + (
|
||||
new_mantissa << (half_mantissa_bits - 1))
|
||||
qdq_val = (
|
||||
(sign << 15)
|
||||
+ (new_exp << half_mantissa_bits)
|
||||
+ (new_mantissa << (half_mantissa_bits - 1))
|
||||
)
|
||||
|
||||
assert qdq_val.max() <= 65535
|
||||
assert qdq_val.min() >= 0
|
||||
@@ -233,8 +235,9 @@ def fp16_to_fp4_simulate(val, half_mantissa_bits: int, half_exp_bits: int,
|
||||
return result
|
||||
|
||||
|
||||
def qdq_mxfp4_torch(x: torch.Tensor,
|
||||
scale_calculation_mode: str = "even") -> torch.Tensor:
|
||||
def qdq_mxfp4_torch(
|
||||
x: torch.Tensor, scale_calculation_mode: str = "even"
|
||||
) -> torch.Tensor:
|
||||
half_dtype = x.dtype
|
||||
|
||||
if half_dtype == torch.float16:
|
||||
@@ -258,8 +261,7 @@ def qdq_mxfp4_torch(x: torch.Tensor,
|
||||
|
||||
block_max = block_max.view(torch.uint16).to(torch.int32)
|
||||
|
||||
block_max_uint = torch.bitwise_and(block_max + val_to_add,
|
||||
sign_exponent_mask)
|
||||
block_max_uint = torch.bitwise_and(block_max + val_to_add, sign_exponent_mask)
|
||||
|
||||
assert block_max_uint.max() <= 65535
|
||||
assert block_max_uint.min() >= 0
|
||||
@@ -268,20 +270,23 @@ def qdq_mxfp4_torch(x: torch.Tensor,
|
||||
|
||||
block_max = block_max_uint.view(half_dtype)
|
||||
|
||||
scale_exp = FLOAT8_E8M0_MAX_EXP + torch.floor(torch.log2(block_max)).to(
|
||||
torch.int32) - 2
|
||||
scale_exp = (
|
||||
FLOAT8_E8M0_MAX_EXP + torch.floor(torch.log2(block_max)).to(torch.int32) - 2
|
||||
)
|
||||
|
||||
scale_exp = torch.clamp(scale_exp, 0, 2 * FLOAT8_E8M0_MAX_EXP)
|
||||
|
||||
scale = 2.0**(scale_exp - FLOAT8_E8M0_MAX_EXP)
|
||||
scale = 2.0 ** (scale_exp - FLOAT8_E8M0_MAX_EXP)
|
||||
scale = scale.to(half_dtype)
|
||||
|
||||
x = x / scale[..., None]
|
||||
|
||||
x_fp4 = fp16_to_fp4_simulate(x,
|
||||
half_exp_bits=half_exp_bits,
|
||||
half_mantissa_bits=half_mantissa_bits,
|
||||
half_exp_bias=half_exp_bias)
|
||||
x_fp4 = fp16_to_fp4_simulate(
|
||||
x,
|
||||
half_exp_bits=half_exp_bits,
|
||||
half_mantissa_bits=half_mantissa_bits,
|
||||
half_exp_bias=half_exp_bias,
|
||||
)
|
||||
|
||||
x_fp4 = x_fp4 * scale[..., None]
|
||||
return x_fp4.reshape(*x_fp4.shape[:-2], -1)
|
||||
|
||||
Reference in New Issue
Block a user