[Model] Ring 2.5 (#35102)

Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
Jiangyun Zhu
2026-02-26 18:17:11 +08:00
committed by GitHub
parent 3827c8c55a
commit ab87f85231
8 changed files with 1407 additions and 70 deletions

View File

@@ -84,6 +84,7 @@ def layer_norm_fwd_kernel(
HAS_Z: tl.constexpr,
NORM_BEFORE_GATE: tl.constexpr,
IS_RMS_NORM: tl.constexpr,
ACTIVATION: tl.constexpr,
):
# Map the program id to the starting row of X and Y it should compute.
row_start = tl.program_id(0) * ROWS_PER_BLOCK
@@ -112,7 +113,10 @@ def layer_norm_fwd_kernel(
if HAS_Z and not NORM_BEFORE_GATE:
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
x *= z * tl.sigmoid(z)
if ACTIVATION == "swish" or ACTIVATION == "silu":
x *= z * tl.sigmoid(z)
elif ACTIVATION == "sigmoid":
x *= tl.sigmoid(z)
# Compute mean and variance per row (reduce along axis 1)
if not IS_RMS_NORM:
@@ -155,7 +159,10 @@ def layer_norm_fwd_kernel(
if HAS_Z and NORM_BEFORE_GATE:
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
y *= z * tl.sigmoid(z)
if ACTIVATION == "swish" or ACTIVATION == "silu":
y *= z * tl.sigmoid(z)
elif ACTIVATION == "sigmoid":
y *= tl.sigmoid(z)
# Write output
tl.store(Y_base, y, mask=mask)
@@ -178,6 +185,7 @@ def layer_norm_fwd(
group_size: int = None,
norm_before_gate: bool = True,
is_rms_norm: bool = False,
activation: str = "swish",
):
M, N = x.shape
if group_size is None:
@@ -232,9 +240,12 @@ def layer_norm_fwd(
eps,
BLOCK_N=BLOCK_N,
ROWS_PER_BLOCK=rows_per_block,
HAS_BIAS=bias is not None,
HAS_Z=z is not None,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps,
ACTIVATION=activation,
)
return out, mean, rstd
@@ -252,6 +263,7 @@ class LayerNormFn(torch.autograd.Function):
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
activation: str = "swish",
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
@@ -277,6 +289,7 @@ class LayerNormFn(torch.autograd.Function):
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
activation=activation,
)
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
ctx.x_shape_og = x_shape_og
@@ -284,6 +297,7 @@ class LayerNormFn(torch.autograd.Function):
ctx.group_size = group_size
ctx.norm_before_gate = norm_before_gate
ctx.is_rms_norm = is_rms_norm
ctx.activation = activation
return y.reshape(x_shape_og)
@@ -296,17 +310,25 @@ def layernorm_fn(
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
activation: str = "swish",
):
return LayerNormFn.apply(
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm, activation
)
def rmsnorm_fn(
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
activation: str = "swish",
):
return LayerNormFn.apply(
x, weight, bias, z, eps, group_size, norm_before_gate, True
x, weight, bias, z, eps, group_size, norm_before_gate, True, activation
)
@@ -359,6 +381,7 @@ class RMSNormGated(nn.Module):
norm_before_gate: bool = False,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
activation: str = "swish",
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
@@ -366,6 +389,7 @@ class RMSNormGated(nn.Module):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.activation = activation
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None)
self.group_size = group_size
@@ -385,4 +409,5 @@ class RMSNormGated(nn.Module):
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
activation=self.activation,
)

View File

