Using max_loras + 1 to construct grid in fused_moe_lora (#32277)

Signed-off-by: Yu Gong <yu3.gong@gmail.com>
This commit is contained in:
yugong333
2026-01-24 09:39:30 -08:00
committed by GitHub
parent 203d0bc0c2
commit d4dbb7af63

View File

@@ -104,7 +104,10 @@ def _fused_moe_lora_kernel(
if moe_enabled == 0: if moe_enabled == 0:
# Early exit for the no moe lora case. # Early exit for the no moe lora case.
return return
max_loras = tl.num_programs(axis=2) # The grid size on axis 2 is (max_loras + 1) to handle the no-lora case
# (lora_id == -1), but sorted_token_ids and expert_ids are allocated with
# shape (max_loras, ...). Use (num_programs - 1) for correct bounds checking.
max_loras = tl.num_programs(axis=2) - 1
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
# calculate pid_m,pid_n # calculate pid_m,pid_n
@@ -255,7 +258,8 @@ def _fused_moe_lora_shrink(
* triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(EM, META["BLOCK_SIZE_M"])
* triton.cdiv(N, META["BLOCK_SIZE_N"]), * triton.cdiv(N, META["BLOCK_SIZE_N"]),
len(lora_a_stacked), len(lora_a_stacked),
lora_a_stacked[0].shape[0], ## max_loras + 1 to handle the no-lora case (lora_id == -1)
lora_a_stacked[0].shape[0] + 1,
) )
_fused_moe_lora_kernel[grid]( _fused_moe_lora_kernel[grid](
qcurr_hidden_states, qcurr_hidden_states,
@@ -355,7 +359,8 @@ def _fused_moe_lora_expand(
grid = lambda META: ( grid = lambda META: (
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
len(lora_b_stacked), len(lora_b_stacked),
lora_b_stacked[0].shape[0], ## max_loras + 1 to handle the no-lora case (lora_id == -1)
lora_b_stacked[0].shape[0] + 1,
) )
_fused_moe_lora_kernel[grid]( _fused_moe_lora_kernel[grid](
a_intermediate_cache1, a_intermediate_cache1,