[Kernel] Optimization of the mm_k operator. (#28280)

Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
caozuoba
2025-11-11 00:03:46 +08:00
committed by GitHub
parent b06b9470ca
commit 40e2eeeb92

View File

@@ -23,6 +23,7 @@ def mm_k(
CAST_TYPE: tl.constexpr, CAST_TYPE: tl.constexpr,
b_dtype: tl.constexpr, b_dtype: tl.constexpr,
USE_GDC: tl.constexpr, USE_GDC: tl.constexpr,
base_k,
): ):
""" """
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
@@ -47,32 +48,62 @@ def mm_k(
matrix dtype. matrix dtype.
b_dtype: datatype of the B matrix b_dtype: datatype of the B matrix
USE_GDC: Whether to use PDL. True indicates use. USE_GDC: Whether to use PDL. True indicates use.
base_k: Base offset along K dimension for current SPLIT_K group
""" """
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)):
# Step size along K for each iteration
STEP_K = BLOCK_K * SPLIT_K
# Total number of iterations (compile-time constant)
num_iters = tl.cdiv(K, STEP_K)
for k in range(num_iters):
# Current iteration's global K offset
iter_k = k * STEP_K + base_k
# Check if this iteration is completely valid (no masking needed)
block_end = iter_k + BLOCK_K
if EVEN_K: if EVEN_K:
# pre-fetech lora weight # K is divisible by BLOCK_K, no masking ever needed
# pre-fetch lora weight
tiled_b = tl.load(b_ptr) tiled_b = tl.load(b_ptr)
if USE_GDC: if USE_GDC:
tl.extra.cuda.gdc_wait() tl.extra.cuda.gdc_wait()
tiled_a = tl.load(a_ptr) tiled_a = tl.load(a_ptr)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(tiled_a, tiled_b)
else: else:
tiled_b = tl.load( # Check if we need element-wise masking
b_ptr, mask=offset_k[:, None] < K - k * (BLOCK_K * SPLIT_K), other=0 if iter_k >= K:
) # Entire block out of range, skip
if USE_GDC: pass
tl.extra.cuda.gdc_wait() elif block_end <= K:
tiled_a = tl.load( # Entire block in range, no masking needed (fast path)
a_ptr, mask=offset_k[None, :] < K - k * (BLOCK_K * SPLIT_K), other=0 tiled_b = tl.load(b_ptr)
) if USE_GDC:
if CAST_TYPE: tl.extra.cuda.gdc_wait()
tiled_a = tiled_a.to(b_dtype) tiled_a = tl.load(a_ptr)
accumulator += tl.dot( if CAST_TYPE:
tiled_a, tiled_a = tiled_a.to(b_dtype)
tiled_b, accumulator += tl.dot(tiled_a, tiled_b)
) else:
a_ptr += BLOCK_K * SPLIT_K * ak_stride # Partial block, need masking (only last iteration)
b_ptr += BLOCK_K * SPLIT_K * bk_stride k_offsets = tl.arange(0, BLOCK_K)
mask = iter_k + k_offsets < K
tiled_b = tl.load(b_ptr, mask=mask[:, None], other=0.0)
if USE_GDC:
tl.extra.cuda.gdc_wait()
tiled_a = tl.load(a_ptr, mask=mask[None, :], other=0.0)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(tiled_a, tiled_b)
a_ptr += STEP_K * ak_stride
b_ptr += STEP_K * bk_stride
return accumulator return accumulator
@@ -178,6 +209,7 @@ def do_expand_kernel(
CAST_TYPE, CAST_TYPE,
cur_lora_ptr.dtype.element_ty, cur_lora_ptr.dtype.element_ty,
USE_GDC, USE_GDC,
base_k=0,
) )
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty) tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
@@ -284,6 +316,7 @@ def do_shrink_kernel(
False, False,
cur_lora_ptr.dtype.element_ty, cur_lora_ptr.dtype.element_ty,
False, # USE_GDC is always False in shrink kernel False, # USE_GDC is always False in shrink kernel
base_k=pid_sk * BLOCK_K,
) )
# GDC launch dependents hints the runtime system to launch dependent kernels. # GDC launch dependents hints the runtime system to launch dependent kernels.
if USE_GDC: if USE_GDC: