[Bugfix] Fix FusedMoE LoRA kernel offs_token out of bound value (#32279)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -139,7 +139,9 @@ def _fused_moe_lora_kernel(
|
|||||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||||
token_ind = stride_tl * lora_id + offs_token_id
|
token_ind = stride_tl * lora_id + offs_token_id
|
||||||
offs_token = tl.load(
|
offs_token = tl.load(
|
||||||
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0
|
sorted_token_ids_ptr + token_ind,
|
||||||
|
mask=token_ind < max_loras * stride_tl,
|
||||||
|
other=num_valid_tokens,
|
||||||
)
|
)
|
||||||
token_mask = offs_token < num_valid_tokens
|
token_mask = offs_token < num_valid_tokens
|
||||||
|
|
||||||
@@ -185,7 +187,7 @@ def _fused_moe_lora_kernel(
|
|||||||
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
|
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
|
||||||
|
|
||||||
if MUL_ROUTED_WEIGHT:
|
if MUL_ROUTED_WEIGHT:
|
||||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0)
|
||||||
accumulator = accumulator * moe_weight[:, None]
|
accumulator = accumulator * moe_weight[:, None]
|
||||||
accumulator = accumulator.to(c_ptr.dtype.element_ty)
|
accumulator = accumulator.to(c_ptr.dtype.element_ty)
|
||||||
# Write back the block of the output
|
# Write back the block of the output
|
||||||
|
|||||||
Reference in New Issue
Block a user