[Bugfix] Fix FusedMoE LoRA kernel offs_token out of bound value (#32279)

Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
Xin Yang
2026-01-23 17:41:35 -08:00
committed by GitHub
parent 7e1f10d562
commit ecc3dd66cc

View File

@@ -139,7 +139,9 @@ def _fused_moe_lora_kernel(
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
token_ind = stride_tl * lora_id + offs_token_id token_ind = stride_tl * lora_id + offs_token_id
offs_token = tl.load( offs_token = tl.load(
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0 sorted_token_ids_ptr + token_ind,
mask=token_ind < max_loras * stride_tl,
other=num_valid_tokens,
) )
token_mask = offs_token < num_valid_tokens token_mask = offs_token < num_valid_tokens
@@ -185,7 +187,7 @@ def _fused_moe_lora_kernel(
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
if MUL_ROUTED_WEIGHT: if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0)
accumulator = accumulator * moe_weight[:, None] accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(c_ptr.dtype.element_ty) accumulator = accumulator.to(c_ptr.dtype.element_ty)
# Write back the block of the output # Write back the block of the output