[Kernel] Support MulAndSilu (#11624)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -23,7 +23,8 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
|
||||
from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU,
|
||||
SiluAndMul)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
@@ -462,15 +463,6 @@ class MolmoAttention(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x, gate = x.chunk(2, dim=-1)
|
||||
# Note that the order is reversed compared to
|
||||
# SiluAndMul.
|
||||
return x * F.silu(gate)
|
||||
|
||||
|
||||
class LanuageModelMLP(nn.Module):
|
||||
"""Molmo's LLM mlp."""
|
||||
|
||||
@@ -489,7 +481,7 @@ class LanuageModelMLP(nn.Module):
|
||||
quant_config=quant_config,
|
||||
)
|
||||
# Activation function.
|
||||
self.act_fn = SwiGLU()
|
||||
self.act_fn = MulAndSilu()
|
||||
# Feed-forward output projection.
|
||||
self.down_proj = RowParallelLinear(
|
||||
self.intermediate_size,
|
||||
|
||||
Reference in New Issue
Block a user