@@ -592,6 +592,7 @@ class RMSNormGated(CustomOp):
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
activation=self.activation,
)

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Callable
import torch
import torch.nn.functional as F
@@ -43,7 +44,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
self.weight.weight_loader = self.weight_loader
self.variance_epsilon = eps
return
@staticmethod
def weight_loader(
@@ -56,7 +56,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
shard_size = loaded_weight.shape[0] // tp_world
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
param.data.copy_(loaded_weight[shard])
return
def _forward(
self,
@@ -102,6 +101,101 @@ class MiniMaxText01RMSNormTP(CustomOp):
return q, k
def clear_linear_attention_cache_for_new_sequences(
kv_cache: torch.Tensor,
state_indices_tensor: torch.Tensor,
attn_metadata: LinearAttentionMetadata,
) -> None:
num_prefills = getattr(attn_metadata, "num_prefills", 0)
if num_prefills <= 0:
return
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0)
for prefill_idx in range(num_prefills):
q_start = attn_metadata.query_start_loc[num_decode_tokens + prefill_idx]
q_end = attn_metadata.query_start_loc[num_decode_tokens + prefill_idx + 1]
query_len = q_end - q_start
context_len = (
attn_metadata.seq_lens[num_decode_tokens + prefill_idx] - query_len
)
if context_len == 0:
block_to_clear = state_indices_tensor[num_decode_tokens + prefill_idx]
kv_cache[block_to_clear, ...] = 0
def linear_attention_decode(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
slope_rate: torch.Tensor,
state_indices_tensor: torch.Tensor,
q_start: int = 0,
q_end: int | None = None,
slot_start: int = 0,
slot_end: int | None = None,
block_size: int = 32,
) -> torch.Tensor:
q = q[q_start:q_end].unsqueeze(2).contiguous()
k = k[q_start:q_end].unsqueeze(2).contiguous()
v = v[q_start:q_end].unsqueeze(2).contiguous()
slot_id = state_indices_tensor[slot_start:slot_end]
return linear_decode_forward_triton(
q, k, v, kv_cache, slope_rate, slot_id, block_size
)
def linear_attention_prefill_and_mix(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
state_indices_tensor: torch.Tensor,
attn_metadata: LinearAttentionMetadata,
slope_rate: torch.Tensor,
block_size: int,
decode_fn: Callable[..., torch.Tensor],
prefix_fn: Callable[..., torch.Tensor],
layer_idx: int | None = None,
) -> torch.Tensor:
hidden = []
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
if _prefill_idx >= len(attn_metadata.query_start_loc):
break
if _prefill_idx >= len(state_indices_tensor):
break
offset = attn_metadata.num_decode_tokens
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
slot_id = state_indices_tensor[offset + _prefill_idx]
qs = q[_start:_end].transpose(0, 1).contiguous()
ks = k[_start:_end].transpose(0, 1).contiguous()
vs = v[_start:_end].transpose(0, 1).contiguous()
slice_layer_cache = kv_cache[slot_id, ...]
out_slice = prefix_fn(
qs,
ks,
vs,
slice_layer_cache,
slope_rate,
block_size,
layer_idx=layer_idx,
)
hidden.append(out_slice.contiguous())
if attn_metadata.num_decode_tokens > 0:
hidden_decode = decode_fn(
q, k, v, kv_cache, state_indices_tensor, attn_metadata
)
hidden.insert(0, hidden_decode)
if not hidden:
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
hidden = torch.concat(hidden, dim=0).contiguous()
return hidden
class MiniMaxText01LinearKernel:
@staticmethod
def jit_linear_forward_prefix(
@@ -258,50 +352,33 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
def _prefill_and_mix_infer(
self, q, k, v, kv_cache, state_indices_tensor, attn_metadata
):
hidden = []
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
if _prefill_idx >= len(attn_metadata.query_start_loc):
break
if _prefill_idx >= len(state_indices_tensor):
break
offset = attn_metadata.num_decode_tokens
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
slot_id = state_indices_tensor[offset + _prefill_idx]
qs = q[_start:_end].transpose(0, 1).contiguous()
ks = k[_start:_end].transpose(0, 1).contiguous()
vs = v[_start:_end].transpose(0, 1).contiguous()
slice_layer_cache = kv_cache[slot_id, ...]
out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
qs,
ks,
vs,
slice_layer_cache,
self.tp_slope,
self.BLOCK,
layer_idx=self.layer_idx,
)
hidden.append(out_slice.contiguous())
if attn_metadata.num_decode_tokens > 0:
hidden_decode = self._decode_infer(
q, k, v, kv_cache, state_indices_tensor, attn_metadata
)
hidden.insert(0, hidden_decode)
if not hidden:
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
hidden = torch.concat(hidden, dim=0).contiguous()
return hidden
return linear_attention_prefill_and_mix(
q=q,
k=k,
v=v,
kv_cache=kv_cache,
state_indices_tensor=state_indices_tensor,
attn_metadata=attn_metadata,
slope_rate=self.tp_slope,
block_size=self.BLOCK,
decode_fn=self._decode_infer,
prefix_fn=MiniMaxText01LinearKernel.jit_linear_forward_prefix,
layer_idx=self.layer_idx,
)
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata):
q = q[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
k = k[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
v = v[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
slot_id = state_indices_tensor[: attn_metadata.num_decodes]
hidden = linear_decode_forward_triton(
q, k, v, kv_cache, self.tp_slope, slot_id, 32
hidden = linear_attention_decode(
q,
k,
v,
kv_cache,
self.tp_slope,
state_indices_tensor,
q_start=0,
q_end=attn_metadata.num_decode_tokens,
slot_start=0,
slot_end=attn_metadata.num_decodes,
block_size=32,
)
return hidden
@@ -338,27 +415,9 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
if attn_metadata is not None:
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
state_indices_tensor = attn_metadata.state_indices_tensor
num_prefills = getattr(attn_metadata, "num_prefills", 0)
if num_prefills > 0:
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0)
for prefill_idx in range(num_prefills):
q_start = attn_metadata.query_start_loc[
num_decode_tokens + prefill_idx
]
q_end = attn_metadata.query_start_loc[
num_decode_tokens + prefill_idx + 1
]
query_len = q_end - q_start
context_len = (
attn_metadata.seq_lens[num_decode_tokens + prefill_idx]
- query_len
)
if context_len == 0:
block_to_clear = state_indices_tensor[
num_decode_tokens + prefill_idx
]
kv_cache[block_to_clear, ...] = 0
clear_linear_attention_cache_for_new_sequences(
kv_cache, state_indices_tensor, attn_metadata
)
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
if attn_metadata is None:

File diff suppressed because it is too large Load Diff

View File

@@ -81,6 +81,7 @@ _TEXT_GENERATION_MODELS = {
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
"BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
"BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
"BailingMoeV2_5ForCausalLM": ("bailing_moe_linear", "BailingMoeV25ForCausalLM"),
"BambaForCausalLM": ("bamba", "BambaForCausalLM"),
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),