[Kernel] Support MulAndSilu (#11624)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-01-15 10:29:53 +08:00
committed by GitHub
parent a3a3ee4e6f
commit 42f5e7c52a
7 changed files with 83 additions and 36 deletions

View File

@@ -23,7 +23,8 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU,
SiluAndMul)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
@@ -462,15 +463,6 @@ class MolmoAttention(nn.Module):
return output
class SwiGLU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, gate = x.chunk(2, dim=-1)
# Note that the order is reversed compared to
# SiluAndMul.
return x * F.silu(gate)
class LanuageModelMLP(nn.Module):
"""Molmo's LLM mlp."""
@@ -489,7 +481,7 @@ class LanuageModelMLP(nn.Module):
quant_config=quant_config,
)
# Activation function.
self.act_fn = SwiGLU()
self.act_fn = MulAndSilu()
# Feed-forward output projection.
self.down_proj = RowParallelLinear(
self.intermediate_size,