Fix Fused MoE LoRA Triton kernel bug (#28450)
Signed-off-by: chaojun-zhang <chaojun.zhang@intel.com>
This commit is contained in:
@@ -26,7 +26,7 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
|
||||
tensor_ptrs = []
|
||||
for lora_weight in lora_weights:
|
||||
tensor_ptrs.append(lora_weight.data_ptr())
|
||||
ptr_tensor = torch.tensor(tensor_ptrs, device=device)
|
||||
ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64)
|
||||
|
||||
_LORA_PTR_DICT[key] = ptr_tensor
|
||||
return _LORA_PTR_DICT.get(key)
|
||||
@@ -85,6 +85,7 @@ def _fused_moe_lora_kernel(
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
USE_GDC: tl.constexpr,
|
||||
launch_pdl: tl.constexpr,
|
||||
IS_PRIMARY: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
Reference in New Issue
Block a user