fix: import NVFP4 SymmBuffer from deep_gemm.mega

This commit is contained in:
2026-05-11 08:05:50 +00:00
parent ff579c9767
commit 8cb23bdb78

View File

@@ -761,7 +761,8 @@ class DeepseekV4MegaMoEExperts(nn.Module):
return w_packed, scale_exp
def get_symm_buffer(self):
from deep_gemm.mega import nvfp4 as nvfp4_mega
import deep_gemm
from deep_gemm.mega import SymmBuffer, get_symm_buffer_for_nvfp4_mega_moe
group = get_ep_group().device_group
device = torch.accelerator.current_device_index()
@@ -776,8 +777,8 @@ class DeepseekV4MegaMoEExperts(nn.Module):
)
symm_buffer = self._symm_buffer_cache.get(key)
if symm_buffer is None:
# NVFP4 SymmBuffer: 2x SF size due to group_size=16 (vs MXFP4's 32)
symm_buffer = nvfp4_mega.get_symm_buffer_for_nvfp4_mega_moe(
# NVFP4 SymmBuffer: 2x SF size due to group_size=16
symm_buffer = get_symm_buffer_for_nvfp4_mega_moe(
group,
self.num_experts,
self.max_num_tokens,