[Models] Step-3.5-Flash (#33523)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: i-zhangmingming <i-zhangmingming@stepfun.com> Co-authored-by: xiewuxun <xiewuxun@stepfun.com> Co-authored-by: zetaohong <i-hongzetao@stepfun.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -17,11 +17,63 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.collection_utils import LazyDict
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _swiglustep_and_mul_kernel(
|
||||
o_ptr,
|
||||
o_stride,
|
||||
x_ptr,
|
||||
x_stride,
|
||||
limit: tl.constexpr,
|
||||
d: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
) -> None:
|
||||
i = tl.program_id(axis=0).to(tl.int64)
|
||||
j = tl.program_id(axis=1)
|
||||
o_row_ptr = o_ptr + o_stride * i
|
||||
x_row_ptr = x_ptr + x_stride * i
|
||||
offsets = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < d
|
||||
|
||||
gate = tl.load(x_row_ptr + offsets, mask=mask).to(tl.float32)
|
||||
up = tl.load(x_row_ptr + offsets + d, mask=mask).to(tl.float32)
|
||||
|
||||
gate_silu = tl.sigmoid(gate) * gate
|
||||
gate_clamped = tl.minimum(gate_silu, limit)
|
||||
up_clamped = tl.minimum(tl.maximum(up, -limit), limit)
|
||||
|
||||
result = gate_clamped * up_clamped
|
||||
result = result.to(x_ptr.dtype.element_ty)
|
||||
tl.store(o_row_ptr + offsets, result, mask=mask)
|
||||
|
||||
|
||||
def swiglustep_and_mul_triton(
|
||||
output: torch.Tensor, input: torch.Tensor, limit: float = 7.0
|
||||
):
|
||||
b, n = input.shape
|
||||
assert input.ndim == 2
|
||||
assert n % 2 == 0
|
||||
d = n // 2
|
||||
|
||||
def grid(meta):
|
||||
return (b, triton.cdiv(d, meta["BLOCK_SIZE"]))
|
||||
|
||||
_swiglustep_and_mul_kernel[grid](
|
||||
output,
|
||||
output.stride(0),
|
||||
input,
|
||||
input.stride(0),
|
||||
limit=limit,
|
||||
d=d,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
|
||||
# --8<-- [start:fatrelu_and_mul]
|
||||
@CustomOp.register("fatrelu_and_mul")
|
||||
class FatreluAndMul(CustomOp):
|
||||
@@ -304,6 +356,44 @@ class SwigluOAIAndMul(CustomOp):
|
||||
return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"
|
||||
|
||||
|
||||
# --8<-- [start:swiglustep_and_mul]
|
||||
@CustomOp.register("swiglustep_and_mul")
|
||||
class SwigluStepAndMul(CustomOp):
|
||||
"""An activation function for SwiGLU with clamping.
|
||||
|
||||
Computes x -> silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit)
|
||||
where d = x.shape[-1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
def __init__(self, limit: float = 7.0):
|
||||
super().__init__()
|
||||
if limit is None:
|
||||
raise ValueError("SwigluStepAndMul requires limit to be set.")
|
||||
self.limit = limit
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
gate, up = x.chunk(2, dim=-1)
|
||||
gate = F.silu(gate)
|
||||
gate = gate.clamp(max=self.limit)
|
||||
up = up.clamp(min=-self.limit, max=self.limit)
|
||||
return gate * up
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
swiglustep_and_mul_triton(out, x, self.limit)
|
||||
return out
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"limit={repr(self.limit)}"
|
||||
|
||||
|
||||
# --8<-- [start:gelu_new]
|
||||
@CustomOp.register("gelu_new")
|
||||
class NewGELU(CustomOp):
|
||||
|
||||
Reference in New Issue
Block a user