[Model] Ring 2.5 (#35102)
Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
This commit is contained in:
@@ -84,6 +84,7 @@ def layer_norm_fwd_kernel(
|
||||
HAS_Z: tl.constexpr,
|
||||
NORM_BEFORE_GATE: tl.constexpr,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
ACTIVATION: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the starting row of X and Y it should compute.
|
||||
row_start = tl.program_id(0) * ROWS_PER_BLOCK
|
||||
@@ -112,7 +113,10 @@ def layer_norm_fwd_kernel(
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
|
||||
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
if ACTIVATION == "swish" or ACTIVATION == "silu":
|
||||
x *= z * tl.sigmoid(z)
|
||||
elif ACTIVATION == "sigmoid":
|
||||
x *= tl.sigmoid(z)
|
||||
|
||||
# Compute mean and variance per row (reduce along axis 1)
|
||||
if not IS_RMS_NORM:
|
||||
@@ -155,7 +159,10 @@ def layer_norm_fwd_kernel(
|
||||
if HAS_Z and NORM_BEFORE_GATE:
|
||||
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
|
||||
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
|
||||
y *= z * tl.sigmoid(z)
|
||||
if ACTIVATION == "swish" or ACTIVATION == "silu":
|
||||
y *= z * tl.sigmoid(z)
|
||||
elif ACTIVATION == "sigmoid":
|
||||
y *= tl.sigmoid(z)
|
||||
|
||||
# Write output
|
||||
tl.store(Y_base, y, mask=mask)
|
||||
@@ -178,6 +185,7 @@ def layer_norm_fwd(
|
||||
group_size: int = None,
|
||||
norm_before_gate: bool = True,
|
||||
is_rms_norm: bool = False,
|
||||
activation: str = "swish",
|
||||
):
|
||||
M, N = x.shape
|
||||
if group_size is None:
|
||||
@@ -232,9 +240,12 @@ def layer_norm_fwd(
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
ROWS_PER_BLOCK=rows_per_block,
|
||||
HAS_BIAS=bias is not None,
|
||||
HAS_Z=z is not None,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps,
|
||||
ACTIVATION=activation,
|
||||
)
|
||||
return out, mean, rstd
|
||||
|
||||
@@ -252,6 +263,7 @@ class LayerNormFn(torch.autograd.Function):
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
activation: str = "swish",
|
||||
):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
|
||||
@@ -277,6 +289,7 @@ class LayerNormFn(torch.autograd.Function):
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
activation=activation,
|
||||
)
|
||||
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
|
||||
ctx.x_shape_og = x_shape_og
|
||||
@@ -284,6 +297,7 @@ class LayerNormFn(torch.autograd.Function):
|
||||
ctx.group_size = group_size
|
||||
ctx.norm_before_gate = norm_before_gate
|
||||
ctx.is_rms_norm = is_rms_norm
|
||||
ctx.activation = activation
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
|
||||
@@ -296,17 +310,25 @@ def layernorm_fn(
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
activation: str = "swish",
|
||||
):
|
||||
return LayerNormFn.apply(
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm, activation
|
||||
)
|
||||
|
||||
|
||||
def rmsnorm_fn(
|
||||
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
activation: str = "swish",
|
||||
):
|
||||
return LayerNormFn.apply(
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, True
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, True, activation
|
||||
)
|
||||
|
||||
|
||||
@@ -359,6 +381,7 @@ class RMSNormGated(nn.Module):
|
||||
norm_before_gate: bool = False,
|
||||
device: torch.device | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
activation: str = "swish",
|
||||
):
|
||||
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
||||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
||||
@@ -366,6 +389,7 @@ class RMSNormGated(nn.Module):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.activation = activation
|
||||
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.register_parameter("bias", None)
|
||||
self.group_size = group_size
|
||||
@@ -385,4 +409,5 @@ class RMSNormGated(nn.Module):
|
||||
eps=self.eps,
|
||||
group_size=self.group_size,
|
||||
norm_before_gate=self.norm_before_gate,
|
||||
activation=self.activation,
|
||||
)
|
||||
|
||||
@@ -592,6 +592,7 @@ class RMSNormGated(CustomOp):
|
||||
eps=self.eps,
|
||||
group_size=self.group_size,
|
||||
norm_before_gate=self.norm_before_gate,
|
||||
activation=self.activation,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -43,7 +44,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
|
||||
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
self.variance_epsilon = eps
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def weight_loader(
|
||||
@@ -56,7 +56,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
|
||||
shard_size = loaded_weight.shape[0] // tp_world
|
||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||
param.data.copy_(loaded_weight[shard])
|
||||
return
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
@@ -102,6 +101,101 @@ class MiniMaxText01RMSNormTP(CustomOp):
|
||||
return q, k
|
||||
|
||||
|
||||
def clear_linear_attention_cache_for_new_sequences(
|
||||
kv_cache: torch.Tensor,
|
||||
state_indices_tensor: torch.Tensor,
|
||||
attn_metadata: LinearAttentionMetadata,
|
||||
) -> None:
|
||||
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||
if num_prefills <= 0:
|
||||
return
|
||||
|
||||
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0)
|
||||
for prefill_idx in range(num_prefills):
|
||||
q_start = attn_metadata.query_start_loc[num_decode_tokens + prefill_idx]
|
||||
q_end = attn_metadata.query_start_loc[num_decode_tokens + prefill_idx + 1]
|
||||
query_len = q_end - q_start
|
||||
context_len = (
|
||||
attn_metadata.seq_lens[num_decode_tokens + prefill_idx] - query_len
|
||||
)
|
||||
if context_len == 0:
|
||||
block_to_clear = state_indices_tensor[num_decode_tokens + prefill_idx]
|
||||
kv_cache[block_to_clear, ...] = 0
|
||||
|
||||
|
||||
def linear_attention_decode(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slope_rate: torch.Tensor,
|
||||
state_indices_tensor: torch.Tensor,
|
||||
q_start: int = 0,
|
||||
q_end: int | None = None,
|
||||
slot_start: int = 0,
|
||||
slot_end: int | None = None,
|
||||
block_size: int = 32,
|
||||
) -> torch.Tensor:
|
||||
q = q[q_start:q_end].unsqueeze(2).contiguous()
|
||||
k = k[q_start:q_end].unsqueeze(2).contiguous()
|
||||
v = v[q_start:q_end].unsqueeze(2).contiguous()
|
||||
slot_id = state_indices_tensor[slot_start:slot_end]
|
||||
return linear_decode_forward_triton(
|
||||
q, k, v, kv_cache, slope_rate, slot_id, block_size
|
||||
)
|
||||
|
||||
|
||||
def linear_attention_prefill_and_mix(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
state_indices_tensor: torch.Tensor,
|
||||
attn_metadata: LinearAttentionMetadata,
|
||||
slope_rate: torch.Tensor,
|
||||
block_size: int,
|
||||
decode_fn: Callable[..., torch.Tensor],
|
||||
prefix_fn: Callable[..., torch.Tensor],
|
||||
layer_idx: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
hidden = []
|
||||
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
|
||||
if _prefill_idx >= len(attn_metadata.query_start_loc):
|
||||
break
|
||||
if _prefill_idx >= len(state_indices_tensor):
|
||||
break
|
||||
offset = attn_metadata.num_decode_tokens
|
||||
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
|
||||
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
|
||||
slot_id = state_indices_tensor[offset + _prefill_idx]
|
||||
qs = q[_start:_end].transpose(0, 1).contiguous()
|
||||
ks = k[_start:_end].transpose(0, 1).contiguous()
|
||||
vs = v[_start:_end].transpose(0, 1).contiguous()
|
||||
slice_layer_cache = kv_cache[slot_id, ...]
|
||||
out_slice = prefix_fn(
|
||||
qs,
|
||||
ks,
|
||||
vs,
|
||||
slice_layer_cache,
|
||||
slope_rate,
|
||||
block_size,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
hidden.append(out_slice.contiguous())
|
||||
|
||||
if attn_metadata.num_decode_tokens > 0:
|
||||
hidden_decode = decode_fn(
|
||||
q, k, v, kv_cache, state_indices_tensor, attn_metadata
|
||||
)
|
||||
hidden.insert(0, hidden_decode)
|
||||
|
||||
if not hidden:
|
||||
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
|
||||
|
||||
hidden = torch.concat(hidden, dim=0).contiguous()
|
||||
return hidden
|
||||
|
||||
|
||||
class MiniMaxText01LinearKernel:
|
||||
@staticmethod
|
||||
def jit_linear_forward_prefix(
|
||||
@@ -258,50 +352,33 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
def _prefill_and_mix_infer(
|
||||
self, q, k, v, kv_cache, state_indices_tensor, attn_metadata
|
||||
):
|
||||
hidden = []
|
||||
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
|
||||
if _prefill_idx >= len(attn_metadata.query_start_loc):
|
||||
break
|
||||
if _prefill_idx >= len(state_indices_tensor):
|
||||
break
|
||||
offset = attn_metadata.num_decode_tokens
|
||||
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
|
||||
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
|
||||
slot_id = state_indices_tensor[offset + _prefill_idx]
|
||||
qs = q[_start:_end].transpose(0, 1).contiguous()
|
||||
ks = k[_start:_end].transpose(0, 1).contiguous()
|
||||
vs = v[_start:_end].transpose(0, 1).contiguous()
|
||||
slice_layer_cache = kv_cache[slot_id, ...]
|
||||
|
||||
out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
|
||||
qs,
|
||||
ks,
|
||||
vs,
|
||||
slice_layer_cache,
|
||||
self.tp_slope,
|
||||
self.BLOCK,
|
||||
layer_idx=self.layer_idx,
|
||||
)
|
||||
hidden.append(out_slice.contiguous())
|
||||
if attn_metadata.num_decode_tokens > 0:
|
||||
hidden_decode = self._decode_infer(
|
||||
q, k, v, kv_cache, state_indices_tensor, attn_metadata
|
||||
)
|
||||
hidden.insert(0, hidden_decode)
|
||||
|
||||
if not hidden:
|
||||
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
|
||||
|
||||
hidden = torch.concat(hidden, dim=0).contiguous()
|
||||
return hidden
|
||||
return linear_attention_prefill_and_mix(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
kv_cache=kv_cache,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
attn_metadata=attn_metadata,
|
||||
slope_rate=self.tp_slope,
|
||||
block_size=self.BLOCK,
|
||||
decode_fn=self._decode_infer,
|
||||
prefix_fn=MiniMaxText01LinearKernel.jit_linear_forward_prefix,
|
||||
layer_idx=self.layer_idx,
|
||||
)
|
||||
|
||||
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata):
|
||||
q = q[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
k = k[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
v = v[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
slot_id = state_indices_tensor[: attn_metadata.num_decodes]
|
||||
hidden = linear_decode_forward_triton(
|
||||
q, k, v, kv_cache, self.tp_slope, slot_id, 32
|
||||
hidden = linear_attention_decode(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
kv_cache,
|
||||
self.tp_slope,
|
||||
state_indices_tensor,
|
||||
q_start=0,
|
||||
q_end=attn_metadata.num_decode_tokens,
|
||||
slot_start=0,
|
||||
slot_end=attn_metadata.num_decodes,
|
||||
block_size=32,
|
||||
)
|
||||
return hidden
|
||||
|
||||
@@ -338,27 +415,9 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
if attn_metadata is not None:
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
|
||||
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||
if num_prefills > 0:
|
||||
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0)
|
||||
for prefill_idx in range(num_prefills):
|
||||
q_start = attn_metadata.query_start_loc[
|
||||
num_decode_tokens + prefill_idx
|
||||
]
|
||||
q_end = attn_metadata.query_start_loc[
|
||||
num_decode_tokens + prefill_idx + 1
|
||||
]
|
||||
query_len = q_end - q_start
|
||||
context_len = (
|
||||
attn_metadata.seq_lens[num_decode_tokens + prefill_idx]
|
||||
- query_len
|
||||
)
|
||||
if context_len == 0:
|
||||
block_to_clear = state_indices_tensor[
|
||||
num_decode_tokens + prefill_idx
|
||||
]
|
||||
kv_cache[block_to_clear, ...] = 0
|
||||
clear_linear_attention_cache_for_new_sequences(
|
||||
kv_cache, state_indices_tensor, attn_metadata
|
||||
)
|
||||
|
||||
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
|
||||
if attn_metadata is None:
|
||||
|
||||
1246
vllm/model_executor/models/bailing_moe_linear.py
Normal file
1246
vllm/model_executor/models/bailing_moe_linear.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -81,6 +81,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),
|
||||
"BailingMoeForCausalLM": ("bailing_moe", "BailingMoeForCausalLM"),
|
||||
"BailingMoeV2ForCausalLM": ("bailing_moe", "BailingMoeV2ForCausalLM"),
|
||||
"BailingMoeV2_5ForCausalLM": ("bailing_moe_linear", "BailingMoeV25ForCausalLM"),
|
||||
"BambaForCausalLM": ("bamba", "BambaForCausalLM"),
|
||||
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
|
||||
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
|
||||
|
||||
Reference in New Issue
Block a user