[lora/moe] Improve fused MoE‑LoRA kernel indexing and memory access (#32770)
Signed-off-by: 陈建华 <1647430658@qq.com> Signed-off-by: Yanwen Lin <lyw1124278064@gmail.com> Signed-off-by: kimheesu <wlskaka4@gmail.com> Signed-off-by: Divakar Verma <divakar.verma@amd.com> Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: ganyi <ygan@amd.com> Signed-off-by: whx-sjtu <2952154980@qq.com> Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Signed-off-by: Nick Hill <nickhill123@gmail.com> Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com> Signed-off-by: Yanan Cao <gmagogsfm@gmail.com> Signed-off-by: Xin Yang <xyangx@amazon.com> Signed-off-by: yewentao256 <zhyanwentao@126.com> Signed-off-by: Matthew Wong <Matthew.Wong2@amd.com> Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com> Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com> Signed-off-by: Ifta khairul Alam Adil <25082512+ikaadil@users.noreply.github.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Signed-off-by: Huy Do <huydhn@gmail.com> Signed-off-by: Micah Williamson <micah.williamson@amd.com> Signed-off-by: Andreas Karatzas <akaratza@amd.com> Signed-off-by: Kebe <mail@kebe7jun.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Alex Sun <alex.s@amd.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Liran Schour <lirans@il.ibm.com> Signed-off-by: liranschour <liranschour@users.noreply.github.com> Signed-off-by: wang.yuqi <yuqi.wang@daocloud.io> Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Shengqi Chen <harry-chen@outlook.com> Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com> Signed-off-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: Lucas Kabela <lucaskabela@meta.com> Signed-off-by: Richard Zou <zou3519@gmail.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Signed-off-by: Max de Bayser <maxdebayser@gmail.com> Signed-off-by: AuYang <459461160@qq.com> Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Signed-off-by: rickychen-infinirc <ricky.chen@infinirc.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Fadi Arafeh <fadi.arafeh@arm.com> Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: RishabhSaini <rishabhsaini01@gmail.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Karan Bansal <karanb192@gmail.com> Signed-off-by: jiang1.li <jiang1.li@intel.com> Signed-off-by: Li, Jiang <bigpyj64@gmail.com> Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Tianshu Yu <tianshuyu.formal@gmail.com> Signed-off-by: raushan <raushan@huggingface.co> Signed-off-by: baonudesifeizhai <baonudesifeizhai@gmail.com> Signed-off-by: Mark McLoughlin <markmc@redhat.com> Signed-off-by: sangbumlikeagod <oironese@naver.com> Signed-off-by: sangbumlikeagod <98077576+sangbumlikeagod@users.noreply.github.com> Signed-off-by: Matteo Fari <matteofari06@gmail.com> Signed-off-by: huanghaoyan.hhy <huanghaoyan.hhy@alibaba-inc.com> Signed-off-by: Chen Zhang <zhangch99@outlook.com> Signed-off-by: Orion Reblitz-Richardson <orionr@meta.com> Signed-off-by: Orion Reblitz-Richardson <orionr@gmail.com> Signed-off-by: marksverdhei <marksverdhei@hotmail.com> Signed-off-by: Markus / Mark <46672778+marksverdhei@users.noreply.github.com> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Randall Smith <ransmith@amd.com> Signed-off-by: jon <joninco@bullpoint.org> Signed-off-by: dolpm <34420038+dolpm@users.noreply.github.com> Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: Luka Govedič <luka.govedic@gmail.com> Signed-off-by: Joe Runde <Joseph.Runde@ibm.com> Signed-off-by: mohammad najafi <mohammad.najafi@amd.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: 7. Sun <jhao.sun@gmail.com> Signed-off-by: esmeetu <jasonailu87@gmail.com> Signed-off-by: Roger Wang <hey@rogerw.io> Signed-off-by: Reagan <reaganjlee@gmail.com> Signed-off-by: Reagan Lee <96998476+reaganjlee@users.noreply.github.com> Signed-off-by: Hongjian Zhang <zhanghongjian@xiaohongshu.com> Signed-off-by: Xingran Wang <wangxingran123456@outlook.com> Signed-off-by: Hiroken. <105287758+HirokenOvo@users.noreply.github.com> Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com> Signed-off-by: Tsai, Louie <louie.tsai@intel.com> Signed-off-by: Louie Tsai <louie.tsai@intel.com> Signed-off-by: Maryam Tahhan <mtahhan@redhat.com> Signed-off-by: Joshua Deng <joshuakdeng@gmail.com> Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com> Signed-off-by: LopezCastroRoberto <rocastro@redhat.com> Signed-off-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: cwazai <38356712+cwazai@users.noreply.github.com> Co-authored-by: Yanwen Lin <lyw1124278064@gmail.com> Co-authored-by: Kim Hee Su <wlskaka4@gmail.com> Co-authored-by: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Pleaplusone <ygan@amd.com> Co-authored-by: whx <56632993+whx-sjtu@users.noreply.github.com> Co-authored-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com> Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: danisereb <daserebrenik@nvidia.com> Co-authored-by: Yanan Cao <gmagogsfm@users.noreply.github.com> Co-authored-by: Xin Yang <105740670+xyang16@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Co-authored-by: Matt <156021403+mawong-amd@users.noreply.github.com> Co-authored-by: knlnguyen1802 <knlnguyen1802@gmail.com> Co-authored-by: Lucain <lucainp@gmail.com> Co-authored-by: Ifta khairul Alam Adil <25082512+ikaadil@users.noreply.github.com> Co-authored-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Huy Do <huydhn@gmail.com> Co-authored-by: Micah Williamson <micah.williamson@amd.com> Co-authored-by: Andreas Karatzas <akaratza@amd.com> Co-authored-by: Matthew Wong <Matthew.Wong2@amd.com> Co-authored-by: Kebe <mail@kebe7jun.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk> Co-authored-by: Alex Sun <minchsun@amd.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: liranschour <liranschour@users.noreply.github.com> Co-authored-by: Or Ozeri <or@ozery.com> Co-authored-by: wang.yuqi <yuqi.wang@daocloud.io> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Shengqi Chen <harry-chen@outlook.com> Co-authored-by: Chauncey <chaunceyjiang@gmail.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Lucas Kabela <lucaskabela@meta.com> Co-authored-by: Richard Zou <zou3519@users.noreply.github.com> Co-authored-by: Maximilien de Bayser <maxdebayser@gmail.com> Co-authored-by: Xu Jinyang <72930776+AuYang261@users.noreply.github.com> Co-authored-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: David Ramon Prados <davidramon3@hotmail.es> Co-authored-by: RickyChen / 陳昭儒 <ricky.chen@infinirc.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Fadi Arafeh <115173828+fadara01@users.noreply.github.com> Co-authored-by: Eldar Kurtić <8884008+eldarkurtic@users.noreply.github.com> Co-authored-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: Rishabh Saini <rishabhsaini01@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Karan Bansal <karanb192@users.noreply.github.com> Co-authored-by: Li, Jiang <jiang1.li@intel.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: tianshu-Michael-yu <101950379+tianshu-Michael-yu@users.noreply.github.com> Co-authored-by: Raushan Turganbay <raushan@huggingface.co> Co-authored-by: baonudesifeizhai <85092850+baonudesifeizhai@users.noreply.github.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: sangbumlikeagod <98077576+sangbumlikeagod@users.noreply.github.com> Co-authored-by: Matteo Fari <matteofari06@gmail.com> Co-authored-by: Harry Huang <vastrockhuang162@gmail.com> Co-authored-by: Chen Zhang <zhangch99@outlook.com> Co-authored-by: Orion Reblitz-Richardson <orionr@gmail.com> Co-authored-by: Kevin H. Luu <khluu000@gmail.com> Co-authored-by: Markus / Mark <46672778+marksverdhei@users.noreply.github.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com> Co-authored-by: rasmith <Randall.Smith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com> Co-authored-by: joninco <joninco@bullpoint.org> Co-authored-by: dolpm <34420038+dolpm@users.noreply.github.com> Co-authored-by: ElizaWszola <ewszola@redhat.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Luka Govedič <luka.govedic@gmail.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: Joe Runde <Joseph.Runde@ibm.com> Co-authored-by: monajafi-amd <mohammad.najafi@amd.com> Co-authored-by: ruizcrp <ruiz.crp@gmail.com> Co-authored-by: Shengqi Chen <i@harrychen.xyz> Co-authored-by: 7. Sun <jhao.sun@gmail.com> Co-authored-by: Roy Wang <jasonailu87@gmail.com> Co-authored-by: Roger Wang <hey@rogerw.io> Co-authored-by: Reagan Lee <96998476+reaganjlee@users.noreply.github.com> Co-authored-by: Hiroken. <105287758+HirokenOvo@users.noreply.github.com> Co-authored-by: Xingran Wang <wangxingran123456@outlook.com> Co-authored-by: david guan <102001211+Chenhao-Guan@users.noreply.github.com> Co-authored-by: Lukas Geiger <lukas.geiger94@gmail.com> Co-authored-by: Louie Tsai <louie.tsai@intel.com> Co-authored-by: Maryam Tahhan <mtahhan@redhat.com> Co-authored-by: Joshua Deng <91448271+joshuadeng@users.noreply.github.com> Co-authored-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: TJian <tunjian.tan@embeddedllm.com> Co-authored-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com> Co-authored-by: JJJYmmm <92386084+JJJYmmm@users.noreply.github.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -62,6 +62,7 @@ def _fused_moe_lora_kernel(
|
||||
num_experts,
|
||||
lora_ids,
|
||||
adapter_enabled,
|
||||
max_loras, # <<< PR2: rename, used for masks when grid axis-2 != max_loras
|
||||
# The stride variables represent how much to increase the ptr by when
|
||||
# moving by 1 element in a particular dimension. E.g. `stride_am` is
|
||||
# how much to increase `a_ptr` by to get the element one row down
|
||||
@@ -83,6 +84,7 @@ def _fused_moe_lora_kernel(
|
||||
num_slice_c: tl.constexpr,
|
||||
top_k: tl.constexpr,
|
||||
MUL_ROUTED_WEIGHT: tl.constexpr,
|
||||
USE_B_L2_CACHE: tl.constexpr, # new, enable .ca load for B
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
@@ -104,10 +106,13 @@ def _fused_moe_lora_kernel(
|
||||
if moe_enabled == 0:
|
||||
# Early exit for the no moe lora case.
|
||||
return
|
||||
# The grid size on axis 2 is (max_loras + 1) to handle the no-lora case
|
||||
# (lora_id == -1), but sorted_token_ids and expert_ids are allocated with
|
||||
# shape (max_loras, ...). Use (num_programs - 1) for correct bounds checking.
|
||||
max_loras = tl.num_programs(axis=2) - 1
|
||||
# The grid's axis-2 dimension is max_loras + 1 to accommodate the -1 sentinel.
|
||||
# This guard ensures we don't access sorted_token_ids / expert_ids /
|
||||
# num_tokens_post_padded beyond their allocated bounds if an invalid
|
||||
# lora_id somehow appears. Although the caller should pass correct
|
||||
# max_loras, defensive programming prevents accidental out-of-bounds.
|
||||
if lora_id >= max_loras:
|
||||
return
|
||||
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
|
||||
|
||||
# calculate pid_m,pid_n
|
||||
@@ -136,10 +141,11 @@ def _fused_moe_lora_kernel(
|
||||
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
|
||||
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
|
||||
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
|
||||
# remove modulo wrap-around
|
||||
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32)
|
||||
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
||||
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int32)
|
||||
token_ind = stride_tl * lora_id + offs_token_id
|
||||
offs_token = tl.load(
|
||||
sorted_token_ids_ptr + token_ind,
|
||||
@@ -176,7 +182,13 @@ def _fused_moe_lora_kernel(
|
||||
# GDC wait waits for ALL programs in the prior kernel to complete
|
||||
# before continuing.
|
||||
# pre-fetch lora weight
|
||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
|
||||
# add (offs_bn < N) mask; optional .ca for B
|
||||
b_mask = (offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N)
|
||||
if USE_B_L2_CACHE:
|
||||
b = tl.load(b_ptrs, mask=b_mask, other=0.0, cache_modifier=".ca")
|
||||
else:
|
||||
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
|
||||
|
||||
if USE_GDC and not IS_PRIMARY:
|
||||
tl.extra.cuda.gdc_wait()
|
||||
a = tl.load(
|
||||
@@ -276,6 +288,7 @@ def _fused_moe_lora_shrink(
|
||||
num_experts,
|
||||
lora_ids,
|
||||
adapter_enabled,
|
||||
lora_a_stacked[0].shape[0],
|
||||
qcurr_hidden_states.stride(0),
|
||||
qcurr_hidden_states.stride(1),
|
||||
w1_lora_a_stacked.stride(0),
|
||||
@@ -292,6 +305,7 @@ def _fused_moe_lora_shrink(
|
||||
num_slice_c=num_slices,
|
||||
top_k=1 if mul_routed_weight else top_k_num,
|
||||
MUL_ROUTED_WEIGHT=False,
|
||||
USE_B_L2_CACHE=True, # new
|
||||
IS_PRIMARY=True,
|
||||
**shrink_config,
|
||||
)
|
||||
@@ -377,6 +391,7 @@ def _fused_moe_lora_expand(
|
||||
num_experts,
|
||||
lora_ids,
|
||||
adapter_enabled,
|
||||
lora_b_stacked[0].shape[0],
|
||||
a_intermediate_cache1.stride(0),
|
||||
a_intermediate_cache1.stride(1),
|
||||
w1_lora_b_stacked.stride(0),
|
||||
@@ -393,6 +408,7 @@ def _fused_moe_lora_expand(
|
||||
num_slice_c=num_slices,
|
||||
top_k=1,
|
||||
MUL_ROUTED_WEIGHT=mul_routed_weight,
|
||||
USE_B_L2_CACHE=True, # new
|
||||
IS_PRIMARY=False,
|
||||
**expand_config,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user