* Merge with private repo * Update README * Update README * Update README * Add PyTorch requirements * Fix sync scopes for MQA logits (#256) * Update README
47 lines
1.4 KiB
Python
47 lines
1.4 KiB
Python
from typing import Any
|
|
from tilelang import language as T
|
|
|
|
|
|
def ceil_div(x: int, y: int) -> int:
|
|
return (x + y - 1) // y
|
|
|
|
|
|
def align(x: int, y: int) -> int:
|
|
return ceil_div(x, y) * y
|
|
|
|
|
|
def get_sf_shape(
|
|
num_tokens: int,
|
|
hidden: int,
|
|
num_per_channels: int,
|
|
use_ue8m0: bool,
|
|
use_col_major_sf: bool,
|
|
) -> tuple[int, int]:
|
|
num_scales = ceil_div(hidden, num_per_channels)
|
|
num_scales = ceil_div(num_scales, 4) if use_ue8m0 else num_scales
|
|
|
|
# For col-major SF, TMA must be aligned into 16 bytes
|
|
# For UE8M0, we must use col-major SF, and 4 UE8M0 are expanded into the inner dim (token)
|
|
num_sf_tokens = num_tokens
|
|
if use_col_major_sf:
|
|
num_sf_tokens = align(num_tokens, 4)
|
|
num_sf_tokens = num_sf_tokens * 4 if use_ue8m0 else num_sf_tokens
|
|
|
|
return (num_scales, num_sf_tokens) if use_col_major_sf else (num_sf_tokens, num_scales)
|
|
|
|
|
|
def get_sf_and_inv(amax: float, round_sf: bool, use_ue8m0: bool) -> tuple[Any, Any]:
|
|
sf = amax / 448.0
|
|
if not round_sf:
|
|
return sf, 448.0 / amax
|
|
|
|
# Round into 2's power
|
|
bits = T.reinterpret("uint32", sf)
|
|
exp = (bits >> 23) & 0xFF
|
|
man_bits = bits & ((1 << 23) - 1)
|
|
exp_scale = T.reinterpret("int32", exp - 127 + (man_bits != 0))
|
|
if use_ue8m0: # noqa: SIM108
|
|
sf = T.Cast("uint8", exp_scale + 127)
|
|
else:
|
|
sf = T.reinterpret("float", (127 + exp_scale) << 23)
|
|
return sf, T.reinterpret("float", (127 - exp_scale) << 23) |