Using max_loras + 1 to construct grid in fused_moe_lora (#32277)
Signed-off-by: Yu Gong <yu3.gong@gmail.com>
This commit is contained in:
@@ -104,7 +104,10 @@ def _fused_moe_lora_kernel(
|
|||||||
if moe_enabled == 0:
|
if moe_enabled == 0:
|
||||||
# Early exit for the no moe lora case.
|
# Early exit for the no moe lora case.
|
||||||
return
|
return
|
||||||
max_loras = tl.num_programs(axis=2)
|
# The grid size on axis 2 is (max_loras + 1) to handle the no-lora case
|
||||||
|
# (lora_id == -1), but sorted_token_ids and expert_ids are allocated with
|
||||||
|
# shape (max_loras, ...). Use (num_programs - 1) for correct bounds checking.
|
||||||
|
max_loras = tl.num_programs(axis=2) - 1
|
||||||
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
|
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
|
||||||
|
|
||||||
# calculate pid_m,pid_n
|
# calculate pid_m,pid_n
|
||||||
@@ -255,7 +258,8 @@ def _fused_moe_lora_shrink(
|
|||||||
* triton.cdiv(EM, META["BLOCK_SIZE_M"])
|
* triton.cdiv(EM, META["BLOCK_SIZE_M"])
|
||||||
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||||
len(lora_a_stacked),
|
len(lora_a_stacked),
|
||||||
lora_a_stacked[0].shape[0],
|
## max_loras + 1 to handle the no-lora case (lora_id == -1)
|
||||||
|
lora_a_stacked[0].shape[0] + 1,
|
||||||
)
|
)
|
||||||
_fused_moe_lora_kernel[grid](
|
_fused_moe_lora_kernel[grid](
|
||||||
qcurr_hidden_states,
|
qcurr_hidden_states,
|
||||||
@@ -355,7 +359,8 @@ def _fused_moe_lora_expand(
|
|||||||
grid = lambda META: (
|
grid = lambda META: (
|
||||||
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||||
len(lora_b_stacked),
|
len(lora_b_stacked),
|
||||||
lora_b_stacked[0].shape[0],
|
## max_loras + 1 to handle the no-lora case (lora_id == -1)
|
||||||
|
lora_b_stacked[0].shape[0] + 1,
|
||||||
)
|
)
|
||||||
_fused_moe_lora_kernel[grid](
|
_fused_moe_lora_kernel[grid](
|
||||||
a_intermediate_cache1,
|
a_intermediate_cache1,
|
||||||
|
|||||||
Reference in New Issue
Block a user