[Bugfix] Fix FusedMoE LoRA kernel offs_token out of bound value (#32279)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -139,7 +139,9 @@ def _fused_moe_lora_kernel(
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
token_ind = stride_tl * lora_id + offs_token_id
|
||||
offs_token = tl.load(
|
||||
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0
|
||||
sorted_token_ids_ptr + token_ind,
|
||||
mask=token_ind < max_loras * stride_tl,
|
||||
other=num_valid_tokens,
|
||||
)
|
||||
token_mask = offs_token < num_valid_tokens
|
||||
|
||||
@@ -185,7 +187,7 @@ def _fused_moe_lora_kernel(
|
||||
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
|
||||
|
||||
if MUL_ROUTED_WEIGHT:
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
|
||||
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0)
|
||||
accumulator = accumulator * moe_weight[:, None]
|
||||
accumulator = accumulator.to(c_ptr.dtype.element_ty)
|
||||
# Write back the block of the output
|
||||
|
||||
Reference in New Issue
Block a user