add NVFP4SymmBuffer + get_symm_buffer_for_nvfp4_mega_moe Python wrapper
The C++ binding was registered but there was no Python wrapper. vLLM patch imports get_symm_buffer_for_nvfp4_mega_moe from deep_gemm.mega.
This commit is contained in:
@@ -72,6 +72,70 @@ def get_symm_buffer_for_mega_moe(group: dist.ProcessGroup,
|
||||
)
|
||||
|
||||
|
||||
class NVFP4SymmBuffer:
|
||||
"""Symmetric buffer for NVFP4 mega_moe kernel.
|
||||
|
||||
Same structure as SymmBuffer but with NVFP4-specific SF layout:
|
||||
group_size=16 means 2x the SF data compared to MXFP4 (group_size=32).
|
||||
x_sf is packed int32 (4 UE4M3 per int32) with shape (M, K//64).
|
||||
"""
|
||||
def __init__(self, group: dist.ProcessGroup,
|
||||
num_experts: int,
|
||||
num_max_tokens_per_rank: int, num_topk: int,
|
||||
hidden: int, intermediate_hidden: int,
|
||||
use_fp8_dispatch: bool = True,
|
||||
activation: str = 'swiglu'):
|
||||
self.group = group
|
||||
self.num_experts = num_experts
|
||||
self.num_max_tokens_per_rank = num_max_tokens_per_rank
|
||||
self.num_topk = num_topk
|
||||
self.hidden = hidden
|
||||
self.intermediate_hidden = intermediate_hidden
|
||||
|
||||
# Allocate a symmetric buffer (NVFP4 variant)
|
||||
num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_nvfp4_mega_moe(
|
||||
group.size(), num_experts,
|
||||
num_max_tokens_per_rank, num_topk,
|
||||
hidden, intermediate_hidden,
|
||||
use_fp8_dispatch, activation
|
||||
)
|
||||
self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda')
|
||||
self.handle = symm_mem.rendezvous(self.buffer, group=group)
|
||||
self.buffer.zero_()
|
||||
self.group.barrier()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Create input buffer views
|
||||
(self.x, self.x_sf,
|
||||
self.topk_idx, self.topk_weights,
|
||||
self.l1_acts, self.l1_acts_sf,
|
||||
self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer)
|
||||
|
||||
def destroy(self):
|
||||
self.handle = None
|
||||
self.buffer = None
|
||||
self.group = None
|
||||
self.x = None
|
||||
self.x_sf = None
|
||||
|
||||
|
||||
def get_symm_buffer_for_nvfp4_mega_moe(group: dist.ProcessGroup,
|
||||
num_experts: int,
|
||||
num_max_tokens_per_rank: int, num_topk: int,
|
||||
hidden: int, intermediate_hidden: int,
|
||||
use_fp8_dispatch: bool = True,
|
||||
activation: str = 'swiglu') -> NVFP4SymmBuffer:
|
||||
# Token count must be aligned to block sizes
|
||||
num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_nvfp4_mega_moe())
|
||||
|
||||
return NVFP4SymmBuffer(
|
||||
group, num_experts,
|
||||
num_max_tokens_per_rank, num_topk,
|
||||
hidden, intermediate_hidden,
|
||||
use_fp8_dispatch, activation
|
||||
)
|
||||
|
||||
|
||||
def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# [gate: 0..7, up: 0..7, gate: 8..15, up: 8..15, ...] instead of [gate | up]
|
||||
def interleave(t, gran: int = 8) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user