[Model] Add support for OLMo Hybrid (#32550)
This commit is contained in:
@@ -76,16 +76,20 @@ def l2norm_fwd_kernel(
|
||||
|
||||
|
||||
@triton.jit
|
||||
def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
|
||||
def l2norm_fwd_kernel2(
|
||||
X, Y, eps, M, N: tl.constexpr, BD: tl.constexpr, MBLOCK: tl.constexpr
|
||||
):
|
||||
xoffset = tl.program_id(0) * MBLOCK
|
||||
row_idx = xoffset + tl.arange(0, MBLOCK)[:, None]
|
||||
xmask = row_idx < M
|
||||
rindex = tl.arange(0, N)[None, :]
|
||||
xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32)
|
||||
square = tl.broadcast_to(xs * xs, [MBLOCK, N])
|
||||
rindex = tl.arange(0, BD)[None, :]
|
||||
cmask = rindex < N
|
||||
mask = xmask & cmask
|
||||
xs = tl.load(X + (rindex + N * row_idx), mask, other=0.0).to(tl.float32)
|
||||
square = tl.broadcast_to(xs * xs, [MBLOCK, BD])
|
||||
square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
|
||||
rsqrt = tl.rsqrt(square_sum + eps)
|
||||
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
|
||||
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, mask)
|
||||
|
||||
|
||||
def l2norm_fwd(
|
||||
@@ -116,6 +120,7 @@ def l2norm_fwd(
|
||||
eps,
|
||||
T,
|
||||
D,
|
||||
BD,
|
||||
MBLOCK,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -250,57 +250,55 @@ def layer_norm_fwd(
|
||||
return out, mean, rstd
|
||||
|
||||
|
||||
class LayerNormFn(torch.autograd.Function):
|
||||
@input_guard
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
def _layer_norm_fn_impl(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
activation: str = "swish",
|
||||
):
|
||||
"""Triton layer/RMS norm with optional gating.
|
||||
|
||||
If z is not None, computes norm(x) * silu(z) when norm_before_gate,
|
||||
else norm(x * silu(z)).
|
||||
|
||||
This calls the triton kernel directly. The original code wrapped this
|
||||
in a torch.autograd.Function (LayerNormFn) to save tensors for a
|
||||
backward pass, but vLLM is inference-only so there is no backward pass.
|
||||
The autograd wrapper also prevented torch.compile/dynamo from tracing
|
||||
through the function due to its @staticmethod forward.
|
||||
"""
|
||||
x_shape_og = x.shape
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if z is not None:
|
||||
assert z.shape == x_shape_og
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
if z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, _, _ = layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
activation: str = "swish",
|
||||
):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if z is not None:
|
||||
assert z.shape == x_shape_og
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
if z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, mean, rstd = layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
activation=activation,
|
||||
)
|
||||
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
|
||||
ctx.x_shape_og = x_shape_og
|
||||
ctx.eps = eps
|
||||
ctx.group_size = group_size
|
||||
ctx.norm_before_gate = norm_before_gate
|
||||
ctx.is_rms_norm = is_rms_norm
|
||||
ctx.activation = activation
|
||||
return y.reshape(x_shape_og)
|
||||
eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
activation=activation,
|
||||
)
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
|
||||
@input_guard
|
||||
def layernorm_fn(
|
||||
x,
|
||||
weight,
|
||||
@@ -312,11 +310,12 @@ def layernorm_fn(
|
||||
is_rms_norm=False,
|
||||
activation: str = "swish",
|
||||
):
|
||||
return LayerNormFn.apply(
|
||||
return _layer_norm_fn_impl(
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm, activation
|
||||
)
|
||||
|
||||
|
||||
@input_guard
|
||||
def rmsnorm_fn(
|
||||
x,
|
||||
weight,
|
||||
@@ -327,7 +326,7 @@ def rmsnorm_fn(
|
||||
norm_before_gate=True,
|
||||
activation: str = "swish",
|
||||
):
|
||||
return LayerNormFn.apply(
|
||||
return _layer_norm_fn_impl(
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, True, activation
|
||||
)
|
||||
|
||||
|
||||
1172
vllm/model_executor/models/olmo_hybrid.py
Normal file
1172
vllm/model_executor/models/olmo_hybrid.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -171,6 +171,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
|
||||
"Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
|
||||
"Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
|
||||
"OlmoHybridForCausalLM": ("olmo_hybrid", "OlmoHybridForCausalLM"),
|
||||
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
|
||||
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
||||
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
||||
|
||||
Reference in New Issue
Block a user