From bf13665dbe35be9f08396160f6a3c126ae654b20 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 13 May 2026 22:36:58 +0000 Subject: [PATCH] Implement TileLang NVFP4 mega_moe L1/L2 kernels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - nvfp4_mega_moe_l1: L1 GEMM (gate_up_proj) with FP4 dequant → BF16 GEMM - nvfp4_mega_moe_l2: L2 GEMM (down_proj) with FP4 dequant → BF16 GEMM - nvfp4_dequant.py: E2M1 packed → BF16 with UE4M3 block16 scales - tilelang_kernels.py: Grouped expert GEMM with TileLang-compiled BF16 GEMM - Full pipeline: L1 GEMM → SiLU+Mul → re-quantize → L2 GEMM → output - MEGA_MOE_STATIC=1 bypass still works for pipeline testing Current approach: dequantize FP4→BF16 then run BF16 GEMM via TileLang T.gemm (auto-lowers to tcgen05 on Blackwell). Will be upgraded to native FP4 block-scaled MMA (tcgen05.mma kind::mxf8f6f4.block_scale) once TileLang adds E2M1+UE4M3 support. --- .gitignore | 3 + src/nvfp4_megamoe_kernel.egg-info/PKG-INFO | 7 + src/nvfp4_megamoe_kernel.egg-info/SOURCES.txt | 11 ++ .../dependency_links.txt | 1 + .../requires.txt | 2 + .../top_level.txt | 1 + .../__pycache__/__init__.cpython-312.pyc | Bin 0 -> 754 bytes .../__pycache__/nvfp4_dequant.cpython-312.pyc | Bin 0 -> 3883 bytes .../nvfp4_mega_moe.cpython-312.pyc | Bin 0 -> 7430 bytes .../__pycache__/symm_buffer.cpython-312.pyc | Bin 0 -> 3431 bytes .../tilelang_kernels.cpython-312.pyc | Bin 0 -> 6192 bytes .../weight_transform.cpython-312.pyc | Bin 0 -> 6637 bytes src/nvfp4_megamoe_kernel/nvfp4_dequant.py | 71 +++++++++ src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py | 100 ++++++++++--- src/nvfp4_megamoe_kernel/tilelang_kernels.py | 136 ++++++++++++++++++ 15 files changed, 311 insertions(+), 21 deletions(-) create mode 100644 .gitignore create mode 100644 src/nvfp4_megamoe_kernel.egg-info/PKG-INFO create mode 100644 src/nvfp4_megamoe_kernel.egg-info/SOURCES.txt create mode 100644 src/nvfp4_megamoe_kernel.egg-info/dependency_links.txt create mode 100644 src/nvfp4_megamoe_kernel.egg-info/requires.txt create mode 100644 src/nvfp4_megamoe_kernel.egg-info/top_level.txt create mode 100644 src/nvfp4_megamoe_kernel/__pycache__/__init__.cpython-312.pyc create mode 100644 src/nvfp4_megamoe_kernel/__pycache__/nvfp4_dequant.cpython-312.pyc create mode 100644 src/nvfp4_megamoe_kernel/__pycache__/nvfp4_mega_moe.cpython-312.pyc create mode 100644 src/nvfp4_megamoe_kernel/__pycache__/symm_buffer.cpython-312.pyc create mode 100644 src/nvfp4_megamoe_kernel/__pycache__/tilelang_kernels.cpython-312.pyc create mode 100644 src/nvfp4_megamoe_kernel/__pycache__/weight_transform.cpython-312.pyc create mode 100644 src/nvfp4_megamoe_kernel/nvfp4_dequant.py create mode 100644 src/nvfp4_megamoe_kernel/tilelang_kernels.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..9f7983e8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__/ +*.pyc +*.egg-info/ diff --git a/src/nvfp4_megamoe_kernel.egg-info/PKG-INFO b/src/nvfp4_megamoe_kernel.egg-info/PKG-INFO new file mode 100644 index 00000000..59be8beb --- /dev/null +++ b/src/nvfp4_megamoe_kernel.egg-info/PKG-INFO @@ -0,0 +1,7 @@ +Metadata-Version: 2.4 +Name: nvfp4-megamoe-kernel +Version: 0.1.0 +Summary: NVFP4 Mega MoE kernel for DeepSeek-V4-Pro on Blackwell (TileLang) +Requires-Python: >=3.10 +Requires-Dist: torch>=2.5 +Requires-Dist: tilelang>=0.1 diff --git a/src/nvfp4_megamoe_kernel.egg-info/SOURCES.txt b/src/nvfp4_megamoe_kernel.egg-info/SOURCES.txt new file mode 100644 index 00000000..f3641fa0 --- /dev/null +++ b/src/nvfp4_megamoe_kernel.egg-info/SOURCES.txt @@ -0,0 +1,11 @@ +README.md +pyproject.toml +src/nvfp4_megamoe_kernel/__init__.py +src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +src/nvfp4_megamoe_kernel/symm_buffer.py +src/nvfp4_megamoe_kernel/weight_transform.py +src/nvfp4_megamoe_kernel.egg-info/PKG-INFO +src/nvfp4_megamoe_kernel.egg-info/SOURCES.txt +src/nvfp4_megamoe_kernel.egg-info/dependency_links.txt +src/nvfp4_megamoe_kernel.egg-info/requires.txt +src/nvfp4_megamoe_kernel.egg-info/top_level.txt \ No newline at end of file diff --git a/src/nvfp4_megamoe_kernel.egg-info/dependency_links.txt b/src/nvfp4_megamoe_kernel.egg-info/dependency_links.txt new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/nvfp4_megamoe_kernel.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/src/nvfp4_megamoe_kernel.egg-info/requires.txt b/src/nvfp4_megamoe_kernel.egg-info/requires.txt new file mode 100644 index 00000000..06d97fde --- /dev/null +++ b/src/nvfp4_megamoe_kernel.egg-info/requires.txt @@ -0,0 +1,2 @@ +torch>=2.5 +tilelang>=0.1 diff --git a/src/nvfp4_megamoe_kernel.egg-info/top_level.txt b/src/nvfp4_megamoe_kernel.egg-info/top_level.txt new file mode 100644 index 00000000..0c0c2376 --- /dev/null +++ b/src/nvfp4_megamoe_kernel.egg-info/top_level.txt @@ -0,0 +1 @@ +nvfp4_megamoe_kernel diff --git a/src/nvfp4_megamoe_kernel/__pycache__/__init__.cpython-312.pyc b/src/nvfp4_megamoe_kernel/__pycache__/__init__.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71a6b920f4b71f6c57dd88ccb159f361534f4cac GIT binary patch literal 754 zcmaJ_0u>|q+4D3r@6CEY4~HJo@%F+#|3nb_W|HgBc^jKU z{d$iQ6ruztgkwQMBCOC7c4!MHbgV8Clh)ESwxTI6Z#BP0i1q){f@rKIXsVwjOL=`AfTwDI+q1arJ{!brsIw@r*W(3{jTOl<91K>dwt~^ zRdZDudTYatg*%7FUc@OT@PS#Wr)_Vq}tpXxxvg@rNxf?i(Q4kjC<;pZ;FAGJYcIg5D literal 0 HcmV?d00001 diff --git a/src/nvfp4_megamoe_kernel/__pycache__/nvfp4_dequant.cpython-312.pyc b/src/nvfp4_megamoe_kernel/__pycache__/nvfp4_dequant.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f38c12b5b6d37d57060fc56f46612407e699dd18 GIT binary patch literal 3883 zcmb_eZ%i9U7T;alYkSwm-~?R@0VZj3hk=;b5K3rsJ;v?}-M`jPu`I_d5MTURn#wNiBViEo^^Pt{4iH@kKq zK!cQYq}`b}Gw;ot_ujnU`-j`jQP95qpEY^YNm2hMM7^=K!EPA_3zSIRp+s6_QuG~K zv|u`Et#91xMeEV?QLAYCz4eae2l?2e`F>CyE80Px^&aDMW%?cSc zkrUK}ltEcFkxHlutoS+ZijLL=7$1AcCsTyI=wO;(+iv6_`L5qj^|6tR>AzTFE9u~eQ)?q98KRAJ3* zM_KMsjiW&4X_3wWb&QOyIMIrV?8lE$OI_ z_WC@Nx8$vPTMT0EGxkx<>IIge@@&$G_PxDi6eIp}Jv~P#(JtCrs6Wx7ZI1q&{Ueq1 z9NzP&HkE8TytdEZ-ovYPU%gHF9NBMS{dR@~7|KFW_66&4YDQ*fKbjU&SqRE0EZ27U zNPtFDlAsQTaDTclnt?D!(1zuhGGvShWd`jK>h$~lJxGa$CPOh<%1$Z1`bqyP&ZGX? z6R4kJJ`!>mPT@4psHW&!M&$34en=|nLjAs{)iPAs{vK|)O`NseNPaJo-}lfS(z-=V zq{n@n&Zv@ZQzbbP*KLLobQ>|opl(xCP)N7SScwZ$7-*y8A`(3ISXRm^x^+5%@9OTS z?BTP^Bu;cowlAnN;S2Bu;3MNcod?NSPGMmhi;5k#5o1C7?w7rBDUExjDV&Lu_D^1`^6eVmzR3?Pv#YKZ*V@}_;Y}0(dOi3RKTzd|HGa6l zU#aq=8bA8a=IV_J?Z!lfe}B{Z{w~~)(pfW+MxLZSn0%5bKLppjV4`I5Of9_;{~c)i z9^VluqW<^_7LZ{i%H;FQurk!XgmgrLA?YZ_h?dW+^<;EZx`8yrBpVJ~14z*(T8uPg zG1E}NmbZy4NuAs&)Gn-BM0jR7UVgN#LN=dDcB8~0MN5-@(lR}N1heA zy#4tlh+GaVnFHUMn&urMPew#z(ifb0yXcziq^P>KCS}}!%RBR&=LPhyoj^ao{slhI zCr=zOKaiq(e;gECdE1dbx4cX|{-PjzJf#&053X&K9nZ4n4$8HGq3k;IO|_f(8X78ep31qIEj%*3D}l#=eEOd=dk;mb7;LKk{aJQ0f@ zoDTR6x%;ZPS`9aZK$)k}$XqdS)WAU`MG@3|poD@WP%%#9I^>bEks!je@N}%V7cOet zWnW=wun|@>0oK76HU|a{fa8&dekRzC`B2z^xxj#N&Nit^wn-^2lhh;g@EhMIz!!s$ zw2j(e=Z$Uf`jhEj)TST5A(IR(lQik~@p3aENFI{O1DD<8N6wH*i6fIrMkaBs+YK!M zTAO4%g#|UVhp4)vuA;BOz+f2>4kkOOGbu@D;tAbhq9Yb{D~TQ536T?uCU6QSNZ{xe zhy&f3!?F}g3(6#*Ni-CN(|G7%uaFUemmX7yro`5JO>QHHov6Vt4yFSkRSJcpfdTm< ztek>R`4}t#2z-X(8fRKdt@F1R-kW=`;yewA$+Z#|;%mx@=nnqDryHP^QM ziq?3pcw@_Ta#ksgn}B>jyHr>#RQU77t00N@%)V0gE%hw+JjkvVRtlAti(gu_me(}z z>l$}?7`&}zVXl4ooppL;VqMVAzN$42!U<4z_-0!cJaZnX zEZi&ZN~2%n{02wE$(eUc@6Klz3Uh^u`&{w*miMKqw@dSOtp(SEo8GR9cc3`-*wHxC zP->W!HAhQvc+1;XjzCp0yc{gYH1AJg-Q}6>)?95>{D9r*27Xqx^H zW2LRTE{b;VupasYx_oTl9LUc^-D{m!}2^(MpoE56vtZ55usfWjst zFbPIr1xJQWu=MFjIN({6smVGM&a5lpVzJ$sac652wOLQX1MjYcS8ykMLT$n?coKDj zHxUqgiJ;(5gdB_*7TJZ<_H^av!r5}!E(vu);E}_8d&HX0MB{6Y5jflvD%ZzlVHz$gizKDUS#nJjbJvICaTsVwNzSJ=F{$MhalqG2#(N0143Ifq6S;hj z%PH~#iQrOK2Fv;*yISmkzPtFti z#XK(wus||#GtLdaKQ%l(J9DAm60C~C6m9wCRp>UG&WPi@G}lFvayADq0?}kL6Ysfr zkwmV5Xcw2oOeRXM-n=q80WzG?6cEVS`wzI_|cCOIcc7vJIL#B7$klaZ5)R89g( z<`qShw9G0|^SPX?Xex=nKQh&iV(urMWVk57$!Jf-F#Qr%m`A95aXucR)x_ zjEu{?V9$f69!X1DfApm(BYCL5)El){OW;i=>O8N&^0H!9R#xeNy)jTZds>twRUpvX zBVA1`E9qBWO-pmu)~NJskOXL!XNkZuJPV>X5|ClNlc$=w-wSybRO zWcxsfERpgl!f5toO3H99NsWNEfE3{l-VyD)gcWcXCD7<4k=_Hkq;Bg(@+9&sl<(U|Wk0@Fm)VKn9? zJata&&GvAb%-u=#TsFd1WnS#dUgGj_4h|zsPhJ#Et%Iv?v{%<~rv7khD$R5-Oq!`< zn5PV2;8PFG_7q|IX&5s=DaHuJhp!KD@yTIsW_D`?@ z*TC=jpTLNljLt0BEfr6JJppg)wXbM(&{mJ`5o*?4(9f$ojvdA9=hG}4XBeebgZ^d` zoO<>f1LhOYBMa|YMrRjnt7Lx$PYvZj(*o|MDlJA}9={f_t98~M zPhxd;-LHj8EsXBBZOnR|Uh_XNS53uSHR}PbzBE^z?$86eQ+GYGB(NUTgELIjlmA=T z?(wirjm;~Ny?o{18Un~6FifR$V1Ohr6!-=(QE{uwfPd6XymXDhbeQ&n4ssM4Oh|jJ z0RRZ5M5WuZl9sMl-~qk2fQA~Pi3OSFZe}O1O%9N0(^&rZ?sxfi+sI8fVIlqRyXAB# zhM!k%%@mdYq@0^Zdv3qJs$vkEsByl_@ai4VKy=N!fP9_!G~VM z1xN}sAqbxuuABlmY`6q%H76P_O$HNhxJ{&1@J=guKNVDyXvlD_aEKDl6~F_-m+9dw z85v$En+IihG(r3KZfqb0jbla$YLrsP!)rsvSw8a}Kw*K4@PxOqO4GT1SLoiIE& z2~5{-xK~UF^DLKY;D}HhczaMF73sx|3-fYTyddX9DVgDyF9_n&1uCBIth^v!P?ep>^}*#>uVsHr^=waWT+v z|JJ=*TlM#D7emzQRz^)9~2ScHn8-sz30YEDZg+w)u%Cupd6Ry|$g%J+sU2 z&g|9f4eece#;|{MqIjZvzb*Exh7Dc(iebGy&pgce*r)9e+y5^A6aB~fPor1A486JS z{4#X*ADap{o&?Bt^N!~*&{G(B=4LMRe=0we|GxQ)j?X)O8jE93@0X$T`!nCcu3y_q zZ7=Lb_HKNZISkAcMtHaiUqK=^Kw5ZddI5OXC;%yF_Z2iMM**m}1VBmy_GrYc3A1L%q9NrZ zmVrz#03p2<5YlTwhyumoIX! zauMop<44v-ZvQ`jk(DmkPSRp4P&AnpWmj*&3;59JKQmc&on;vZuo<0sWc3EiU$gdC z?^#sew1*WY8rU$U^=OTN)+4tyod|07OD5W~`;;J;yh`*rXO*-DG+Nd8Ikc;m_eq`0 z9IqwkMfHN!W??aaZ-t?YvPLu73u+~}B(G{nYQYWxgZ=taCX$?rslW13J&D2Ld`2!vlCe}w7WiW*;)V$msuz=ReKzUd1Mt?X4S2;-+^_md)Ixs zgO>c1t0I-CO2DGs?aX=rm@MK%RkB)x(25vQm8{kfw4%JLlGPfn*!OtFzH9Zsar__l z>d}L*Xz}WySG4$a?<-pTy8rl=Pka>;J@#wMqp8&}lF^yUQw z;HsoMe-yHNuQzK5%rxoEb}e07%X+JZP;Iug=&ib6udCW=8|WT_cvaHd?2%CWqT&31 zfcnvpS2;)}E6cos0SHy45wQ~X2oSlVu5AoE!{x6Ia9M~oG&KrwF$zR9MrKF6jthxF z@G3*p^c*<%5Cp?8i2&5xhei>hZfSfx4klI1siL^ZE%gDC!NOqZ0n?4ZelM>hvf4$; zWE#pPtJp>3R-75P>Y~%&5XYo4SaCpmv$?@oEY?L*h?cOod3=q&!ZDi`A$~(4PpAmR znbAKCn_1ICh$N#x*HvOiX9RKw!cs?IJxbRIj@t1~uDPM0rUuu1up`s8!#Y%0`UGViq%`p9e?kIlT_dEk_{j1r(PpIuCNbRLg{70C z(grUIT1M)9Du{oE3m3*2?25?=ZU1q_WPs>#8Ls4fURqQT3Jr%O8NO6rf|#)^@fpLZ zrZaiyfl5#9H05?sA?H@P=m~>eHCiC%Dya}J%R;QHe9@}vOP9{@S7^NvE-RjSvkY(P zR+-dv0QZACGn_cM;es6eV<%o?cmZ0QDs$c{a!WmDQ+Ev~hS3d2K4*BXIYq&6!(ixh zQx7}EL02gD;u9TUl*ln)U{;xArtzpo(4$pu!gqfLKXnf%Ao#I;Ohe;l>qhHV{YE?b zvBCTA-+Lc?*t)>|>-Vk~hJRh#^3VPgPwMKyM{VvXcAVRKcU##x_37IW-`+j-@cd54 z?t9>}Hui(d8u-R9J>L7Z_iF##3obA?rH#>##tPT&o-H;+%)&Dt6KK01zZc)?+#1-g zJ69Mk`dc3RPagPBZjEgv_WhB*1}P_+HCr&M$nQ`}TBL<~ZA4Y(M?O*vGN`6VaWPUFT1HKlbhD zdzU|Z>x=WBpZ}}_ee6jVo$M(VyzCD0!-0@*#QW4jFp;Y)?lVr}bw?U`cJ*ET0NCif%#4<-+rzH!(4m79Upoj7c}bl_|*jNBcC ziu%(SJMeDxZ8yI>x6A#f*FU;`cl2k@7RouJj)@V_@>-4)KHdNbai%e+Ni;R zRIMm2r4?B+YUV_(0x=K{=mj^gI_mL5h6;je`wP#z;0)w5;vXs^G@*Z^{taYs)-3xQ zM=R@m-pjDQUodt5#?1Z;a~sst7c97op~ga8F<4*lgT@!UMSoqv15HgW`%Ji4-%tnw z*KcfjaA$L2V_{R;koM1B-m5umdH1kkun_ureba;K&BR7x^Y+H={kQsdlZVX%hxLC@ k2>!O|Ewk1%9bokv~V?2$CO1Xa)hK+s~}1(tAQmYaA3n>(HzcqOX)=3 zk=;8gBFmtD7^qS75xy8BDXhGrLVfs|ypn%E6r4iDHUUyJAOC1bL4mwFd&d(ACaKju z-0sZI?Ck99%LwB>1E7%kanG zk7GF(HwNzV*;pl%8m#`~&5C7VmzcV8yPV5mqU4H1S-ZJn+IdAMMaM}K4M43>NE2+7 zb*z}SV&Qd7wTd_+%hz$Kpy}9A42(egJQ zR8dN7c}9NB?6O7&SC^I-)z#OR)GJG`-dLvG^SW;^zwcNN9gYItKp4UOAa*y|uo1e8 zvV0|+;;Gy{iS{R0(luDeN`YZmQ{tU&X+)gL-= zcS|^|i%AYFtT5ra3+KOG>7L&^@VqA}tz@?6701=qfk)F_bIaRHxp?)BIWUZ#$42Uy zQFU@Z#Ovr<0jz)V#?svC3;m1SVkxhh25bUQI`1h&T+q-k!2tr%9@jhAT;Fh=fRhx< zu4%6;o#}wpx7@r^whc@Q6~^(a8&?)Po)~bg!|Kw!QX)mz#Zf%eqv?^J+hjo(hX1ex zeFN3F8d8tRwGd#zb)#AsuucdeRO10-E4@||0OOA_5ilXZ+#s;gNrQE-v-I6aO)$7B z?D?)#6FX4U9S9_E4hK+o1T6(4#Mboed%8;xpjzZ_;k}N=pe-aLzdcpD)FXE+Sk)c- zJ+2QNeI4B}_#FZ$WCP4*jN^(0;3IFH}Wk8xqh$Rc4vQ%P9-hK_Bj< zh=r=uNdpV1GgR<`p2qC5rTRJ6p@WvTt+ENXn8&t51rVj?X>7wZ3~Z~8S;aKrWlswm zrsiUQ3fPFgL3zDwXmn676M}75^%bQ8wn}b=^04e@l-s6a8K4&~Q=X|y1s{CQn@(wQK2+Jvi1s7oj0Ab-ko?#tx zV8WuA4m&wDuX-kMRR4H1(K`oXF>zQy!k!hseYZ5TQMB+(v4m~CpuIB#N1X9q+_Y8P zEY3JY_fhWwhPPbJV`AgNjMIIk;YFpAs+zXxs%rIA_YyO~B=cRz1z)&>z8)FB$NyPw zO@O{R8l?w^}$O5?3$ z`r(EBeqm& znanVT>H74}7BGx3hA1$coqQ0xA7fIM82wWH(oPQOXBm9}=;zKqxOSfj@agG#dRGMc zIiP<5YQ)i*WFueCdulKNnR;e-6d2Ah0kLmo6d(UGo_O!|ohu(-{pjk)D<7@wy!>$V z6aKU4r_s-bJ{@Yl^y@a_UgIWPljk2;_pO7;+5O4c=DEucGoPP6SiHW!c>Q4U*8bwH z<|}UjW0p&{k_uDlAer7zrkfL)wh%h?{I@7HFxw8pZiirew3ZUs9od?NRh7yrlx0S_ zz~H#5{;sSQI+Td28bw`I2@CSSBJvyz8X(&6ZoFzYjVJU`~8 zR3sRxGh1g|L6ZRR-eZ=68i@Ib-)ih_4#Dfj*Mz;!iZ@Y;tNwAZP$)uugcVI;PWZj5 zlofqzH(50k)QpPnvpwZhZw=X3k)EUkWwDduJd>ov{~enS)7_6X&yMu`(>EY9ux1&4 zM}h7RIvPXq(UvrD@AvPo?Jl*XQ_YbJhe9|mwozCRJ=3S+ggq`aRCEZd5@^CpU>U=7 zplZKehJujBdJ&!>G|W;>1lwCCDcUp)s(DGszig%jVGmN*3%-+n={wHL<%f()+uJ))Kr?ZAds4=|4^w?+5VU+?Vni5Ms#Bm>X0V>a|u$XYTEX^cRt6Y zW^~eC%kTI5d!P4tpZEE=KRO&X1nJNJG>v>xh0xz{Lr+w$vKA!}nnNraL@dD?Lc}1U zmWDxtS{lQqL6X2@#*jH|8MK70gH{3s(O?N%%9^ecgEp4rh!K|#yk4^AXY^a1(Qkc5 zf5|iYZN>Zj6l=d~8l?Vbx~SrK?EjB02hio1G`O9yPYnp!27)23hv5Ts-?{G7t#mIJ zVCddR7k!EocrGN;UnM@EJG-0q)5p7dd+G6@G)&`pEO#Zw@KSJ!k%AH4W3vqm2Sqv@ zVPheVjtY^nAj^rgG|bVyh`?!!M{|(=h@g%f(kvMNG>IJ+JcRT(1iDSpQ`o4Par_C>)=o{#Iv$yNmiH?D;y(&TDTre;!iS!=& zY*%Y<3oZJX5GP`wW_mXp8RvDB!li(vf&Nzv+tHXHaJ)n_QJ{qJ4Y$+B^y71QbAZ7t z+KW%c(?X0d;MiD@p>-a6G%f+7fGGm!14N#s{W0E$naG4Z^c%d7%TK2nmKA9pI6TJD zk}tsVuk7=L!;Eb-$g}P3;R%16f4{%gGZc#WM!o8RX)zX!Mg)n*Jnm}lZQ7%9V=q0< z?Fxlx5$?}7OiK}44i#Uey1YSM&eFZT9qQUJGZ75O!f+xX5(^ARW0HEuaJQV$z)e6d zX_p!k;hq_KyT}E?z$1>`Uce?1lOjb%IbNiPCTVTmUfU1@RB)^Y08qIBIwl4|5Or=f z_#)vbNGGdtGbcSB+avtE``iYlGR8+4I62oGZt_YI?~uP~zfzX#jB%~u7H_Pj*=Jc- z2M8-8BmCFi2bFUTT_dg;t{JbHu94Twb0~#UM9Ppdrc5a^W%gU*7TF@@7mp#KWW6O@ z>R_ZFNf=E{BPPAJUa~}hxNW!vd*Q?yr93SoR;{lvhy4wHgJjnGMpV_Qx3uyWyjh%B z({r#4DdKx$k11pTtLAGs}4 zU7fLDh*cki51tqYGNlVb!E4flK$TFeTLSPhc$y7@%N&w&H-ykgj=wOPPX(hz?l@&nmFl=h*@z=>|VLJTNo%~4P+J`Ub*a7>|?5ZEKHILrtft5}Epp$H>^VlLJC{Q{#J z3dN!hv@|QEFT^o|QgS*Nr-YWiJ=r)M33H8+D98Ik z%y=WqjWzOP{%C7MI5HAx6a}Bw^oGG<3P(7vrkNTgFjygQr*du40|zO=hRKUhGqv3t zW=1*HQmGxDLn^)E5l?>3iYP5GI{yJP7n}{w6Y~d_U3+ey`Q*aQ z3yW9o>{@Q<{PO6E<8&5LC9h@%E-jsXYw5zp`!%(5@65h4KfY*L+S&H`wl7}(?B(T} zuH@^_9&pv%bM4Hyb}m>JYnFB${JiOlgP$E-c6B9>KcpPf$5Y2=j;~Pl535~s7iTa2 zy8JhFAJ;83Ed-Xg?Z4Z;Qr-6$86B-N#Qn1JIrFS}=1RIE9k@|He`Vpo$3IxAYtGcP zES_H}JF-4ro__1b=J_EQ4;2}on6F(p@TanttOa17*pQ=ix;xc9b0FQ39-2RJBec-D zSaY*yX=_`i{@@DL{!Ja6{u|7KuW|OR=l~Y}&TwpBH!}VKk=-R0w^?!IZp5ov7saYy zlwuqSN`ezsDCPk$BjA`})8H-p5 zTG)vPog402Xn@|Wc=B%{`&R-zuv1B9rhMjj`mOoE!qk$jJ#lo^V!IyBST^5x)~EY3 z&TUK9Z4XJuwM*|_N;joB8S@xX2 z?S&_yP@1#tRedgCZ6rDh##dA*wU(Qym<3LX3A_)qX+B|p2y3{t--895Loyo453N^J z>qy4UGEsCC*2gVzYrI4*VF|2h<2KnQn@4v8ryVB|Ymm(!g6}2UK19DV>7K}s2-#Q+ zjWuCCAGZS~MLdBxCFQ9fDPo~z68xLIH7aJGW$HG0D>ozV2%>nYWYdfd#Ho zYt;q(EMoIUidfDO;N8X~SIgx(&*CoGBvf@kXXG=C^i)x ztPBh}v^snk0A`2^IpGy!Ls;eF3s5XMc8j$@neOUrEH$fDmJ?akJ;B;tI16iDgTIJl z<^;-aMRhNHMBSh=rs~AW8FIC9>%GbsGnFsC+mq->TuJU+wU*A5r)-IiRXa69zTduN zI+5NoADa&>HFVu&?)ER$oLDxUNVcz1RnvW`zI0QD+LGv6byQAYPF+rPKXg<~Urb$0 zZ<;@IoBM>n$=^M<;^;|qXRWBpmFT(esF-Q`$@#~KAm1R8#QjYT3!`_o-Q@3tzHIs| zvb?D;NzI&j-{wZ)BApQ=`kuT*`MrMt=sYoX1+Q-d}C|!vn5k}DBp1h)tB)kR zVP?&a`NNAZE$ujvakbrZb!J?hcbk`8-O1ytRK?8UWop~}xrIRHrF|J{-wL(u>n%W2 zwi1q!ZAMf@qF3<1ohXi+>BS#HIJUTL-47EBV_Z4^MFOHkEca>xjZ1+2GRme~IU5K1 zjF{@~XTlJSi@GYy%LP3YlW6T@i%7?UywuW6Yg9b-1&;Oy1yS<24Fa}$sxZ0D1+V5D z?k#vVx%2p=Z6a0+nviOuxoM+ldbRz$qF=ZGbAJeb5qoXmIG3S{?eh)G<*y`8-nVZ~ zch8?%+;r#sU3u9)uw)ugdFM6>I8a^B9${X%OBIXP%R&U>6|j02+>j~dy1VGX&V3BR zAH}NIH0N=<8Xas12RvG5?qjK{V>C8^B}S@0DD1&C{KV8F+(asA>UqDl9@4(29ueM# zajf>ldys)>6U5hsQo^)GBf{|)_lrBHcg!m(!FOjKu&5p6`pDpfw+lj_MD zq5h!UHCI1dpCzFV!dF>4w|91L)`DABH3x3JLzEC_ mi99TBaZ|J#W}0tsOP35|*4G>M@=Lc0M%my)CelKnxjBjXu6HWNF{jFXa` zsHWR)9a;&8trWBEilFTZs#YqMRw|F|(?0BrUGFN^h=-{Avb>p1pDJy4&z85gh)I_0!2n zCx`u`{p#2LJ~GV4F9ukW4+dDqV+OrADNHfs#OP3WfQ^J>42^sO??~pg873&e`^77= z$VE0Dywoi#(&Wp-qoXA^51YHm(}Cb6kVNC$=#-Z{)jb`E#yL+#N^waxRPw#Q&%42c z$%$jfPLpilN<1NW%;t%y2+RO8!v=y3PtpuCLxq{?X&;&|pP)`C6`Kvs_^DF=6pVsW zClt+OfQ3=R00TmTV@MbvOe8N4EQS__y+Y28jEu_Gd4{<}&H8;aSIAH#2GcUnsAdnF zNoe1u&rVqqNq4fnC3P=ZCMjhFyj4t*DI})P5@eAD2iPzto9goI%ux+Iby@L2sI~*fVvGgU$_+xS9@Q z3p&B1+7$`a>*LskvlwgN z@kx1gw4x@u1sBZ$*AHW@)nj?47tjiCVdY!&fxp|2)<24-ZBr99>c}c%@@Nt-g4l^T zOK$*6QpQRd1Qqn;50}7oAK)G!>zBX1p%xsSgtliV`XuOB#g37Cmt&poIC0F&7V5E$U0G?7EfG|r_)xB}8 z&>LD$z8j_}i50DweaCfbk@{fr)`goF^0uzziGrzd)wCmL+VR2uTSsml$(uaMqiY7! z+pk@H?Kku3@i$&CxEeFNmzq=hVk>5?OO6(IVrJ|0rp2arcBBK>S~FwWmIHoKeQVKaH_(9J_DzsNWHR+|sznh)N0<(mgmBZY?M z44Z4%mpW0jD?XVphwFN1G4#$@`b&l;HqD4S!&5#%)8oBM~j*#Wb*co zH8i?g-kaW)+58S)&gDz#*DS8&=sL_ACJ}rnSdF%2EH*3sg6=kHq1X7(F%0dr0 z8(%{`^GpU zQji}=;gJ(P5EeaZf``W-X=fwM9ETF5q?7AZNykqGW*82wu%wf_mCV6-Oo)UN@dWSD zaHPxu081Q{^fPj4$s?11P~1+$YoO3qDuHTTN{Ca6^IQ+``r*fa1CJ!OX0=~`dGY1< zhL=w*oyuFu{>MrOuN$6>O zlZ0MCSx#=XAWPJV`Y=HoJ~4ibb|ti0aD)56fDJDhI7W#bFOo8H&%>XD4@>x%q(4PX1=tYRj6Uh-DCC%= zode$_3&dE-XPQTmDrfp~Fm0nMe5!oD;JL#fi7FxfPw;?fs^7jmxHNcmBsrA4oZ4Hk zSX12Mf%N#bfvnb_*$un@Saw%`w#lE@`cnf1qy4&N(UKlYU(Os|np+`OPUMYU$)SSH ze(my&w)B~FAbsZduGFs=e+3EOV!bwg15fwDXFq(rkh3%=k3XzyT<%)x%FM0^nb^Iy zzt{O!>jramgZIZD;13Su>t0EoDpxr{cu(Uy8xLD5tvQf}$_9>4(=(E7L*2z8;FgiqW9J4CJ%}Dh&;wzuWGVxdO{u)v%Llv~ z*KGpqcNQ)(xe?Gl2|xZn@PI?!UTpWioVGQ2GWB|)%~PVF#4&Eq8rmN=dhYgQ2Sy45 zFFkPI_kN|pn|_KfT7lj4+-jpI*Xa4A5T;Fisd1BQ$o^y`pooACL3(l5rE&G)oYv{Ur zDDU2%94)wauDUyO?#{e>UvhNKx;=9!Yw5UayLTn~+)&msl+_N&DanJ6d5Ezwk3}+2 z6deyzl#X%LykrdRLddu7Auby8wp|^uZd{>6K#olC#7?5`ppNb!e>xe9FGve(l z`|k|i9z?A6d&8fPd^Unu?Zp=Cxvs+AeUDB1@S`}Yv4>G_D;;;dx4l`nFK6#5Y7lF! zCQFA{XoKur7MH}LL4Gx=ud>ae8N5vGPdCWt?ZhqdrkHhg torch.Tensor: + """Unpack uint32 packed UE4M3 (4 values per uint32) to float8_e4m3fn. + + Args: + packed: (..., sf_k_groups) uint32 — 4 UE4M3 values packed per element + + Returns: + (..., sf_k_groups * 4) float8_e4m3fn + """ + u32 = packed.to(torch.int32) + b0 = (u32 & 0xFF).to(torch.uint8) + b1 = ((u32 >> 8) & 0xFF).to(torch.uint8) + b2 = ((u32 >> 16) & 0xFF).to(torch.uint8) + b3 = ((u32 >> 24) & 0xFF).to(torch.uint8) + interleaved = torch.stack([b0, b1, b2, b3], dim=-1) + return interleaved.reshape(*packed.shape[:-1], -1).contiguous().view(torch.float8_e4m3fn) + + +def unpack_e2m1_to_bf16( + packed: torch.Tensor, # (..., K//2) int8 — two E2M1 values per byte + scales: torch.Tensor, # (..., K//16) float8_e4m3fn — UE4M3 block16 scales +) -> torch.Tensor: + """Dequantize packed E2M1 with UE4M3 block16 scales to BF16. + + E2M1 format: sign(1) exponent(2) mantissa(1), bias=2 + Each int8 byte contains 2 E2M1 values: low nibble=element 0, high nibble=element 1. + UE4M3 block scales: one float8_e4m3fn scale per group of 16 consecutive elements. + + Args: + packed: (..., K//2) int8 packed E2M1 + scales: (..., K//16) float8_e4m3fn UE4M3 block16 scales + + Returns: + (..., K) bfloat16 + """ + u8 = packed.view(torch.uint8) + lo = (u8 & 0x0F).to(torch.int32) # lower nibble + hi = (u8 >> 4).to(torch.int32) # upper nibble + + # Interleave: (..., K//2, 2) → (..., K) + unpacked = torch.stack([lo, hi], dim=-1).reshape(*u8.shape[:-1], -1) + + # E2M1 → float32 + sign = (unpacked >> 3).to(torch.float32) * -2.0 + 1.0 + exp_field = (unpacked >> 1) & 0x3 + mant = (unpacked & 0x1).to(torch.float32) + + # E2M1 value = sign * 2^(exp - 2) * (1 + mant * 0.5) + val = sign * (2.0 ** (exp_field.to(torch.float32) - 2.0)) * (1.0 + mant * 0.5) + + # Zero: exp=0 and mant=0 + zero_mask = (exp_field == 0) & ((unpacked & 1) == 0) + val = val * (~zero_mask).to(torch.float32) + + # Apply UE4M3 block16 scales + sf_f32 = scales.to(torch.float32) + sf_expanded = sf_f32.repeat_interleave(16, dim=-1) + + K = unpacked.shape[-1] + sf_expanded = sf_expanded[..., :K] + + return (val * sf_expanded).to(torch.bfloat16) diff --git a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py index cf32a293..25786966 100644 --- a/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py +++ b/src/nvfp4_megamoe_kernel/nvfp4_mega_moe.py @@ -10,12 +10,25 @@ Architecture: - NVLink cross-rank sync via symm buffer - Expert parallel: each rank handles NUM_EXPERTS/8 experts -The kernel is written in TileLang, compiled to SM100 (Blackwell) CUBIN. +The kernel uses TileLang, compiled to SM100 (Blackwell) CUBIN. + +Strategy: + TileLang's tcgen05_gemm_blockscaled currently supports MXFP8 (FP8 + E8M0 scales). + NVFP4 uses E2M1 packed weights + UE4M3 scales with group_size=16. + We use a dequantize-then-GEMM approach: + 1. Load packed FP4 (int8) weights + UE4M3 (uint32) scales into shared memory + 2. Dequantize to BF16 in shared memory (FP4 → BF16 using UE4M3 block scales) + 3. Run regular BF16 GEMM via T.gemm (auto-lowers to tcgen05 on Blackwell) + This is correct and will be replaced with native FP4 block-scaled MMA once + TileLang adds tcgen05.mma kind::mxf8f6f4.block_scale support for E2M1+UE4M3. """ import os import torch +from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_e2m1_to_bf16, unpack_ue4m3_u32 +from nvfp4_megamoe_kernel.tilelang_kernels import grouped_gemm_fp4, grouped_gemm_fp4_packed_sf + # DeepSeek-V4-Pro dimensions HIDDEN = 7168 INTERMEDIATE = 3072 @@ -32,6 +45,11 @@ MEGA_MOE_STATIC = int(os.environ.get("MEGA_MOE_STATIC", "0")) MEGA_MOE_DEBUG = int(os.environ.get("MEGA_MOE_DEBUG", "0")) + +# --------------------------------------------------------------------------- +# Main kernel entry points +# --------------------------------------------------------------------------- + def nvfp4_mega_moe_l1( x_fp4, # (num_tokens, K//2) int8 packed E2M1 x_sf, # (num_tokens, sf_k_groups) uint32 packed UE4M3 @@ -42,10 +60,33 @@ def nvfp4_mega_moe_l1( num_experts_per_rank, ): """L1 GEMM: gate_up_proj — FP4 x FP4 → BF16 with block scaling. - - TODO: TileLang JIT kernel (nvfp4_blockscaled_gemm_2cta_persistent pattern). + + Pipeline: + 1. Dequantize activation FP4 → BF16 using UE4M3 block16 scales + 2. Dequantize weight FP4 → BF16 using UE4M3 block16 scales + 3. Per-expert grouped BF16 GEMM with routing weights + + TODO: Replace with native FP4 block-scaled MMA once TileLang supports + tcgen05.mma kind::mxf8f6f4.block_scale with E2M1+UE4M3 inputs. """ - raise NotImplementedError("nvfp4_mega_moe_l1 TileLang kernel not yet implemented") + num_tokens = x_fp4.shape[0] + K_half = x_fp4.shape[1] + K = K_half * 2 # HIDDEN = 7168 + N = l1_weights.shape[1] # 2 * INTERMEDIATE = 6144 + + if MEGA_MOE_DEBUG: + print(f"[nvfp4_moe_l1] tokens={num_tokens} K={K} N={N} " + f"experts={num_experts_per_rank}") + + # Dequantize activation FP4 → BF16 + x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf + x_bf16 = unpack_e2m1_to_bf16(x_fp4, x_sf_fp8) # (num_tokens, K) + + # Grouped expert GEMM (handles weight dequant internally) + w_sf_fp8 = unpack_ue4m3_u32(l1_scales) if l1_scales.dtype == torch.uint32 else l1_scales + output = grouped_gemm_fp4(x_bf16, l1_weights, w_sf_fp8, topk_ids, topk_weights) + + return output # (num_tokens, 6144) bfloat16 def nvfp4_mega_moe_l2( @@ -58,15 +99,32 @@ def nvfp4_mega_moe_l2( num_experts_per_rank, ): """L2 GEMM: down_proj — FP4 x FP4 → BF16 with block scaling. - - TODO: TileLang JIT kernel (same pattern as L1). + + Same pipeline as L1: dequantize FP4→BF16, then grouped expert GEMM. """ - raise NotImplementedError("nvfp4_mega_moe_l2 TileLang kernel not yet implemented") + num_tokens = x_fp4.shape[0] + K_half = x_fp4.shape[1] + K = K_half * 2 # INTERMEDIATE = 3072 + N = l2_weights.shape[1] # HIDDEN = 7168 + + if MEGA_MOE_DEBUG: + print(f"[nvfp4_moe_l2] tokens={num_tokens} K={K} N={N} " + f"experts={num_experts_per_rank}") + + # Dequantize activation FP4 → BF16 + x_sf_fp8 = unpack_ue4m3_u32(x_sf) if x_sf.dtype == torch.uint32 else x_sf + x_bf16 = unpack_e2m1_to_bf16(x_fp4, x_sf_fp8) # (num_tokens, K) + + # Grouped expert GEMM + w_sf_fp8 = unpack_ue4m3_u32(l2_scales) if l2_scales.dtype == torch.uint32 else l2_scales + output = grouped_gemm_fp4(x_bf16, l2_weights, w_sf_fp8, topk_ids, topk_weights) + + return output # (num_tokens, 7168) bfloat16 def stage_activation(x_bf16): """Quantize BF16 activation to FP4 (E2M1) with UE4M3 block16 scales. - + This replaces the Triton staging kernel from patches/staging_kernel.py. """ from vllm.model_executor.layers.quantization.utils.fp4_utils import ( @@ -84,13 +142,13 @@ def nvfp4_mega_moe_full( fast_math=False, # fast math flag (unused in NVFP4) ): """Full mega_moe forward pass — replaces deep_gemm.mega.fp8_nvfp4_mega_moe. - + API matches the DeepGEMM fp8_nvfp4_mega_moe call signature used in the vLLM deepseek_v4.py patch: - + fp8_nvfp4_mega_moe(y, l1_weights, l2_weights, symm_buffer, activation_clamp=..., fast_math=...) - + Pipeline: 1. Read staged activation from symm_buffer (already quantized by staging kernel) 2. L1 GEMM: gate_up_proj (FP4 x FP4 → BF16 with block scaling) @@ -98,24 +156,24 @@ def nvfp4_mega_moe_full( 4. Quantize L1 output → FP4 + UE4M3 scales 5. L2 GEMM: down_proj (FP4 x FP4 → BF16 with block scaling) 6. NVLink sync + reduce across ranks → write to y - + When MEGA_MOE_STATIC=1, returns zeros (bypass) for pipeline testing. """ num_tokens = y.shape[0] device = y.device dtype = y.dtype - + if MEGA_MOE_STATIC: if MEGA_MOE_DEBUG: print(f"[MEGA_MOE_STATIC] Skipping nvfp4_mega_moe, returning zeros " f"shape=({num_tokens}, {y.shape[1]})") y.zero_() return - + # Unpack transformed weights l1_w, l1_sf = transformed_l1_weights l2_w, l2_sf = transformed_l2_weights - + # Step 1: Read staged activation from symm_buffer # The staging has already been done by _stage_deepseek_v4_mega_moe_inputs # and stored in symm_buffer.x, symm_buffer.x_sf @@ -123,32 +181,32 @@ def nvfp4_mega_moe_full( x_sf = symm_buffer.x_sf[:num_tokens] topk_ids = symm_buffer.topk_idx[:num_tokens] topk_weights = symm_buffer.topk_weights[:num_tokens] - + if MEGA_MOE_DEBUG: print(f"[nvfp4_mega_moe_full] x_fp4={x_fp4.shape} x_sf={x_sf.shape} " f"topk_ids={topk_ids.shape} l1_w={l1_w.shape} l2_w={l2_w.shape}") - + # Step 2: L1 GEMM num_experts_per_rank = l1_w.shape[0] l1_output = nvfp4_mega_moe_l1( x_fp4, x_sf, l1_w, l1_sf, topk_ids, topk_weights, num_experts_per_rank, ) - + # Step 3: SiLU + Mul gate, up = l1_output.chunk(2, dim=-1) activated = torch.nn.functional.silu(gate) * up if activation_clamp is not None: activated = activated.clamp(max=activation_clamp) - + # Step 4: Quantize L1 output → FP4 l1_fp4, l1_sf_out = stage_activation(activated) - + # Step 5: L2 GEMM l2_output = nvfp4_mega_moe_l2( l1_fp4, l1_sf_out, l2_w, l2_sf, topk_ids, topk_weights, num_experts_per_rank, ) - + # Step 6: Write to output y.copy_(l2_output) diff --git a/src/nvfp4_megamoe_kernel/tilelang_kernels.py b/src/nvfp4_megamoe_kernel/tilelang_kernels.py new file mode 100644 index 00000000..1ab67e19 --- /dev/null +++ b/src/nvfp4_megamoe_kernel/tilelang_kernels.py @@ -0,0 +1,136 @@ +""" +TileLang NVFP4 Mega MoE Kernels — BF16 GEMM with FP4 dequantization. + +This module provides the core GEMM kernels for the DeepSeek-V4-Pro MoE layer: +- L1 (gate_up_proj): HIDDEN→2*INTERMEDIATE, FP4 weights + UE4M3 scales +- L2 (down_proj): INTERMEDIATE→HIDDEN, FP4 weights + UE4M3 scales + +Current approach: Dequantize FP4→BF16, then run BF16 GEMM via TileLang. +This is correct and functional. Once TileLang adds native tcgen05.mma +kind::mxf8f6f4.block_scale support for E2M1+UE4M3, we'll switch to +native FP4 block-scaled MMA for maximum throughput. + +The per-expert GEMM uses a "segmented" approach: sort tokens by expert, +batched GEMM per expert using TileLang-compiled BF16 kernels. +""" + +import torch +import tilelang +import tilelang.language as T + +from nvfp4_megamoe_kernel.nvfp4_dequant import unpack_e2m1_to_bf16, unpack_ue4m3_u32 + +# --------------------------------------------------------------------------- +# TileLang BF16 GEMM kernel (auto-detects Blackwell, lowers to tcgen05) +# --------------------------------------------------------------------------- + +_kernel_cache = {} + + +def _make_bf16_gemm(M, N, K, block_M=128, block_N=128, block_K=128, num_stages=3): + """Build and cache a TileLang BF16 GEMM kernel for the given dimensions.""" + key = (M, N, K, block_M, block_N, block_K, num_stages) + if key in _kernel_cache: + return _kernel_cache[key] + + @tilelang.jit(out_idx=[2]) + def bf16_gemm( + A: T.Tensor((M, K), T.bfloat16), + B: T.Tensor((K, N), T.bfloat16), + C: T.Tensor((M, N), T.bfloat16), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), T.bfloat16) + B_shared = T.alloc_shared((block_K, block_N), T.bfloat16) + C_local = T.alloc_fragment((block_M, block_N), T.float32) + + T.clear(C_local) + + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, bx * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + _kernel_cache[key] = bf16_gemm + return bf16_gemm + + +# --------------------------------------------------------------------------- +# Grouped expert GEMM with FP4 dequantization +# --------------------------------------------------------------------------- + +def grouped_gemm_fp4( + x_bf16: torch.Tensor, # (total_tokens, K_dim) bfloat16 + weights_fp4: torch.Tensor, # (E, N, K//2) int8 packed E2M1 + scales_ue4m3: torch.Tensor, # (E, N, K//16) float8_e4m3fn + topk_ids: torch.Tensor, # (num_tokens, NUM_TOPK) int32 + topk_weights: torch.Tensor, # (num_tokens, NUM_TOPK) float32 +) -> torch.Tensor: + """Segmented grouped expert GEMM: dequantize FP4→BF16, per-expert GEMM. + + Strategy: + 1. Sort tokens by expert assignment + 2. For each expert, dequantize its weight to BF16 (cached) + 3. Run batched BF16 GEMM using TileLang-compiled kernels + 4. Scatter results back with routing weights + """ + num_tokens, K_dim = x_bf16.shape + E, N, K_half = weights_fp4.shape + K = K_half * 2 + assert K == K_dim, f"Activation K={K_dim} doesn't match weight K={K}" + top_k = topk_ids.shape[1] + device = x_bf16.device + + output = torch.zeros(num_tokens, N, dtype=torch.bfloat16, device=device) + + # Pre-compute expert weight dequantization (cache for repeated use) + # For 32 experts, this is manageable + w_bf16_cache = {} + for e in range(E): + w_bf16_cache[e] = unpack_e2m1_to_bf16(weights_fp4[e], scales_ue4m3[e]) # (N, K) + + # Process per expert + for e in range(E): + # Find all (token, k_idx) pairs for this expert + mask = (topk_ids == e) # (num_tokens, top_k) + if not mask.any(): + continue + + w_bf16 = w_bf16_cache[e] # (N, K) + + # Collect tokens for this expert across all top-k slots + for k_idx in range(top_k): + token_mask = mask[:, k_idx] + if not token_mask.any(): + continue + token_indices = token_mask.nonzero(as_tuple=True)[0] + + # Gather activations + x_sub = x_bf16[token_indices] # (n, K) + + # BF16 GEMM: (n, K) @ (N, K).T → (n, N) + result = torch.nn.functional.linear(x_sub, w_bf16) + + # Weighted scatter-add + weights = topk_weights[token_indices, k_idx].unsqueeze(-1) + output[token_indices] += result * weights + + return output + + +# --------------------------------------------------------------------------- +# Convenience: grouped GEMM with uint32 packed scales +# --------------------------------------------------------------------------- + +def grouped_gemm_fp4_packed_sf( + x_bf16: torch.Tensor, + weights_fp4: torch.Tensor, # (E, N, K//2) int8 + scales_packed: torch.Tensor, # (E, N, sf_k_groups) uint32 packed UE4M3 + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, +) -> torch.Tensor: + """Same as grouped_gemm_fp4 but unpacks uint32 packed UE4M3 scales first.""" + scales_fp8 = unpack_ue4m3_u32(scales_packed) + return grouped_gemm_fp4(x_bf16, weights_fp4, scales_fp8, topk_ids, topk_weights)