[LoRA] LoRA PDL improvement (#31660)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -163,15 +163,17 @@ def _fused_moe_lora_kernel(
|
||||
# accumulator
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
# GDC wait waits for ALL programs in the prior kernel to complete
|
||||
# before continuing.
|
||||
if USE_GDC and not IS_PRIMARY:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
|
||||
for k in range(0, grid_k):
|
||||
k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
|
||||
# GDC wait waits for ALL programs in the prior kernel to complete
|
||||
# before continuing.
|
||||
# pre-fetch lora weight
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
|
||||
if USE_GDC and not IS_PRIMARY:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
|
||||
|
||||
Reference in New Issue
Block a user