[Kernels] Overlap shared experts with send/recv (#23273)
Signed-off-by: Bill Nell <bnell@redhat.com>
This commit is contained in:
@@ -8,6 +8,7 @@ import vllm._custom_ops as ops
|
||||
from tests.kernels.quant_utils import per_block_cast_to_int8
|
||||
from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX,
|
||||
FLOAT8_E4M3_MAX)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedPrepareAndFinalize, BatchedTritonExperts, NaiveBatchedExperts)
|
||||
@@ -282,3 +283,151 @@ def per_token_cast_to_fp8(
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
|
||||
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
# CustomOp?
|
||||
class BaselineMM(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
b: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
):
|
||||
super().__init__()
|
||||
self.b = b.to(dtype=torch.float32)
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
def forward(
|
||||
self,
|
||||
a: torch.Tensor) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
return torch.mm(a.to(dtype=torch.float32),
|
||||
self.b).to(self.out_dtype), None
|
||||
|
||||
|
||||
class TestMLP(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = BaselineMM(w1, out_dtype)
|
||||
self.down_proj = BaselineMM(w2, out_dtype)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
x, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(x)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
def make_naive_shared_experts(
|
||||
N: int,
|
||||
K: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
) -> torch.nn.Module:
|
||||
w1 = torch.randn((K, N * 2), device="cuda", dtype=in_dtype) / 15
|
||||
w2 = torch.randn((N, K), device="cuda", dtype=in_dtype) / 15
|
||||
return TestMLP(w1, w2, out_dtype=in_dtype)
|
||||
|
||||
|
||||
class RealMLP(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
hidden_act: str = "silu",
|
||||
quant_config=None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
w1_s: Optional[torch.Tensor] = None,
|
||||
w2_s: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear, RowParallelLinear)
|
||||
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.gate_up_proj.register_parameter(
|
||||
"weight", torch.nn.Parameter(w1, requires_grad=False))
|
||||
self.gate_up_proj.register_parameter(
|
||||
"weight_scale", torch.nn.Parameter(w1_s, requires_grad=False))
|
||||
self.gate_up_proj.register_parameter(
|
||||
"input_scale",
|
||||
None) #torch.nn.Parameter(None, requires_grad=False))
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
self.down_proj.register_parameter(
|
||||
"weight", torch.nn.Parameter(w2, requires_grad=False))
|
||||
self.down_proj.register_parameter(
|
||||
"weight_scale", torch.nn.Parameter(w2_s, requires_grad=False))
|
||||
self.down_proj.register_parameter(
|
||||
"input_scale",
|
||||
None) #torch.nn.Parameter(None, requires_grad=False))
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
def make_shared_experts(
|
||||
N: int,
|
||||
K: int,
|
||||
in_dtype: torch.dtype = torch.bfloat16,
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
) -> torch.nn.Module:
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
||||
|
||||
(_, w1, w1_s, _), (_, w2, w2_s, _) = make_test_weights(
|
||||
1,
|
||||
N,
|
||||
K,
|
||||
in_dtype=in_dtype,
|
||||
quant_dtype=quant_dtype,
|
||||
)
|
||||
old_dtype = torch.get_default_dtype()
|
||||
try:
|
||||
torch.set_default_dtype(in_dtype)
|
||||
if quant_dtype == torch.float8_e4m3fn:
|
||||
w1 = w1[0].transpose(0, 1)
|
||||
w2 = w2[0].transpose(0, 1)
|
||||
w1_s = w1_s[0].transpose(0, 1) if w1_s is not None else None
|
||||
w2_s = w2_s[0].transpose(0, 1) if w2_s is not None else None
|
||||
quant_config = Fp8Config(True)
|
||||
else:
|
||||
w1 = w1[0]
|
||||
w2 = w2[0]
|
||||
w1_s = None
|
||||
w2_s = None
|
||||
quant_config = None
|
||||
|
||||
return RealMLP(K,
|
||||
N,
|
||||
w1,
|
||||
w2,
|
||||
"silu",
|
||||
quant_config,
|
||||
w1_s=w1_s,
|
||||
w2_s=w2_s)
|
||||
finally:
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
Reference in New Issue
Block a user