[Misc] Fixes and Optimizations for DeepEP + DeepGEMM combination. (#19298)
Signed-off-by: Varun <vsundarr@redhat.com> Co-authored-by: Varun <vsundarr@redhat.com>
This commit is contained in:
committed by
GitHub
parent
b8089195b4
commit
5cf2daea9a
@@ -234,8 +234,13 @@ def _per_token_group_quant_fp8(
|
||||
row = g_id // groups_per_row
|
||||
row_g_id = g_id % groups_per_row
|
||||
|
||||
y_ptr += (row * y_row_stride) + (row_g_id * group_size)
|
||||
y_q_ptr += g_id * group_size
|
||||
# Ensure offset calculations use int64 to prevent overflow
|
||||
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
|
||||
group_size)
|
||||
y_ptr += y_ptr_offset
|
||||
|
||||
y_q_ptr_offset = g_id.to(tl.int64) * group_size
|
||||
y_q_ptr += y_q_ptr_offset
|
||||
y_s_ptr += g_id
|
||||
|
||||
cols = tl.arange(0, BLOCK) # N <= BLOCK
|
||||
@@ -282,15 +287,23 @@ def _per_token_group_quant_fp8_colmajor(
|
||||
row = g_id // groups_per_row
|
||||
row_g_id = g_id % groups_per_row
|
||||
|
||||
y_ptr += (row * y_row_stride) + (row_g_id * group_size)
|
||||
y_q_ptr += g_id * group_size
|
||||
# Ensure offset calculations use int64 to prevent overflow
|
||||
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) *
|
||||
group_size)
|
||||
y_ptr += y_ptr_offset
|
||||
|
||||
y_q_ptr_offset = g_id.to(tl.int64) * group_size
|
||||
y_q_ptr += y_q_ptr_offset
|
||||
|
||||
# Convert g_id the flattened block coordinate to 2D so we can index
|
||||
# into the output y_scales matrix
|
||||
blocks_per_row = y_num_columns // group_size
|
||||
scale_col = g_id % blocks_per_row
|
||||
scale_row = g_id // blocks_per_row
|
||||
y_s_ptr += scale_col * y_s_col_stride + scale_row
|
||||
# Ensure offset calculation uses int64 for y_s_ptr
|
||||
y_s_ptr_offset = (scale_col.to(tl.int64) * y_s_col_stride) + scale_row.to(
|
||||
tl.int64)
|
||||
y_s_ptr += y_s_ptr_offset
|
||||
|
||||
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
|
||||
mask = cols < group_size
|
||||
@@ -311,6 +324,7 @@ def per_token_group_quant_fp8(
|
||||
eps: float = 1e-10,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
column_major_scales: bool = False,
|
||||
out_q: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||
It converts the tensor values into signed float8 values and returns the
|
||||
@@ -321,6 +335,8 @@ def per_token_group_quant_fp8(
|
||||
eps: The minimum to avoid dividing zero.
|
||||
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
|
||||
is supported for now.
|
||||
column_major_scales: Outputs scales in column major.
|
||||
out_q: Optional output tensor. If not provided, function will create.
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
scaling factor for quantization.
|
||||
@@ -335,7 +351,11 @@ def per_token_group_quant_fp8(
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
assert out_q is None or out_q.shape == x.shape
|
||||
x_q = out_q
|
||||
if x_q is None:
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
|
||||
M = x.numel() // group_size
|
||||
N = group_size
|
||||
if column_major_scales:
|
||||
|
||||
Reference in New Issue
Block a user