[Model] Add support for OLMo Hybrid (#32550)

This commit is contained in:
Yanhong Li
2026-03-05 11:51:06 -08:00
committed by GitHub
parent 5395471d29
commit a911f4dd20
10 changed files with 1520 additions and 53 deletions

View File

@@ -76,16 +76,20 @@ def l2norm_fwd_kernel(
@triton.jit
def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
def l2norm_fwd_kernel2(
X, Y, eps, M, N: tl.constexpr, BD: tl.constexpr, MBLOCK: tl.constexpr
):
xoffset = tl.program_id(0) * MBLOCK
row_idx = xoffset + tl.arange(0, MBLOCK)[:, None]
xmask = row_idx < M
rindex = tl.arange(0, N)[None, :]
xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32)
square = tl.broadcast_to(xs * xs, [MBLOCK, N])
rindex = tl.arange(0, BD)[None, :]
cmask = rindex < N
mask = xmask & cmask
xs = tl.load(X + (rindex + N * row_idx), mask, other=0.0).to(tl.float32)
square = tl.broadcast_to(xs * xs, [MBLOCK, BD])
square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
rsqrt = tl.rsqrt(square_sum + eps)
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, mask)
def l2norm_fwd(
@@ -116,6 +120,7 @@ def l2norm_fwd(
eps,
T,
D,
BD,
MBLOCK,
)
else:

View File

@@ -250,57 +250,55 @@ def layer_norm_fwd(
return out, mean, rstd
class LayerNormFn(torch.autograd.Function):
@input_guard
@staticmethod
def forward(
ctx,
def _layer_norm_fn_impl(
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
activation: str = "swish",
):
"""Triton layer/RMS norm with optional gating.
If z is not None, computes norm(x) * silu(z) when norm_before_gate,
else norm(x * silu(z)).
This calls the triton kernel directly. The original code wrapped this
in a torch.autograd.Function (LayerNormFn) to save tensors for a
backward pass, but vLLM is inference-only so there is no backward pass.
The autograd wrapper also prevented torch.compile/dynamo from tracing
through the function due to its @staticmethod forward.
"""
x_shape_og = x.shape
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if z is not None:
assert z.shape == x_shape_og
z = z.reshape(-1, z.shape[-1])
if z.stride(-1) != 1:
z = z.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, _, _ = layer_norm_fwd(
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
activation: str = "swish",
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if z is not None:
assert z.shape == x_shape_og
z = z.reshape(-1, z.shape[-1])
if z.stride(-1) != 1:
z = z.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, mean, rstd = layer_norm_fwd(
x,
weight,
bias,
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
activation=activation,
)
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
ctx.x_shape_og = x_shape_og
ctx.eps = eps
ctx.group_size = group_size
ctx.norm_before_gate = norm_before_gate
ctx.is_rms_norm = is_rms_norm
ctx.activation = activation
return y.reshape(x_shape_og)
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
activation=activation,
)
return y.reshape(x_shape_og)
@input_guard
def layernorm_fn(
x,
weight,
@@ -312,11 +310,12 @@ def layernorm_fn(
is_rms_norm=False,
activation: str = "swish",
):
return LayerNormFn.apply(
return _layer_norm_fn_impl(
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm, activation
)
@input_guard
def rmsnorm_fn(
x,
weight,
@@ -327,7 +326,7 @@ def rmsnorm_fn(
norm_before_gate=True,
activation: str = "swish",
):
return LayerNormFn.apply(
return _layer_norm_fn_impl(
x, weight, bias, z, eps, group_size, norm_before_gate, True, activation
)

File diff suppressed because it is too large Load Diff

View File

@@ -171,6 +171,7 @@ _TEXT_GENERATION_MODELS = {
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"Olmo2ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
"Olmo3ForCausalLM": ("olmo2", "Olmo2ForCausalLM"),
"OlmoHybridForCausalLM": ("olmo_hybrid", "OlmoHybridForCausalLM"),
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),