diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 771035691..9376b4e6d 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -139,7 +139,9 @@ def _fused_moe_lora_kernel( offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) token_ind = stride_tl * lora_id + offs_token_id offs_token = tl.load( - sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0 + sorted_token_ids_ptr + token_ind, + mask=token_ind < max_loras * stride_tl, + other=num_valid_tokens, ) token_mask = offs_token < num_valid_tokens @@ -185,7 +187,7 @@ def _fused_moe_lora_kernel( b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk if MUL_ROUTED_WEIGHT: - moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0) accumulator = accumulator * moe_weight[:, None] accumulator = accumulator.to(c_ptr.dtype.element_ty) # Write back the block of the output