[Kernels] Overlap shared experts with send/recv (#23273)

Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
bnellnm
2025-09-03 12:35:18 -04:00
committed by GitHub
parent fa4311d85f
commit e9b92dcd89
32 changed files with 885 additions and 227 deletions

View File

@@ -8,6 +8,7 @@ import vllm._custom_ops as ops
from tests.kernels.quant_utils import per_block_cast_to_int8
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
FLOAT8_E4M3_MAX)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
@@ -282,3 +283,151 @@ def per_token_cast_to_fp8(
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
# CustomOp?
class BaselineMM(torch.nn.Module):
def __init__(
self,
b: torch.Tensor,
out_dtype: torch.dtype,
):
super().__init__()
self.b = b.to(dtype=torch.float32)
self.out_dtype = out_dtype
def forward(
self,
a: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
return torch.mm(a.to(dtype=torch.float32),
self.b).to(self.out_dtype), None
class TestMLP(torch.nn.Module):
def __init__(
self,
w1: torch.Tensor,
w2: torch.Tensor,
out_dtype: torch.dtype,
):
super().__init__()
self.gate_up_proj = BaselineMM(w1, out_dtype)
self.down_proj = BaselineMM(w2, out_dtype)
self.act_fn = SiluAndMul()
def forward(self, x):
x, _ = self.gate_up_proj(x)
x = self.act_fn(x)
x, _ = self.down_proj(x)
return x
def make_naive_shared_experts(
N: int,
K: int,
in_dtype: torch.dtype = torch.bfloat16,
) -> torch.nn.Module:
w1 = torch.randn((K, N * 2), device="cuda", dtype=in_dtype) / 15
w2 = torch.randn((N, K), device="cuda", dtype=in_dtype) / 15
return TestMLP(w1, w2, out_dtype=in_dtype)
class RealMLP(torch.nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
w1: torch.Tensor,
w2: torch.Tensor,
hidden_act: str = "silu",
quant_config=None,
reduce_results: bool = True,
prefix: str = "",
w1_s: Optional[torch.Tensor] = None,
w2_s: Optional[torch.Tensor] = None,
) -> None:
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, RowParallelLinear)
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.gate_up_proj.register_parameter(
"weight", torch.nn.Parameter(w1, requires_grad=False))
self.gate_up_proj.register_parameter(
"weight_scale", torch.nn.Parameter(w1_s, requires_grad=False))
self.gate_up_proj.register_parameter(
"input_scale",
None) #torch.nn.Parameter(None, requires_grad=False))
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
self.down_proj.register_parameter(
"weight", torch.nn.Parameter(w2, requires_grad=False))
self.down_proj.register_parameter(
"weight_scale", torch.nn.Parameter(w2_s, requires_grad=False))
self.down_proj.register_parameter(
"input_scale",
None) #torch.nn.Parameter(None, requires_grad=False))
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
def make_shared_experts(
N: int,
K: int,
in_dtype: torch.dtype = torch.bfloat16,
quant_dtype: Union[torch.dtype, str, None] = None,
) -> torch.nn.Module:
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
1,
N,
K,
in_dtype=in_dtype,
quant_dtype=quant_dtype,
)
old_dtype = torch.get_default_dtype()
try:
torch.set_default_dtype(in_dtype)
if quant_dtype == torch.float8_e4m3fn:
w1 = w1[0].transpose(0, 1)
w2 = w2[0].transpose(0, 1)
w1_s = w1_s[0].transpose(0, 1) if w1_s is not None else None
w2_s = w2_s[0].transpose(0, 1) if w2_s is not None else None
quant_config = Fp8Config(True)
else:
w1 = w1[0]
w2 = w2[0]
w1_s = None
w2_s = None
quant_config = None
return RealMLP(K,
N,
w1,
w2,
"silu",
quant_config,
w1_s=w1_s,
w2_s=w2_s)
finally:
torch.set_default_dtype(old_dtype)