Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -70,20 +70,19 @@ class ScalarType:
|
||||
"""
|
||||
|
||||
def _floating_point_max_int(self) -> int:
|
||||
assert (
|
||||
self.mantissa <= 52 and self.exponent <= 11
|
||||
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
||||
assert self.mantissa <= 52 and self.exponent <= 11, (
|
||||
f"Cannot represent max/min as a double for type {self.__str__()}"
|
||||
)
|
||||
|
||||
max_mantissa = (1 << self.mantissa) - 1
|
||||
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
|
||||
max_mantissa = max_mantissa - 1
|
||||
|
||||
max_exponent = (1 << self.exponent) - 2
|
||||
if (self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN
|
||||
or self.nan_repr == NanRepr.NONE):
|
||||
assert (
|
||||
self.exponent < 11
|
||||
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
||||
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN or self.nan_repr == NanRepr.NONE:
|
||||
assert self.exponent < 11, (
|
||||
f"Cannot represent max/min as a double for type {self.__str__()}"
|
||||
)
|
||||
max_exponent = max_exponent + 1
|
||||
|
||||
# adjust the exponent to match that of a double
|
||||
@@ -96,38 +95,39 @@ class ScalarType:
|
||||
exponent_bias = (1 << (self.exponent - 1)) - 1
|
||||
exponent_bias_double = (1 << 10) - 1 # double e = 11
|
||||
|
||||
max_exponent_double = (max_exponent - exponent_bias +
|
||||
exponent_bias_double)
|
||||
max_exponent_double = max_exponent - exponent_bias + exponent_bias_double
|
||||
|
||||
# shift the mantissa and exponent into the proper positions for an
|
||||
# IEEE double and bitwise-or them together.
|
||||
return (max_mantissa <<
|
||||
(52 - self.mantissa)) | (max_exponent_double << 52)
|
||||
return (max_mantissa << (52 - self.mantissa)) | (max_exponent_double << 52)
|
||||
|
||||
def _floating_point_max(self) -> float:
|
||||
double_raw = self._floating_point_max_int()
|
||||
return struct.unpack('!d', struct.pack('!Q', double_raw))[0]
|
||||
return struct.unpack("!d", struct.pack("!Q", double_raw))[0]
|
||||
|
||||
def _raw_max(self) -> Union[int, float]:
|
||||
if self.is_floating_point():
|
||||
return self._floating_point_max()
|
||||
else:
|
||||
assert (self.size_bits < 64 or self.size_bits == 64
|
||||
and self.is_signed()), "Cannot represent max as an int"
|
||||
assert self.size_bits < 64 or self.size_bits == 64 and self.is_signed(), (
|
||||
"Cannot represent max as an int"
|
||||
)
|
||||
return (1 << self.mantissa) - 1
|
||||
|
||||
def _raw_min(self) -> Union[int, float]:
|
||||
if self.is_floating_point():
|
||||
assert self.is_signed(
|
||||
), "We currently assume all floating point types are signed"
|
||||
assert self.is_signed(), (
|
||||
"We currently assume all floating point types are signed"
|
||||
)
|
||||
sign_bit_double = 1 << 63
|
||||
|
||||
max_raw = self._floating_point_max_int()
|
||||
min_raw = max_raw | sign_bit_double
|
||||
return struct.unpack('!d', struct.pack('!Q', min_raw))[0]
|
||||
return struct.unpack("!d", struct.pack("!Q", min_raw))[0]
|
||||
else:
|
||||
assert (not self.is_signed() or self.size_bits
|
||||
<= 64), "Cannot represent min as a int64_t"
|
||||
assert not self.is_signed() or self.size_bits <= 64, (
|
||||
"Cannot represent min as a int64_t"
|
||||
)
|
||||
|
||||
if self.is_signed():
|
||||
return -(1 << (self.size_bits - 1))
|
||||
@@ -158,8 +158,7 @@ class ScalarType:
|
||||
or_and_advance(self._finite_values_only, 1)
|
||||
or_and_advance(self.nan_repr.value, 8)
|
||||
|
||||
assert offset <= 64, \
|
||||
f"ScalarType fields too big {offset} to fit into an int64"
|
||||
assert offset <= 64, f"ScalarType fields too big {offset} to fit into an int64"
|
||||
|
||||
_SCALAR_TYPES_ID_MAP[val] = self
|
||||
|
||||
@@ -215,8 +214,7 @@ class ScalarType:
|
||||
If the type is a floating point type that follows IEEE 754
|
||||
conventions
|
||||
"""
|
||||
return self.nan_repr == NanRepr.IEEE_754.value and \
|
||||
not self._finite_values_only
|
||||
return self.nan_repr == NanRepr.IEEE_754.value and not self._finite_values_only
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
@@ -232,8 +230,14 @@ class ScalarType:
|
||||
- if bias is not present it means its zero
|
||||
"""
|
||||
if self.is_floating_point():
|
||||
ret = "float" + str(self.size_bits) + "_e" + str(
|
||||
self.exponent) + "m" + str(self.mantissa)
|
||||
ret = (
|
||||
"float"
|
||||
+ str(self.size_bits)
|
||||
+ "_e"
|
||||
+ str(self.exponent)
|
||||
+ "m"
|
||||
+ str(self.mantissa)
|
||||
)
|
||||
|
||||
if not self.is_ieee_754():
|
||||
if self._finite_values_only:
|
||||
@@ -261,41 +265,43 @@ class ScalarType:
|
||||
#
|
||||
|
||||
@classmethod
|
||||
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
||||
def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
|
||||
"Create a signed integer scalar type (size_bits includes sign-bit)."
|
||||
ret = cls(0, size_bits - 1, True, bias if bias else 0)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
|
||||
def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
|
||||
"""Create an unsigned integer scalar type."""
|
||||
ret = cls(0, size_bits, False, bias if bias else 0)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
|
||||
def float_IEEE754(cls, exponent: int, mantissa: int) -> "ScalarType":
|
||||
"""
|
||||
Create a standard floating point type
|
||||
(i.e. follows IEEE 754 conventions).
|
||||
"""
|
||||
assert (mantissa > 0 and exponent > 0)
|
||||
assert mantissa > 0 and exponent > 0
|
||||
ret = cls(exponent, mantissa, True, 0)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
|
||||
nan_repr: NanRepr) -> 'ScalarType':
|
||||
def float_(
|
||||
cls, exponent: int, mantissa: int, finite_values_only: bool, nan_repr: NanRepr
|
||||
) -> "ScalarType":
|
||||
"""
|
||||
Create a non-standard floating point type
|
||||
(i.e. does not follow IEEE 754 conventions).
|
||||
"""
|
||||
assert (mantissa > 0 and exponent > 0)
|
||||
assert (nan_repr != NanRepr.IEEE_754), (
|
||||
assert mantissa > 0 and exponent > 0
|
||||
assert nan_repr != NanRepr.IEEE_754, (
|
||||
"use `float_IEEE754` constructor for floating point types that "
|
||||
"follow IEEE 754 conventions")
|
||||
"follow IEEE 754 conventions"
|
||||
)
|
||||
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
@@ -303,8 +309,7 @@ class ScalarType:
|
||||
@classmethod
|
||||
def from_id(cls, scalar_type_id: int):
|
||||
if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
|
||||
raise ValueError(
|
||||
f"scalar_type_id {scalar_type_id} doesn't exists.")
|
||||
raise ValueError(f"scalar_type_id {scalar_type_id} doesn't exists.")
|
||||
return _SCALAR_TYPES_ID_MAP[scalar_type_id]
|
||||
|
||||
|
||||
@@ -327,8 +332,7 @@ class scalar_types:
|
||||
uint8 = ScalarType.uint(8, None)
|
||||
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
|
||||
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
|
||||
float8_e8m0fnu = ScalarType(8, 0, False, 0, True,
|
||||
NanRepr.EXTD_RANGE_MAX_MIN)
|
||||
float8_e8m0fnu = ScalarType(8, 0, False, 0, True, NanRepr.EXTD_RANGE_MAX_MIN)
|
||||
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
|
||||
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user