[Kernel] Optimization of the mm_k operator. (#28280)
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -23,6 +23,7 @@ def mm_k(
|
|||||||
CAST_TYPE: tl.constexpr,
|
CAST_TYPE: tl.constexpr,
|
||||||
b_dtype: tl.constexpr,
|
b_dtype: tl.constexpr,
|
||||||
USE_GDC: tl.constexpr,
|
USE_GDC: tl.constexpr,
|
||||||
|
base_k,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
|
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
|
||||||
@@ -47,32 +48,62 @@ def mm_k(
|
|||||||
matrix dtype.
|
matrix dtype.
|
||||||
b_dtype: datatype of the B matrix
|
b_dtype: datatype of the B matrix
|
||||||
USE_GDC: Whether to use PDL. True indicates use.
|
USE_GDC: Whether to use PDL. True indicates use.
|
||||||
|
base_k: Base offset along K dimension for current SPLIT_K group
|
||||||
"""
|
"""
|
||||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||||
for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
|
||||||
|
# Step size along K for each iteration
|
||||||
|
STEP_K = BLOCK_K * SPLIT_K
|
||||||
|
|
||||||
|
# Total number of iterations (compile-time constant)
|
||||||
|
num_iters = tl.cdiv(K, STEP_K)
|
||||||
|
|
||||||
|
for k in range(num_iters):
|
||||||
|
# Current iteration's global K offset
|
||||||
|
iter_k = k * STEP_K + base_k
|
||||||
|
|
||||||
|
# Check if this iteration is completely valid (no masking needed)
|
||||||
|
block_end = iter_k + BLOCK_K
|
||||||
|
|
||||||
if EVEN_K:
|
if EVEN_K:
|
||||||
# pre-fetech lora weight
|
# K is divisible by BLOCK_K, no masking ever needed
|
||||||
|
# pre-fetch lora weight
|
||||||
tiled_b = tl.load(b_ptr)
|
tiled_b = tl.load(b_ptr)
|
||||||
if USE_GDC:
|
if USE_GDC:
|
||||||
tl.extra.cuda.gdc_wait()
|
tl.extra.cuda.gdc_wait()
|
||||||
tiled_a = tl.load(a_ptr)
|
tiled_a = tl.load(a_ptr)
|
||||||
|
if CAST_TYPE:
|
||||||
|
tiled_a = tiled_a.to(b_dtype)
|
||||||
|
accumulator += tl.dot(tiled_a, tiled_b)
|
||||||
else:
|
else:
|
||||||
tiled_b = tl.load(
|
# Check if we need element-wise masking
|
||||||
b_ptr, mask=offset_k[:, None] < K - k * (BLOCK_K * SPLIT_K), other=0
|
if iter_k >= K:
|
||||||
)
|
# Entire block out of range, skip
|
||||||
if USE_GDC:
|
pass
|
||||||
tl.extra.cuda.gdc_wait()
|
elif block_end <= K:
|
||||||
tiled_a = tl.load(
|
# Entire block in range, no masking needed (fast path)
|
||||||
a_ptr, mask=offset_k[None, :] < K - k * (BLOCK_K * SPLIT_K), other=0
|
tiled_b = tl.load(b_ptr)
|
||||||
)
|
if USE_GDC:
|
||||||
if CAST_TYPE:
|
tl.extra.cuda.gdc_wait()
|
||||||
tiled_a = tiled_a.to(b_dtype)
|
tiled_a = tl.load(a_ptr)
|
||||||
accumulator += tl.dot(
|
if CAST_TYPE:
|
||||||
tiled_a,
|
tiled_a = tiled_a.to(b_dtype)
|
||||||
tiled_b,
|
accumulator += tl.dot(tiled_a, tiled_b)
|
||||||
)
|
else:
|
||||||
a_ptr += BLOCK_K * SPLIT_K * ak_stride
|
# Partial block, need masking (only last iteration)
|
||||||
b_ptr += BLOCK_K * SPLIT_K * bk_stride
|
k_offsets = tl.arange(0, BLOCK_K)
|
||||||
|
mask = iter_k + k_offsets < K
|
||||||
|
tiled_b = tl.load(b_ptr, mask=mask[:, None], other=0.0)
|
||||||
|
if USE_GDC:
|
||||||
|
tl.extra.cuda.gdc_wait()
|
||||||
|
tiled_a = tl.load(a_ptr, mask=mask[None, :], other=0.0)
|
||||||
|
if CAST_TYPE:
|
||||||
|
tiled_a = tiled_a.to(b_dtype)
|
||||||
|
accumulator += tl.dot(tiled_a, tiled_b)
|
||||||
|
|
||||||
|
a_ptr += STEP_K * ak_stride
|
||||||
|
b_ptr += STEP_K * bk_stride
|
||||||
|
|
||||||
return accumulator
|
return accumulator
|
||||||
|
|
||||||
|
|
||||||
@@ -178,6 +209,7 @@ def do_expand_kernel(
|
|||||||
CAST_TYPE,
|
CAST_TYPE,
|
||||||
cur_lora_ptr.dtype.element_ty,
|
cur_lora_ptr.dtype.element_ty,
|
||||||
USE_GDC,
|
USE_GDC,
|
||||||
|
base_k=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
|
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
|
||||||
@@ -284,6 +316,7 @@ def do_shrink_kernel(
|
|||||||
False,
|
False,
|
||||||
cur_lora_ptr.dtype.element_ty,
|
cur_lora_ptr.dtype.element_ty,
|
||||||
False, # USE_GDC is always False in shrink kernel
|
False, # USE_GDC is always False in shrink kernel
|
||||||
|
base_k=pid_sk * BLOCK_K,
|
||||||
)
|
)
|
||||||
# GDC launch dependents hints the runtime system to launch dependent kernels.
|
# GDC launch dependents hints the runtime system to launch dependent kernels.
|
||||||
if USE_GDC:
|
if USE_GDC:
|
||||||
|
|||||||
Reference in New Issue
Block a user