Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -19,16 +19,21 @@ from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.lightning_attn import (
|
||||
lightning_attention, linear_decode_forward_triton)
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
lightning_attention,
|
||||
linear_decode_forward_triton,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
|
||||
@@ -47,8 +52,7 @@ class MiniMaxText01RMSNormTP(CustomOp):
|
||||
super().__init__()
|
||||
self.tp_world = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.weight = nn.Parameter(torch.ones(int(hidden_size /
|
||||
self.tp_world)))
|
||||
self.weight = nn.Parameter(torch.ones(int(hidden_size / self.tp_world)))
|
||||
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
self.variance_epsilon = eps
|
||||
@@ -75,8 +79,7 @@ class MiniMaxText01RMSNormTP(CustomOp):
|
||||
x = x.to(torch.float32)
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
|
||||
if self.tp_world > 1:
|
||||
variance = tensor_model_parallel_all_reduce(
|
||||
variance) / self.tp_world
|
||||
variance = tensor_model_parallel_all_reduce(variance) / self.tp_world
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x.to(orig_dtype) * self.weight
|
||||
return x
|
||||
@@ -91,17 +94,17 @@ class MiniMaxText01RMSNormTP(CustomOp):
|
||||
|
||||
|
||||
class MiniMaxText01LinearKernel:
|
||||
|
||||
@staticmethod
|
||||
def jit_linear_forward_prefix(q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_caches: torch.Tensor,
|
||||
slope_rate: torch.Tensor,
|
||||
block_size: int,
|
||||
layer_idx: Optional[int] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
|
||||
def jit_linear_forward_prefix(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_caches: torch.Tensor,
|
||||
slope_rate: torch.Tensor,
|
||||
block_size: int,
|
||||
layer_idx: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
slope_rate = slope_rate.to(torch.float32)
|
||||
should_pad_dim = q.dim() == 3
|
||||
if should_pad_dim:
|
||||
@@ -111,26 +114,22 @@ class MiniMaxText01LinearKernel:
|
||||
b, h, n, d = q.shape
|
||||
e = d
|
||||
kv_history = kv_caches.reshape(1, h, d, e).contiguous()
|
||||
output, kv_history = lightning_attention(q,
|
||||
k,
|
||||
v,
|
||||
slope_rate,
|
||||
block_size=block_size,
|
||||
kv_history=kv_history)
|
||||
output, kv_history = lightning_attention(
|
||||
q, k, v, slope_rate, block_size=block_size, kv_history=kv_history
|
||||
)
|
||||
kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e))
|
||||
assert output.shape[0] == 1, "batch size must be 1"
|
||||
return rearrange(output.squeeze(0), "h n d -> n (h d)")
|
||||
|
||||
|
||||
class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
|
||||
@property
|
||||
def mamba_type(self) -> str:
|
||||
return "linear_attention"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.linear_attn import (
|
||||
LinearAttentionBackend)
|
||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
|
||||
|
||||
return LinearAttentionBackend
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype]:
|
||||
@@ -143,9 +142,8 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, int, int], ...]:
|
||||
return MambaStateShapeCalculator.linear_attention_state_shape(
|
||||
num_heads=self.num_heads,
|
||||
tp_size=self.tp_size,
|
||||
head_dim=self.head_dim)
|
||||
num_heads=self.num_heads, tp_size=self.tp_size, head_dim=self.head_dim
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -209,16 +207,16 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
eps=1e-5,
|
||||
)
|
||||
|
||||
slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(
|
||||
self.num_heads)
|
||||
slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(self.num_heads)
|
||||
if num_hidden_layer <= 1:
|
||||
self.slope_rate = slope_rate * (1 + 1e-5)
|
||||
else:
|
||||
self.slope_rate = slope_rate * (1 - layer_idx /
|
||||
(num_hidden_layer - 1) + 1e-5)
|
||||
self.tp_slope = self.slope_rate[self.tp_rank *
|
||||
self.tp_heads:(self.tp_rank + 1) *
|
||||
self.tp_heads].contiguous()
|
||||
self.slope_rate = slope_rate * (
|
||||
1 - layer_idx / (num_hidden_layer - 1) + 1e-5
|
||||
)
|
||||
self.tp_slope = self.slope_rate[
|
||||
self.tp_rank * self.tp_heads : (self.tp_rank + 1) * self.tp_heads
|
||||
].contiguous()
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
@@ -226,36 +224,36 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
@staticmethod
|
||||
def weight_direct_load(param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor) -> None:
|
||||
def weight_direct_load(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||
assert param.size() == loaded_weight.size()
|
||||
param.data.copy_(loaded_weight)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _build_slope_tensor(n_attention_heads: int):
|
||||
|
||||
def get_slopes(n):
|
||||
|
||||
def get_slopes_power_of_2(n):
|
||||
start = 2**(-(2**-(math.log2(n) - 3)))
|
||||
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
||||
ratio = start
|
||||
return [start * ratio**i for i in range(n)]
|
||||
|
||||
if math.log2(n).is_integer():
|
||||
return get_slopes_power_of_2(n)
|
||||
else:
|
||||
closest_power_of_2 = 2**math.floor(math.log2(n))
|
||||
return (get_slopes_power_of_2(closest_power_of_2) + get_slopes(
|
||||
2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(n))
|
||||
return (
|
||||
get_slopes_power_of_2(closest_power_of_2)
|
||||
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
|
||||
)
|
||||
|
||||
slopes = torch.tensor(get_slopes(n_attention_heads),
|
||||
dtype=torch.float32).reshape(
|
||||
n_attention_heads, 1, 1)
|
||||
slopes = torch.tensor(
|
||||
get_slopes(n_attention_heads), dtype=torch.float32
|
||||
).reshape(n_attention_heads, 1, 1)
|
||||
return slopes
|
||||
|
||||
def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
|
||||
attn_metadata):
|
||||
def _prefill_and_mix_infer(
|
||||
self, q, k, v, kv_cache, state_indices_tensor, attn_metadata
|
||||
):
|
||||
hidden = []
|
||||
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
|
||||
if _prefill_idx >= len(attn_metadata.query_start_loc):
|
||||
@@ -278,12 +276,13 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
slice_layer_cache,
|
||||
self.tp_slope,
|
||||
self.BLOCK,
|
||||
layer_idx=self.layer_idx)
|
||||
layer_idx=self.layer_idx,
|
||||
)
|
||||
hidden.append(out_slice.contiguous())
|
||||
if attn_metadata.num_decode_tokens > 0:
|
||||
hidden_decode = self._decode_infer(q, k, v, kv_cache,
|
||||
state_indices_tensor,
|
||||
attn_metadata)
|
||||
hidden_decode = self._decode_infer(
|
||||
q, k, v, kv_cache, state_indices_tensor, attn_metadata
|
||||
)
|
||||
hidden.insert(0, hidden_decode)
|
||||
|
||||
if not hidden:
|
||||
@@ -292,18 +291,19 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
hidden = torch.concat(hidden, dim=0).contiguous()
|
||||
return hidden
|
||||
|
||||
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
|
||||
attn_metadata):
|
||||
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
slot_id = state_indices_tensor[:attn_metadata.num_decodes]
|
||||
hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
|
||||
slot_id, 32)
|
||||
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata):
|
||||
q = q[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
k = k[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
v = v[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
slot_id = state_indices_tensor[: attn_metadata.num_decodes]
|
||||
hidden = linear_decode_forward_triton(
|
||||
q, k, v, kv_cache, self.tp_slope, slot_id, 32
|
||||
)
|
||||
return hidden
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
||||
positions: torch.Tensor) -> None:
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor
|
||||
) -> None:
|
||||
torch.ops.vllm.linear_attention(
|
||||
hidden_states,
|
||||
output,
|
||||
@@ -311,16 +311,18 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
||||
positions: torch.Tensor) -> None:
|
||||
def _forward(
|
||||
self, hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor
|
||||
) -> None:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, LinearAttentionMetadata)
|
||||
num_actual_tokens = attn_metadata.num_prefill_tokens + \
|
||||
attn_metadata.num_decode_tokens
|
||||
num_actual_tokens = (
|
||||
attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
|
||||
)
|
||||
else:
|
||||
num_actual_tokens = hidden_states.shape[0]
|
||||
|
||||
@@ -335,35 +337,39 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
|
||||
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||
if num_prefills > 0:
|
||||
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens",
|
||||
0)
|
||||
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0)
|
||||
for prefill_idx in range(num_prefills):
|
||||
q_start = attn_metadata.query_start_loc[num_decode_tokens +
|
||||
prefill_idx]
|
||||
q_end = attn_metadata.query_start_loc[num_decode_tokens +
|
||||
prefill_idx + 1]
|
||||
q_start = attn_metadata.query_start_loc[
|
||||
num_decode_tokens + prefill_idx
|
||||
]
|
||||
q_end = attn_metadata.query_start_loc[
|
||||
num_decode_tokens + prefill_idx + 1
|
||||
]
|
||||
query_len = q_end - q_start
|
||||
context_len = attn_metadata.seq_lens[
|
||||
num_decode_tokens + prefill_idx] - query_len
|
||||
context_len = (
|
||||
attn_metadata.seq_lens[num_decode_tokens + prefill_idx]
|
||||
- query_len
|
||||
)
|
||||
if context_len == 0:
|
||||
block_to_clear = state_indices_tensor[num_decode_tokens
|
||||
+ prefill_idx]
|
||||
block_to_clear = state_indices_tensor[
|
||||
num_decode_tokens + prefill_idx
|
||||
]
|
||||
kv_cache[block_to_clear, ...] = 0
|
||||
|
||||
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
|
||||
if attn_metadata is None:
|
||||
hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]),
|
||||
device=q.device,
|
||||
dtype=q.dtype)
|
||||
hidden = torch.empty(
|
||||
(q.shape[0], q.shape[1] * q.shape[2]), device=q.device, dtype=q.dtype
|
||||
)
|
||||
else:
|
||||
if not decode_only:
|
||||
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
|
||||
state_indices_tensor,
|
||||
attn_metadata)
|
||||
hidden = self._prefill_and_mix_infer(
|
||||
q, k, v, kv_cache, state_indices_tensor, attn_metadata
|
||||
)
|
||||
else:
|
||||
hidden = self._decode_infer(q, k, v, kv_cache,
|
||||
state_indices_tensor,
|
||||
attn_metadata)
|
||||
hidden = self._decode_infer(
|
||||
q, k, v, kv_cache, state_indices_tensor, attn_metadata
|
||||
)
|
||||
hidden = self.norm._forward(hidden)
|
||||
gate, _ = self.output_gate(hidden_states[:num_actual_tokens])
|
||||
hidden = F.sigmoid(gate) * hidden
|
||||
@@ -380,9 +386,7 @@ def linear_attention(
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self._forward(hidden_states=hidden_states,
|
||||
output=output,
|
||||
positions=positions)
|
||||
self._forward(hidden_states=hidden_states, output=output, positions=positions)
|
||||
|
||||
|
||||
def linear_attention_fake(
|
||||
|
||||
@@ -12,20 +12,30 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
causal_conv1d_fn,
|
||||
causal_conv1d_update,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
selective_scan_fn,
|
||||
selective_state_update,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
|
||||
@@ -44,22 +54,24 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
**selective** state spaces)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
ssm_state_size: int,
|
||||
conv_kernel_size: int,
|
||||
intermediate_size: int,
|
||||
time_step_rank: int,
|
||||
use_conv_bias: bool,
|
||||
use_bias: bool,
|
||||
use_rms_norm: bool,
|
||||
rms_norm_has_weight: bool = True,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation="silu",
|
||||
is_lora_enabled: bool = False,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
ssm_state_size: int,
|
||||
conv_kernel_size: int,
|
||||
intermediate_size: int,
|
||||
time_step_rank: int,
|
||||
use_conv_bias: bool,
|
||||
use_bias: bool,
|
||||
use_rms_norm: bool,
|
||||
rms_norm_has_weight: bool = True,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation="silu",
|
||||
is_lora_enabled: bool = False,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.time_step_rank = time_step_rank
|
||||
self.ssm_state_size = ssm_state_size
|
||||
@@ -80,9 +92,9 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
# doesn't allow to override it
|
||||
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
||||
|
||||
self.in_proj = MergedColumnParallelLinear(hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=use_bias)
|
||||
self.in_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2, bias=use_bias
|
||||
)
|
||||
|
||||
# selective projection used to make dt, B and C input dependent
|
||||
self.x_proj = RowParallelLinear(
|
||||
@@ -93,17 +105,18 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
# time step projection (discretization) -
|
||||
# In the forward we need to apply dt_proj without the bias,
|
||||
# as the bias is added in the selective scan kernel.
|
||||
self.dt_proj = ColumnParallelLinear(time_step_rank,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
skip_bias_add=True)
|
||||
self.dt_proj = ColumnParallelLinear(
|
||||
time_step_rank, intermediate_size, bias=True, skip_bias_add=True
|
||||
)
|
||||
|
||||
def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
param.data.copy_(
|
||||
loaded_weight.data.split(loaded_weight.shape[0] // tp_size,
|
||||
dim=0)[tp_rank])
|
||||
loaded_weight.data.split(loaded_weight.shape[0] // tp_size, dim=0)[
|
||||
tp_rank
|
||||
]
|
||||
)
|
||||
|
||||
def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
|
||||
weight_loader(param, -torch.exp(loaded_weight.float()))
|
||||
@@ -114,7 +127,8 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
intermediate_size // tp_size,
|
||||
ssm_state_size,
|
||||
dtype=torch.float32,
|
||||
))
|
||||
)
|
||||
)
|
||||
self.D = nn.Parameter(torch.ones(intermediate_size // tp_size))
|
||||
|
||||
set_weight_attrs(self.D, {"weight_loader": weight_loader})
|
||||
@@ -127,23 +141,35 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
input_is_parallel=True,
|
||||
)
|
||||
|
||||
self.dt_layernorm = RMSNorm(
|
||||
time_step_rank,
|
||||
eps=rms_norm_eps,
|
||||
has_weight=rms_norm_has_weight,
|
||||
) if use_rms_norm else None
|
||||
self.dt_layernorm = (
|
||||
RMSNorm(
|
||||
time_step_rank,
|
||||
eps=rms_norm_eps,
|
||||
has_weight=rms_norm_has_weight,
|
||||
)
|
||||
if use_rms_norm
|
||||
else None
|
||||
)
|
||||
|
||||
self.b_layernorm = RMSNorm(
|
||||
ssm_state_size,
|
||||
eps=rms_norm_eps,
|
||||
has_weight=rms_norm_has_weight,
|
||||
) if use_rms_norm else None
|
||||
self.b_layernorm = (
|
||||
RMSNorm(
|
||||
ssm_state_size,
|
||||
eps=rms_norm_eps,
|
||||
has_weight=rms_norm_has_weight,
|
||||
)
|
||||
if use_rms_norm
|
||||
else None
|
||||
)
|
||||
|
||||
self.c_layernorm = RMSNorm(
|
||||
ssm_state_size,
|
||||
eps=rms_norm_eps,
|
||||
has_weight=rms_norm_has_weight,
|
||||
) if use_rms_norm else None
|
||||
self.c_layernorm = (
|
||||
RMSNorm(
|
||||
ssm_state_size,
|
||||
eps=rms_norm_eps,
|
||||
has_weight=rms_norm_has_weight,
|
||||
)
|
||||
if use_rms_norm
|
||||
else None
|
||||
)
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
@@ -157,7 +183,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
self.prefix = prefix
|
||||
|
||||
def _ssm_transform(
|
||||
self, x: torch.Tensor
|
||||
self, x: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if self.is_lora_enabled:
|
||||
# Lora kernel requires contiguous tensor.
|
||||
@@ -167,7 +193,8 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
time_step, B, C = torch.split(
|
||||
ssm_params,
|
||||
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
|
||||
dim=-1)
|
||||
dim=-1,
|
||||
)
|
||||
if self.use_rms_norm:
|
||||
assert self.dt_layernorm is not None
|
||||
assert self.b_layernorm is not None
|
||||
@@ -185,8 +212,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
def forward_native(self, hidden_states: torch.Tensor,
|
||||
output: torch.Tensor):
|
||||
def forward_native(self, hidden_states: torch.Tensor, output: torch.Tensor):
|
||||
pass
|
||||
|
||||
def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor):
|
||||
@@ -232,8 +258,9 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
hidden_states_BC, gate = projected_states.chunk(2, dim=-2)
|
||||
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
conv_weights = self.conv1d.weight.view(
|
||||
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# V1 profile run
|
||||
@@ -281,10 +308,12 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
query_start_loc=query_start_loc_p)
|
||||
query_start_loc=query_start_loc_p,
|
||||
)
|
||||
# 3. State Space Model sequence transformations.
|
||||
discrete_time_step_p, B_p, C_p = self._ssm_transform(
|
||||
conv_out_p.transpose(-2, -1))
|
||||
conv_out_p.transpose(-2, -1)
|
||||
)
|
||||
time_proj_bias = self._time_proj_bias()
|
||||
|
||||
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
|
||||
@@ -301,7 +330,8 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
delta_softplus=True,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
has_initial_state=has_initial_states_p,
|
||||
query_start_loc=query_start_loc_p)
|
||||
query_start_loc=query_start_loc_p,
|
||||
)
|
||||
ssm_outputs.append(scan_out_p)
|
||||
|
||||
if has_decode:
|
||||
@@ -312,39 +342,42 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=state_indices_tensor_d).transpose(0, 1)
|
||||
conv_state_indices=state_indices_tensor_d,
|
||||
).transpose(0, 1)
|
||||
|
||||
# 3. State Space Model sequence transformation.
|
||||
discrete_time_step_d, B_d, C_d = self._ssm_transform(
|
||||
conv_out_d.transpose(-2, -1))
|
||||
conv_out_d.transpose(-2, -1)
|
||||
)
|
||||
time_proj_bias = self._time_proj_bias()
|
||||
|
||||
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
|
||||
scan_outputs_d = torch.empty_like(
|
||||
hidden_states_BC_d.transpose(0, 1))
|
||||
selective_state_update(ssm_state,
|
||||
conv_out_d.transpose(0, 1),
|
||||
discrete_time_step_d.transpose(0, 1),
|
||||
self.A,
|
||||
B_d,
|
||||
C_d,
|
||||
self.D,
|
||||
gate_d.transpose(0, 1),
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor_d,
|
||||
out=scan_outputs_d)
|
||||
scan_outputs_d = torch.empty_like(hidden_states_BC_d.transpose(0, 1))
|
||||
selective_state_update(
|
||||
ssm_state,
|
||||
conv_out_d.transpose(0, 1),
|
||||
discrete_time_step_d.transpose(0, 1),
|
||||
self.A,
|
||||
B_d,
|
||||
C_d,
|
||||
self.D,
|
||||
gate_d.transpose(0, 1),
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor_d,
|
||||
out=scan_outputs_d,
|
||||
)
|
||||
scan_outputs_d = scan_outputs_d.transpose(0, 1)
|
||||
|
||||
ssm_outputs.insert(0, scan_outputs_d)
|
||||
|
||||
scan_outputs_combined = ssm_outputs[0] if len(
|
||||
ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
|
||||
scan_outputs_combined = (
|
||||
ssm_outputs[0] if len(ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
|
||||
)
|
||||
|
||||
# 5. Final output projection
|
||||
if self.is_lora_enabled: # Lora kernel requires contiguous tensor.
|
||||
scan_outputs_combined = scan_outputs_combined.transpose(
|
||||
-2, -1).contiguous()
|
||||
scan_outputs_combined = scan_outputs_combined.transpose(-2, -1).contiguous()
|
||||
out = self.out_proj(scan_outputs_combined)[0]
|
||||
else:
|
||||
out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0]
|
||||
@@ -373,8 +406,8 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
return "mamba1"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.mamba1_attn import (
|
||||
Mamba1AttentionBackend)
|
||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
|
||||
|
||||
return Mamba1AttentionBackend
|
||||
|
||||
def _time_proj_bias(self) -> Optional[torch.Tensor]:
|
||||
@@ -406,27 +439,34 @@ def split_batch_to_prefill_and_decode(
|
||||
num_decodes: int,
|
||||
num_padded_decodes: int,
|
||||
) -> PrefillDecodeSplit:
|
||||
|
||||
num_actual_tokens = num_prefill_tokens + num_padded_decodes
|
||||
|
||||
# In v1, decode tokens come first, then prefill tokens.
|
||||
hidden_states_BC_d, hidden_states_BC_p = torch.split(
|
||||
hidden_states_BC[..., :num_actual_tokens],
|
||||
[num_padded_decodes, num_prefill_tokens],
|
||||
dim=-1)
|
||||
gate_d, gate_p = torch.split(gate[..., :num_actual_tokens],
|
||||
[num_padded_decodes, num_prefill_tokens],
|
||||
dim=-1)
|
||||
dim=-1,
|
||||
)
|
||||
gate_d, gate_p = torch.split(
|
||||
gate[..., :num_actual_tokens], [num_padded_decodes, num_prefill_tokens], dim=-1
|
||||
)
|
||||
|
||||
# num_padded_decodes accounts for CUDA graph padding when applicable
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor[:num_padded_decodes + num_prefills],
|
||||
state_indices_tensor[: num_padded_decodes + num_prefills],
|
||||
[num_padded_decodes, num_prefills],
|
||||
dim=0)
|
||||
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
|
||||
num_padded_decodes if num_prefills > 0 else None)
|
||||
has_initial_states_p = has_initial_states[-num_prefills:] if (
|
||||
has_initial_states is not None and num_prefills > 0) else None
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (
|
||||
query_start_loc[-num_prefills - 1 :] - num_padded_decodes
|
||||
if num_prefills > 0
|
||||
else None
|
||||
)
|
||||
has_initial_states_p = (
|
||||
has_initial_states[-num_prefills:]
|
||||
if (has_initial_states is not None and num_prefills > 0)
|
||||
else None
|
||||
)
|
||||
|
||||
return PrefillDecodeSplit(
|
||||
hidden_states_BC_p=hidden_states_BC_p,
|
||||
|
||||
@@ -11,28 +11,40 @@ from torch import nn
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed import (
|
||||
divide,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
causal_conv1d_fn,
|
||||
causal_conv1d_update,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_state_update)
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_state_update
|
||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||
mamba_chunk_scan_combined_varlen)
|
||||
mamba_chunk_scan_combined_varlen,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
LoaderFunction, composed_weight_loader, sharded_weight_loader)
|
||||
LoaderFunction,
|
||||
composed_weight_loader,
|
||||
sharded_weight_loader,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
|
||||
@@ -43,12 +55,13 @@ from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
|
||||
# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated
|
||||
@CustomOp.register("mixer2_gated_rms_norm")
|
||||
class Mixer2RMSNormGated(CustomOp):
|
||||
|
||||
def __init__(self,
|
||||
full_hidden_size: int,
|
||||
full_n_groups: int,
|
||||
use_rms_norm: bool = True,
|
||||
eps: float = 1e-6):
|
||||
def __init__(
|
||||
self,
|
||||
full_hidden_size: int,
|
||||
full_n_groups: int,
|
||||
use_rms_norm: bool = True,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
@@ -62,13 +75,13 @@ class Mixer2RMSNormGated(CustomOp):
|
||||
if self.use_rms_norm:
|
||||
# Register norm weight only if we're actually applying RMSNorm
|
||||
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
|
||||
set_weight_attrs(self.weight,
|
||||
{"weight_loader": sharded_weight_loader(0)})
|
||||
set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
|
||||
else:
|
||||
# Avoid checkpoint mismatch by skipping unused parameter
|
||||
self.register_parameter("weight", None)
|
||||
assert (self.full_hidden_size % self.tp_size == 0
|
||||
), "Tensor parallel world size must divide hidden size."
|
||||
assert self.full_hidden_size % self.tp_size == 0, (
|
||||
"Tensor parallel world size must divide hidden size."
|
||||
)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
@@ -111,8 +124,7 @@ class Mixer2RMSNormGated(CustomOp):
|
||||
group_count = hidden_dim // self.group_size
|
||||
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
|
||||
variance = x_grouped.pow(2).mean(-1, keepdim=True)
|
||||
x_grouped = x_grouped * torch.rsqrt(variance +
|
||||
self.variance_epsilon)
|
||||
x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x_grouped.view(*prefix_dims, hidden_dim)
|
||||
|
||||
if redundant_tp:
|
||||
@@ -130,18 +142,19 @@ class Mixer2RMSNormGated(CustomOp):
|
||||
input_dtype = x.dtype
|
||||
if not self.use_rms_norm:
|
||||
# Keep gate in float32 for numerical stability during silu
|
||||
return x * nn.functional.silu(gate.to(
|
||||
torch.float32)).to(input_dtype)
|
||||
return x * nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
|
||||
|
||||
if (((self.n_groups % self.tp_size) != 0) or self.n_groups != 1):
|
||||
if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
|
||||
return self.forward_native(x, gate)
|
||||
|
||||
return rms_norm_gated(x,
|
||||
self.weight.data,
|
||||
bias=None,
|
||||
z=gate,
|
||||
eps=self.variance_epsilon,
|
||||
norm_before_gate=False)
|
||||
return rms_norm_gated(
|
||||
x,
|
||||
self.weight.data,
|
||||
bias=None,
|
||||
z=gate,
|
||||
eps=self.variance_epsilon,
|
||||
norm_before_gate=False,
|
||||
)
|
||||
|
||||
|
||||
def mamba_v2_sharded_weight_loader(
|
||||
@@ -156,7 +169,6 @@ def mamba_v2_sharded_weight_loader(
|
||||
"""
|
||||
|
||||
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||
|
||||
# - track boundary of (sharded) param, and loaded_weight, respectively
|
||||
boundary, loaded_boundary = 0, 0
|
||||
|
||||
@@ -191,11 +203,12 @@ def mamba_v2_sharded_weight_loader(
|
||||
# seem to handle slices well.
|
||||
# https://github.com/python/mypy/issues/2410
|
||||
param.data[
|
||||
boundary:(boundary + take),
|
||||
... # type: ignore[misc]
|
||||
] = loaded_weight[loaded_start_idx:(loaded_start_idx +
|
||||
take) # type: ignore[misc]
|
||||
] # type: ignore[misc]
|
||||
boundary : (boundary + take), ... # type: ignore[misc]
|
||||
] = loaded_weight[
|
||||
loaded_start_idx : (
|
||||
loaded_start_idx + take
|
||||
) # type: ignore[misc]
|
||||
] # type: ignore[misc]
|
||||
|
||||
# move indexing boundaries
|
||||
boundary += shard_size
|
||||
@@ -217,23 +230,25 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
**selective** state spaces)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
ssm_state_size: int,
|
||||
conv_kernel_size: int,
|
||||
intermediate_size: int,
|
||||
use_conv_bias: bool,
|
||||
use_bias: bool,
|
||||
n_groups: int = 1,
|
||||
num_heads: int = 128,
|
||||
head_dim: int = 64,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation: str = "silu",
|
||||
use_rms_norm: bool = True,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
ssm_state_size: int,
|
||||
conv_kernel_size: int,
|
||||
intermediate_size: int,
|
||||
use_conv_bias: bool,
|
||||
use_bias: bool,
|
||||
n_groups: int = 1,
|
||||
num_heads: int = 128,
|
||||
head_dim: int = 64,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation: str = "silu",
|
||||
use_rms_norm: bool = True,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# For TP, the sharding plan is as follows:
|
||||
@@ -253,15 +268,18 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
assert (num_heads % self.tp_size == 0
|
||||
), "Tensor parallel world size must divide num heads."
|
||||
assert num_heads % self.tp_size == 0, (
|
||||
"Tensor parallel world size must divide num heads."
|
||||
)
|
||||
|
||||
assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
|
||||
"If tensor parallel world size does not divide num_groups, "
|
||||
"then num_groups must equal 1.")
|
||||
"then num_groups must equal 1."
|
||||
)
|
||||
|
||||
assert (n_groups % self.tp_size == 0) or self.tp_size == 1 or \
|
||||
quant_config is None, (
|
||||
assert (
|
||||
(n_groups % self.tp_size == 0) or self.tp_size == 1 or quant_config is None
|
||||
), (
|
||||
"Tensor parallel currently supported for quantized models only "
|
||||
"if tensor parallel world size divides num groups."
|
||||
)
|
||||
@@ -280,7 +298,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# - but if n_groups cannot divide tp_size, we need to
|
||||
# extend some extra groups
|
||||
groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
|
||||
n_groups, self.tp_size)
|
||||
n_groups, self.tp_size
|
||||
)
|
||||
self.n_groups = n_groups + groups
|
||||
|
||||
self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
|
||||
@@ -340,8 +359,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# to the head shards
|
||||
group_shard_settings = (
|
||||
self.groups_ssm_state_size, # expected model size
|
||||
(self.n_groups - n_groups) *
|
||||
self.ssm_state_size, # extra dims assigned
|
||||
(self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned
|
||||
n_groups == 1, # if there was only one group
|
||||
)
|
||||
intermediate_settings = (intermediate_size, 0, False)
|
||||
@@ -355,8 +373,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
set_weight_attrs(
|
||||
self.conv1d.bias,
|
||||
{
|
||||
"weight_loader":
|
||||
mamba_v2_sharded_weight_loader(
|
||||
"weight_loader": mamba_v2_sharded_weight_loader(
|
||||
[
|
||||
intermediate_settings,
|
||||
group_shard_settings,
|
||||
@@ -372,8 +389,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
set_weight_attrs(
|
||||
self.conv1d.weight,
|
||||
{
|
||||
"weight_loader":
|
||||
mamba_v2_sharded_weight_loader(
|
||||
"weight_loader": mamba_v2_sharded_weight_loader(
|
||||
[
|
||||
intermediate_settings,
|
||||
group_shard_settings,
|
||||
@@ -391,8 +407,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
set_weight_attrs(
|
||||
self.in_proj.weight,
|
||||
{
|
||||
"weight_loader":
|
||||
mamba_v2_sharded_weight_loader(
|
||||
"weight_loader": mamba_v2_sharded_weight_loader(
|
||||
[
|
||||
intermediate_settings, # for gate
|
||||
intermediate_settings,
|
||||
@@ -418,17 +433,18 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
torch.empty(
|
||||
divide(num_heads, self.tp_size),
|
||||
dtype=torch.float32,
|
||||
))
|
||||
)
|
||||
)
|
||||
self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
|
||||
self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
|
||||
self.use_rms_norm = use_rms_norm
|
||||
|
||||
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
|
||||
a_weight_loader = composed_weight_loader(
|
||||
sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
|
||||
sharded_weight_loader(0), lambda x: -torch.exp(x.float())
|
||||
)
|
||||
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
|
||||
set_weight_attrs(self.dt_bias,
|
||||
{"weight_loader": sharded_weight_loader(0)})
|
||||
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
@@ -439,10 +455,9 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
self.norm = Mixer2RMSNormGated(intermediate_size,
|
||||
n_groups,
|
||||
self.use_rms_norm,
|
||||
eps=rms_norm_eps)
|
||||
self.norm = Mixer2RMSNormGated(
|
||||
intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps
|
||||
)
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
@@ -525,8 +540,9 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
conv_weights = self.conv1d.weight.view(
|
||||
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
|
||||
)
|
||||
|
||||
# - get hidden_states, B and C after depthwise convolution.
|
||||
split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
|
||||
@@ -541,10 +557,10 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
|
||||
if attn_metadata is None:
|
||||
# profile run
|
||||
hidden_states_B_C = (hidden_states_B_C.transpose(
|
||||
0, 1).clone().transpose(0, 1)).contiguous()
|
||||
hidden_states, _B, _C = split_hidden_states_B_C_fn(
|
||||
hidden_states_B_C)
|
||||
hidden_states_B_C = (
|
||||
hidden_states_B_C.transpose(0, 1).clone().transpose(0, 1)
|
||||
).contiguous()
|
||||
hidden_states, _B, _C = split_hidden_states_B_C_fn(hidden_states_B_C)
|
||||
hidden_states = self.norm(hidden_states, gate)
|
||||
out, _ = self.out_proj(hidden_states)
|
||||
return out
|
||||
@@ -580,11 +596,11 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# If prefix caching is enabled, retrieve the relevant variables
|
||||
# for prefill and decode
|
||||
last_state_idx_d, last_state_idx_p = torch.split(
|
||||
attn_metadata.last_state_idx, [num_decodes, num_prefills],
|
||||
dim=0)
|
||||
attn_metadata.last_state_idx, [num_decodes, num_prefills], dim=0
|
||||
)
|
||||
current_last_idx_d, current_last_idx_p = torch.split(
|
||||
attn_metadata.current_last_idx, [num_decodes, num_prefills],
|
||||
dim=0)
|
||||
attn_metadata.current_last_idx, [num_decodes, num_prefills], dim=0
|
||||
)
|
||||
# Prefill-only variables:
|
||||
current_first_idx_p = attn_metadata.current_first_idx_p
|
||||
context_lens_p = attn_metadata.context_lens_p
|
||||
@@ -600,7 +616,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
preallocated_ssm_out = torch.empty(
|
||||
[
|
||||
num_prefill_tokens + num_decodes,
|
||||
(self.num_heads // self.tp_size) * self.head_dim
|
||||
(self.num_heads // self.tp_size) * self.head_dim,
|
||||
],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
@@ -626,7 +642,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# "state_indices_tensor_p"), it will write additional cache
|
||||
# states aligned at "block_size_to_align".
|
||||
x = hidden_states_B_C_p.transpose(
|
||||
0, 1) # this is the form that causal-conv see
|
||||
0, 1
|
||||
) # this is the form that causal-conv see
|
||||
hidden_states_B_C_p = causal_conv1d_fn(
|
||||
x,
|
||||
conv_weights,
|
||||
@@ -641,34 +658,34 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
context_lens=context_lens_p,
|
||||
block_size_to_align=mamba_block_size,
|
||||
metadata=attn_metadata,
|
||||
query_start_loc=query_start_loc_p).transpose(
|
||||
0, 1)[:num_prefill_tokens]
|
||||
query_start_loc=query_start_loc_p,
|
||||
).transpose(0, 1)[:num_prefill_tokens]
|
||||
|
||||
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(
|
||||
hidden_states_B_C_p)
|
||||
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(hidden_states_B_C_p)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
initial_states = None
|
||||
if (has_initial_states_p is not None and prep_initial_states):
|
||||
if has_initial_states_p is not None and prep_initial_states:
|
||||
kernel_ssm_indices = state_indices_tensor_p
|
||||
if prefix_caching_enabled:
|
||||
kernel_ssm_indices = state_indices_tensor_p.gather(
|
||||
1, last_state_idx_p.unsqueeze(1)).squeeze(1)
|
||||
1, last_state_idx_p.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
initial_states = torch.where(
|
||||
has_initial_states_p[:, None, None, None],
|
||||
ssm_state[kernel_ssm_indices], 0)
|
||||
ssm_state[kernel_ssm_indices],
|
||||
0,
|
||||
)
|
||||
|
||||
# NOTE: final output is an in-place update of out tensor
|
||||
varlen_states = mamba_chunk_scan_combined_varlen(
|
||||
hidden_states_p.view(num_prefill_tokens,
|
||||
self.num_heads // self.tp_size,
|
||||
self.head_dim),
|
||||
hidden_states_p.view(
|
||||
num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim
|
||||
),
|
||||
dt_p,
|
||||
self.A,
|
||||
B_p.view(num_prefill_tokens, self.n_groups // self.tp_size,
|
||||
-1),
|
||||
C_p.view(num_prefill_tokens, self.n_groups // self.tp_size,
|
||||
-1),
|
||||
B_p.view(num_prefill_tokens, self.n_groups // self.tp_size, -1),
|
||||
C_p.view(num_prefill_tokens, self.n_groups // self.tp_size, -1),
|
||||
chunk_size=chunk_size,
|
||||
D=self.D,
|
||||
z=None,
|
||||
@@ -681,18 +698,19 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
return_intermediate_states=prefix_caching_enabled,
|
||||
dt_softplus=True,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1,
|
||||
self.head_dim),
|
||||
state_dtype=ssm_state.dtype)
|
||||
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim),
|
||||
state_dtype=ssm_state.dtype,
|
||||
)
|
||||
|
||||
if prefix_caching_enabled:
|
||||
# Save states for sequences with more than just the final state:
|
||||
n_blocks_to_fill = current_last_idx_p - current_first_idx_p
|
||||
for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1):
|
||||
cache_blocks_to_fill = state_indices_tensor_p[
|
||||
seq_idx, current_first_idx_p[seq_idx]:
|
||||
current_first_idx_p[seq_idx] +
|
||||
n_blocks_to_fill[seq_idx]]
|
||||
seq_idx,
|
||||
current_first_idx_p[seq_idx] : current_first_idx_p[seq_idx]
|
||||
+ n_blocks_to_fill[seq_idx],
|
||||
]
|
||||
# chunks = [0 1 2 3 4 5 6 ...]
|
||||
# First aligned chunk would typically be:
|
||||
# mamba_block_size = 1024, chunk_size = 256
|
||||
@@ -704,22 +722,33 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2)
|
||||
# e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3)
|
||||
chunk_stride = mamba_block_size // chunk_size
|
||||
first_aligned_chunk = \
|
||||
torch.concat([torch.zeros(1, \
|
||||
dtype=last_chunk_indices_p.dtype, \
|
||||
device=last_chunk_indices_p.device), \
|
||||
last_chunk_indices_p + 1])[seq_idx] \
|
||||
+ chunk_stride - 1 \
|
||||
- last_computed_offset_p[seq_idx] // chunk_size
|
||||
first_aligned_chunk = (
|
||||
torch.concat(
|
||||
[
|
||||
torch.zeros(
|
||||
1,
|
||||
dtype=last_chunk_indices_p.dtype,
|
||||
device=last_chunk_indices_p.device,
|
||||
),
|
||||
last_chunk_indices_p + 1,
|
||||
]
|
||||
)[seq_idx]
|
||||
+ chunk_stride
|
||||
- 1
|
||||
- last_computed_offset_p[seq_idx] // chunk_size
|
||||
)
|
||||
from_where = varlen_states[
|
||||
first_aligned_chunk:first_aligned_chunk +
|
||||
n_blocks_to_fill[seq_idx] * chunk_stride:chunk_stride]
|
||||
first_aligned_chunk : first_aligned_chunk
|
||||
+ n_blocks_to_fill[seq_idx] * chunk_stride : chunk_stride
|
||||
]
|
||||
ssm_state[cache_blocks_to_fill] = from_where
|
||||
|
||||
#For all seqs, store the last state (Note: might be partial):
|
||||
ssm_state[state_indices_tensor_p.gather(1,
|
||||
current_last_idx_p.unsqueeze(1)).squeeze(1)] = \
|
||||
varlen_states[last_chunk_indices_p]
|
||||
# For all seqs, store the last state (Note: might be partial):
|
||||
ssm_state[
|
||||
state_indices_tensor_p.gather(
|
||||
1, current_last_idx_p.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
] = varlen_states[last_chunk_indices_p]
|
||||
else:
|
||||
# update ssm states
|
||||
# - varlen state is a (num_prefills, nheads, headdim, dstate)
|
||||
@@ -729,13 +758,13 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# Process decode requests
|
||||
if has_decode:
|
||||
if prefix_caching_enabled:
|
||||
state_indices_tensor_d_input = \
|
||||
state_indices_tensor_d.gather(1,
|
||||
last_state_idx_d.unsqueeze(1)).squeeze(1)
|
||||
state_indices_tensor_d_output = \
|
||||
state_indices_tensor_d.gather(1,
|
||||
current_last_idx_d.unsqueeze(1)).squeeze(1)
|
||||
#Note:
|
||||
state_indices_tensor_d_input = state_indices_tensor_d.gather(
|
||||
1, last_state_idx_d.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
state_indices_tensor_d_output = state_indices_tensor_d.gather(
|
||||
1, current_last_idx_d.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
# Note:
|
||||
# for decode always: current_first_idx_d == current_last_idx_d
|
||||
# at block boundaries: current_first_idx_d > last_state_idx_d
|
||||
else:
|
||||
@@ -755,20 +784,23 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
initial_state_idx=last_state_idx_d,
|
||||
)
|
||||
|
||||
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(
|
||||
hidden_states_B_C_d)
|
||||
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
n_groups = self.n_groups // self.tp_size
|
||||
A_d = self.A[:, None, ...][:, :, None].expand(
|
||||
-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
|
||||
A_d = (
|
||||
self.A[:, None, ...][:, :, None]
|
||||
.expand(-1, self.head_dim, self.ssm_state_size)
|
||||
.to(dtype=torch.float32)
|
||||
)
|
||||
dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim)
|
||||
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
|
||||
D_d = self.D[:, None, ...].expand(-1, self.head_dim)
|
||||
B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
|
||||
C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
|
||||
hidden_states_d = hidden_states_d.view(
|
||||
-1, self.num_heads // self.tp_size, self.head_dim)
|
||||
-1, self.num_heads // self.tp_size, self.head_dim
|
||||
)
|
||||
|
||||
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
||||
# - mamba_cache_params.ssm_state's slots will be selected
|
||||
@@ -787,16 +819,14 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor_d_input,
|
||||
dst_state_batch_indices=state_indices_tensor_d_output,
|
||||
out=preallocated_ssm_out_d.view(num_decodes, -1,
|
||||
self.head_dim),
|
||||
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
|
||||
)
|
||||
|
||||
# 4. gated MLP
|
||||
# GatedRMSNorm internally applying SiLU to the gate
|
||||
# SiLU is applied internally before normalization, unlike standard
|
||||
# norm usage
|
||||
hidden_states = self.norm(preallocated_ssm_out,
|
||||
gate[:num_actual_tokens])
|
||||
hidden_states = self.norm(preallocated_ssm_out, gate[:num_actual_tokens])
|
||||
|
||||
# 5. Final linear projection
|
||||
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
|
||||
@@ -826,8 +856,8 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
return "mamba2"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.mamba2_attn import (
|
||||
Mamba2AttentionBackend)
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
|
||||
|
||||
return Mamba2AttentionBackend
|
||||
|
||||
|
||||
@@ -839,9 +869,7 @@ def mamba_mixer2(
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.forward_cuda(hidden_states=hidden_states,
|
||||
output=output,
|
||||
mup_vector=mup_vector)
|
||||
self.forward_cuda(hidden_states=hidden_states, output=output, mup_vector=mup_vector)
|
||||
|
||||
|
||||
def mamba_mixer2_fake(
|
||||
|
||||
@@ -10,7 +10,6 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype
|
||||
|
||||
|
||||
class MambaStateDtypeCalculator:
|
||||
|
||||
@classmethod
|
||||
def linear_attention_state_dtype(
|
||||
cls,
|
||||
@@ -21,7 +20,7 @@ class MambaStateDtypeCalculator:
|
||||
if mamba_cache_dtype == "float32":
|
||||
raise ValueError("fp32 state for minimax is not yet supported")
|
||||
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
return (state_dtype, )
|
||||
return (state_dtype,)
|
||||
|
||||
@classmethod
|
||||
def mamba1_state_dtype(
|
||||
@@ -30,8 +29,9 @@ class MambaStateDtypeCalculator:
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
return cls._mamba_state_dtype(model_dtype, mamba_cache_dtype,
|
||||
mamba_ssm_cache_dtype)
|
||||
return cls._mamba_state_dtype(
|
||||
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def mamba2_state_dtype(
|
||||
@@ -40,8 +40,9 @@ class MambaStateDtypeCalculator:
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
return cls._mamba_state_dtype(model_dtype, mamba_cache_dtype,
|
||||
mamba_ssm_cache_dtype)
|
||||
return cls._mamba_state_dtype(
|
||||
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _mamba_state_dtype(
|
||||
@@ -50,13 +51,11 @@ class MambaStateDtypeCalculator:
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
|
||||
model_dtype)
|
||||
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
if mamba_ssm_cache_dtype == "auto":
|
||||
temporal_state_dtype = conv_state_dtype
|
||||
else:
|
||||
temporal_state_dtype = (
|
||||
STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype])
|
||||
temporal_state_dtype = STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype]
|
||||
|
||||
return (conv_state_dtype, temporal_state_dtype)
|
||||
|
||||
@@ -66,9 +65,8 @@ class MambaStateDtypeCalculator:
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
mamba_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
|
||||
model_dtype)
|
||||
return (conv_state_dtype, )
|
||||
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
return (conv_state_dtype,)
|
||||
|
||||
@classmethod
|
||||
def gated_delta_net_state_dtype(
|
||||
@@ -81,7 +79,6 @@ class MambaStateDtypeCalculator:
|
||||
|
||||
|
||||
class MambaStateShapeCalculator:
|
||||
|
||||
@classmethod
|
||||
def linear_attention_state_shape(
|
||||
cls,
|
||||
@@ -89,9 +86,8 @@ class MambaStateShapeCalculator:
|
||||
tp_size: int,
|
||||
head_dim: int,
|
||||
) -> tuple[tuple[int, int, int], ...]:
|
||||
|
||||
state_shape = (num_heads // tp_size, head_dim, head_dim)
|
||||
return (state_shape, )
|
||||
return (state_shape,)
|
||||
|
||||
@classmethod
|
||||
def mamba1_state_shape(
|
||||
@@ -101,11 +97,9 @@ class MambaStateShapeCalculator:
|
||||
state_size: int,
|
||||
conv_kernel: int,
|
||||
) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
conv_state_shape = (divide(intermediate_size,
|
||||
tp_world_size), conv_kernel - 1)
|
||||
conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1)
|
||||
|
||||
temporal_state_shape = (divide(intermediate_size,
|
||||
tp_world_size), state_size)
|
||||
temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size)
|
||||
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
|
||||
@@ -124,8 +118,7 @@ class MambaStateShapeCalculator:
|
||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
# if n_groups is not divisible by world_size, need to extend the shards
|
||||
# to ensure all groups needed by a head is sharded along with it
|
||||
n_groups = n_groups + cls.extra_groups_for_head_shards(
|
||||
n_groups, tp_world_size)
|
||||
n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size)
|
||||
# heads and n_groups are TP-ed
|
||||
conv_dim = intermediate_size + 2 * n_groups * state_size
|
||||
|
||||
@@ -135,8 +128,7 @@ class MambaStateShapeCalculator:
|
||||
# These are not TP-ed as they depend on A, dt_bias, D
|
||||
# - they are typically small
|
||||
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
|
||||
temporal_state_shape = (divide(num_heads,
|
||||
tp_world_size), head_dim, state_size)
|
||||
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
@classmethod
|
||||
@@ -148,7 +140,7 @@ class MambaStateShapeCalculator:
|
||||
) -> tuple[tuple[int, int]]:
|
||||
conv_dim = divide(intermediate_size, tp_world_size)
|
||||
conv_state_shape = (conv_kernel - 1, conv_dim)
|
||||
return (conv_state_shape, )
|
||||
return (conv_state_shape,)
|
||||
|
||||
@classmethod
|
||||
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
|
||||
@@ -173,7 +165,7 @@ class MambaStateShapeCalculator:
|
||||
conv_kernel_size: int,
|
||||
num_spec: int = 0,
|
||||
):
|
||||
conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads)
|
||||
conv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads
|
||||
conv_state_shape = (
|
||||
divide(conv_dim, tp_world_size),
|
||||
conv_kernel_size - 1 + num_spec,
|
||||
@@ -181,6 +173,9 @@ class MambaStateShapeCalculator:
|
||||
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
|
||||
temporal_state_shape = (divide(num_v_heads,
|
||||
tp_world_size), head_k_dim, head_v_dim)
|
||||
temporal_state_shape = (
|
||||
divide(num_v_heads, tp_world_size),
|
||||
head_k_dim,
|
||||
head_v_dim,
|
||||
)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
@@ -38,8 +38,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
|
||||
# Strides
|
||||
stride_x_dim: tl.constexpr, # stride to get to next feature-value,
|
||||
stride_x_token: tl.
|
||||
constexpr, # stride to get to next token (same feature-index, same sequence-index)
|
||||
stride_x_token: tl.constexpr, # stride to get to next token (same feature-index, same sequence-index)
|
||||
stride_w_dim: tl.constexpr, # stride to get to next dim-axis value
|
||||
stride_w_width: tl.constexpr, # stride to get to next width-axis value
|
||||
stride_istate_seq: tl.constexpr,
|
||||
@@ -66,7 +65,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
stride_conv_state_seq = stride_istate_seq
|
||||
stride_conv_state_dim = stride_istate_dim
|
||||
stride_conv_state_tok = stride_istate_token
|
||||
state_len = KERNEL_WIDTH - 1 # can be passed via argument if it's not the same as this value
|
||||
state_len = (
|
||||
KERNEL_WIDTH - 1
|
||||
) # can be passed via argument if it's not the same as this value
|
||||
|
||||
# one program handles one chunk in a single sequence
|
||||
# rather than mixing sequences - to make updating initial_states across sequences efficiently
|
||||
@@ -86,7 +87,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
# find the actual sequence length
|
||||
seqlen = sequence_end_index - sequence_start_index
|
||||
|
||||
B_size: tl.constexpr = (stride_block_m * BLOCK_M)
|
||||
B_size: tl.constexpr = stride_block_m * BLOCK_M
|
||||
|
||||
if IS_APC_ENABLED:
|
||||
# Handle the case if prefix caching is enabled.
|
||||
@@ -124,20 +125,24 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
segment_len = min(BLOCK_M, seqlen - token_offset)
|
||||
|
||||
# base of the sequence
|
||||
x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,]
|
||||
x_base = (
|
||||
x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim
|
||||
) # [BLOCK_N,]
|
||||
|
||||
# cache_idx
|
||||
conv_states_input_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_cache_indices +
|
||||
conv_state_init_index).to(tl.int64)
|
||||
conv_states_input_coord = tl.load(
|
||||
conv_state_indices_ptr + idx_seq * stride_cache_indices + conv_state_init_index
|
||||
).to(tl.int64)
|
||||
|
||||
if USE_PAD_SLOT: # noqa
|
||||
if conv_states_input_coord == pad_slot_id:
|
||||
# not processing as this is not the actual sequence
|
||||
return
|
||||
conv_states_base = (conv_states_ptr +
|
||||
(conv_states_input_coord * stride_conv_state_seq) +
|
||||
(idx_feats * stride_conv_state_dim)) # [BLOCK_N,]
|
||||
conv_states_base = (
|
||||
conv_states_ptr
|
||||
+ (conv_states_input_coord * stride_conv_state_seq)
|
||||
+ (idx_feats * stride_conv_state_dim)
|
||||
) # [BLOCK_N,]
|
||||
|
||||
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
|
||||
|
||||
@@ -149,8 +154,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1)
|
||||
if load_init_state:
|
||||
# load from conv_states
|
||||
prior_tokens = conv_states_base + (state_len -
|
||||
1) * stride_conv_state_tok
|
||||
prior_tokens = conv_states_base + (state_len - 1) * stride_conv_state_tok
|
||||
mask_w = idx_feats < dim
|
||||
if KERNEL_WIDTH == 2:
|
||||
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
||||
@@ -180,46 +184,54 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
# prior-tokens are zeros
|
||||
if KERNEL_WIDTH >= 2: # STRATEGY1
|
||||
# first chunk and does not have prior-token, so just set to 0
|
||||
col0 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty)
|
||||
col0 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
|
||||
if KERNEL_WIDTH >= 3: # STRATEGY1
|
||||
col1 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty)
|
||||
col1 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
|
||||
if KERNEL_WIDTH >= 4: # STRATEGY1
|
||||
col2 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty)
|
||||
col2 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
|
||||
if KERNEL_WIDTH >= 5: # STRATEGY1
|
||||
col3 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty)
|
||||
col3 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
|
||||
|
||||
# STEP 2:
|
||||
# here prepare data for updating conv_state
|
||||
if state_len <= seqlen: # SMALL_CACHE=True (only move part of 'x' into conv_state cache)
|
||||
if (
|
||||
state_len <= seqlen
|
||||
): # SMALL_CACHE=True (only move part of 'x' into conv_state cache)
|
||||
# just read from 'x'
|
||||
# copy 'x' data to conv_state
|
||||
# load only 'x' data (and set 0 before 'x' if seqlen < state_len)
|
||||
idx_tokens_last = (seqlen - state_len) + tl.arange(
|
||||
0, NP2_STATELEN) # [BLOCK_M]
|
||||
x_ptrs = x_ptr + (
|
||||
(sequence_start_index + idx_tokens_last) *
|
||||
stride_x_token)[:, None] + (
|
||||
idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,]
|
||||
mask_x = ((idx_tokens_last >= 0)[:, None] &
|
||||
(idx_tokens_last < seqlen)[:, None] &
|
||||
(idx_feats < dim)[None, :]
|
||||
) # token-index # token-index # feature-index
|
||||
0, NP2_STATELEN
|
||||
) # [BLOCK_M]
|
||||
x_ptrs = (
|
||||
x_ptr
|
||||
+ ((sequence_start_index + idx_tokens_last) * stride_x_token)[:, None]
|
||||
+ (idx_feats * stride_x_dim)[None, :]
|
||||
) # [BLOCK_M,BLOCK_N,]
|
||||
mask_x = (
|
||||
(idx_tokens_last >= 0)[:, None]
|
||||
& (idx_tokens_last < seqlen)[:, None]
|
||||
& (idx_feats < dim)[None, :]
|
||||
) # token-index # token-index # feature-index
|
||||
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
|
||||
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||
|
||||
# Compute the offset where the last block should be written in the conv_states
|
||||
conv_states_output_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_cache_indices +
|
||||
current_last_index).to(tl.int64)
|
||||
conv_states_output_coord = tl.load(
|
||||
conv_state_indices_ptr
|
||||
+ idx_seq * stride_cache_indices
|
||||
+ current_last_index
|
||||
).to(tl.int64)
|
||||
|
||||
conv_states_ptrs_target = (
|
||||
conv_states_ptr + (conv_states_output_coord *
|
||||
stride_conv_state_seq) + # Offset from seq
|
||||
(idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,]
|
||||
idx_tokens_conv * stride_conv_state_tok)[:, None]
|
||||
conv_states_ptr
|
||||
+ (conv_states_output_coord * stride_conv_state_seq) # Offset from seq
|
||||
+ (idx_feats * stride_conv_state_dim)
|
||||
)[None, :] + ( # [BLOCK_N,]
|
||||
idx_tokens_conv * stride_conv_state_tok
|
||||
)[:, None]
|
||||
|
||||
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats
|
||||
< dim)[None, :]
|
||||
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :]
|
||||
tl.debug_barrier() # NOTE: use this due to bug in Triton compiler
|
||||
tl.store(conv_states_ptrs_target, loaded_x, mask)
|
||||
|
||||
@@ -229,39 +241,43 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||
|
||||
conv_states_ptrs_source = (
|
||||
conv_states_ptr +
|
||||
(conv_states_input_coord * stride_conv_state_seq) +
|
||||
(idx_feats * stride_conv_state_dim)[None, :] +
|
||||
((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:,
|
||||
None]
|
||||
conv_states_ptr
|
||||
+ (conv_states_input_coord * stride_conv_state_seq)
|
||||
+ (idx_feats * stride_conv_state_dim)[None, :]
|
||||
+ ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
mask = ((conv_states_input_coord < num_cache_lines)
|
||||
& ((idx_tokens_conv + seqlen) < state_len)[:, None]
|
||||
& (idx_feats < dim)[None, :])
|
||||
mask = (
|
||||
(conv_states_input_coord < num_cache_lines)
|
||||
& ((idx_tokens_conv + seqlen) < state_len)[:, None]
|
||||
& (idx_feats < dim)[None, :]
|
||||
)
|
||||
conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0)
|
||||
|
||||
VAL = state_len - seqlen
|
||||
|
||||
x_ptrs = x_base[None, :] + (
|
||||
(idx_tokens_conv - VAL) *
|
||||
stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
|
||||
x_ptrs = (
|
||||
x_base[None, :]
|
||||
+ ((idx_tokens_conv - VAL) * stride_x_token)[:, None]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
|
||||
mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] &
|
||||
(idx_tokens_conv - VAL < seqlen)[:, None] &
|
||||
(idx_feats < dim)[None, :]
|
||||
) # token-index # token-index # feature-index
|
||||
mask_x = (
|
||||
(idx_tokens_conv - VAL >= 0)[:, None]
|
||||
& (idx_tokens_conv - VAL < seqlen)[:, None]
|
||||
& (idx_feats < dim)[None, :]
|
||||
) # token-index # token-index # feature-index
|
||||
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
|
||||
|
||||
tl.debug_barrier(
|
||||
) # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load
|
||||
tl.debug_barrier() # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load
|
||||
new_conv_state = tl.where(
|
||||
mask, conv_state, loaded_x
|
||||
) # BUG in 'tl.where' which requires a barrier before this
|
||||
conv_states_ptrs_target = conv_states_base + (
|
||||
idx_tokens_conv *
|
||||
stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
|
||||
mask = (idx_tokens_conv
|
||||
< state_len)[:, None] & (idx_feats < dim)[None, :]
|
||||
conv_states_ptrs_target = (
|
||||
conv_states_base
|
||||
+ (idx_tokens_conv * stride_conv_state_tok)[:, None]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[
|
||||
None, :
|
||||
]
|
||||
tl.store(conv_states_ptrs_target, new_conv_state, mask)
|
||||
else: # load_init_state == False
|
||||
# update conv_state by shifting left, BUT
|
||||
@@ -270,21 +286,25 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
|
||||
VAL = state_len - seqlen
|
||||
|
||||
x_ptrs = x_base[None, :] + (
|
||||
(idx_tokens_conv - VAL) *
|
||||
stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
|
||||
x_ptrs = (
|
||||
x_base[None, :]
|
||||
+ ((idx_tokens_conv - VAL) * stride_x_token)[:, None]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
|
||||
mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] &
|
||||
(idx_tokens_conv - VAL < seqlen)[:, None] &
|
||||
(idx_feats < dim)[None, :]
|
||||
) # token-index # token-index # feature-index
|
||||
mask_x = (
|
||||
(idx_tokens_conv - VAL >= 0)[:, None]
|
||||
& (idx_tokens_conv - VAL < seqlen)[:, None]
|
||||
& (idx_feats < dim)[None, :]
|
||||
) # token-index # token-index # feature-index
|
||||
new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
|
||||
|
||||
conv_states_ptrs_target = conv_states_base + (
|
||||
idx_tokens_conv *
|
||||
stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
|
||||
mask = (idx_tokens_conv
|
||||
< state_len)[:, None] & (idx_feats < dim)[None, :]
|
||||
conv_states_ptrs_target = (
|
||||
conv_states_base
|
||||
+ (idx_tokens_conv * stride_conv_state_tok)[:, None]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[
|
||||
None, :
|
||||
]
|
||||
tl.store(conv_states_ptrs_target, new_conv_state, mask)
|
||||
|
||||
else: # chunk_offset > 0
|
||||
@@ -294,29 +314,29 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
mask_w = idx_feats < dim
|
||||
if KERNEL_WIDTH == 2:
|
||||
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
||||
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
|
||||
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
|
||||
if KERNEL_WIDTH == 3:
|
||||
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
||||
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
|
||||
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
|
||||
conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
|
||||
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
|
||||
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
|
||||
if KERNEL_WIDTH == 4:
|
||||
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
||||
col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
|
||||
col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
|
||||
conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
|
||||
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
|
||||
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
|
||||
conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N]
|
||||
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
|
||||
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
|
||||
if KERNEL_WIDTH == 5:
|
||||
# ruff: noqa: F841
|
||||
conv_states_ptrs = prior_tokens # [BLOCK_N]
|
||||
col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
|
||||
col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
|
||||
conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
|
||||
col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
|
||||
col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
|
||||
conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N]
|
||||
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
|
||||
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
|
||||
conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N]
|
||||
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
|
||||
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
|
||||
|
||||
# Store intermediate states aligned with stride_block_m
|
||||
# The additional states are cached starting from the last stride_block_m.
|
||||
@@ -327,43 +347,51 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
# For example chunk_offset = n_block_to_fill stores the state at last_full_block
|
||||
if (chunk_offset - 1) < n_block_to_fill:
|
||||
# Store the states at the chunk boundaries from the start of the sequence
|
||||
idx_tokens_last = (last_full_block_token_index -
|
||||
(n_block_to_fill - chunk_offset) * B_size -
|
||||
state_len) + tl.arange(
|
||||
0, NP2_STATELEN) # [BLOCK_M]
|
||||
x_ptrs = x_ptr + (idx_tokens_last * stride_x_token)[:, None] + (
|
||||
idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,]
|
||||
idx_tokens_last = (
|
||||
last_full_block_token_index
|
||||
- (n_block_to_fill - chunk_offset) * B_size
|
||||
- state_len
|
||||
) + tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||
x_ptrs = (
|
||||
x_ptr
|
||||
+ (idx_tokens_last * stride_x_token)[:, None]
|
||||
+ (idx_feats * stride_x_dim)[None, :]
|
||||
) # [BLOCK_M,BLOCK_N,]
|
||||
|
||||
mask_x = (
|
||||
(idx_tokens_last >= 0)[:, None] & (idx_feats < dim)[None, :]
|
||||
) # token-index # token-index # feature-index
|
||||
mask_x = (idx_tokens_last >= 0)[:, None] & (idx_feats < dim)[
|
||||
None, :
|
||||
] # token-index # token-index # feature-index
|
||||
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
|
||||
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
|
||||
|
||||
# cache_idx
|
||||
conv_states_output_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_cache_indices +
|
||||
current_first_index +
|
||||
(chunk_offset - 1)).to(tl.int64)
|
||||
conv_states_output_coord = tl.load(
|
||||
conv_state_indices_ptr
|
||||
+ idx_seq * stride_cache_indices
|
||||
+ current_first_index
|
||||
+ (chunk_offset - 1)
|
||||
).to(tl.int64)
|
||||
|
||||
conv_states_ptrs_target = (
|
||||
conv_states_ptr + (conv_states_output_coord *
|
||||
stride_conv_state_seq) + # Offset from seq
|
||||
(idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,]
|
||||
idx_tokens_conv * stride_conv_state_tok)[:, None]
|
||||
conv_states_ptr
|
||||
+ (conv_states_output_coord * stride_conv_state_seq) # Offset from seq
|
||||
+ (idx_feats * stride_conv_state_dim)
|
||||
)[None, :] + ( # [BLOCK_N,]
|
||||
idx_tokens_conv * stride_conv_state_tok
|
||||
)[:, None]
|
||||
|
||||
mask = (idx_tokens_conv < state_len)[:, None] & \
|
||||
(idx_feats < dim)[None, :]
|
||||
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :]
|
||||
tl.debug_barrier() # NOTE: use this due to bug in Triton compiler
|
||||
tl.store(conv_states_ptrs_target, loaded_x, mask)
|
||||
|
||||
if HAS_BIAS:
|
||||
bias = bias_ptr + idx_feats
|
||||
mask_bias = idx_feats < dim
|
||||
acc_preload = tl.load(bias, mask=mask_bias,
|
||||
other=0.0).to(tl.float32) # [BLOCK_N]
|
||||
acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(
|
||||
tl.float32
|
||||
) # [BLOCK_N]
|
||||
else:
|
||||
acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
|
||||
x_base_1d = x_base + token_offset * stride_x_token # starting of chunk
|
||||
|
||||
@@ -387,7 +415,6 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
matrix_w = w_col0
|
||||
matrix_x = col0
|
||||
for j in tl.static_range(KERNEL_WIDTH):
|
||||
|
||||
if KERNEL_WIDTH == 2:
|
||||
if j == 1: # KERNEL_WIDTH-1:
|
||||
matrix_w = w_col1
|
||||
@@ -428,9 +455,13 @@ def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
if SILU_ACTIVATION:
|
||||
acc = acc / (1 + tl.exp(-acc))
|
||||
mask_1d = (idx_token < segment_len) & (
|
||||
idx_feats < dim) # token-index # feature-index
|
||||
o_ptrs = o_ptr + (sequence_start_index + token_offset + idx_token
|
||||
) * stride_o_token + (idx_feats * stride_o_dim)
|
||||
idx_feats < dim
|
||||
) # token-index # feature-index
|
||||
o_ptrs = (
|
||||
o_ptr
|
||||
+ (sequence_start_index + token_offset + idx_token) * stride_o_token
|
||||
+ (idx_feats * stride_o_dim)
|
||||
)
|
||||
|
||||
tl.store(o_ptrs, acc, mask=mask_1d)
|
||||
|
||||
@@ -518,21 +549,15 @@ def causal_conv1d_fn(
|
||||
batch_ptr = metadata.batch_ptr
|
||||
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
|
||||
else:
|
||||
seqlens = query_start_loc.diff().to('cpu')
|
||||
seqlens = query_start_loc.diff().to("cpu")
|
||||
args = seqlens
|
||||
MAX_NUM_PROGRAMS = 1024
|
||||
|
||||
batch_ptr = torch.full(
|
||||
(MAX_NUM_PROGRAMS, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.int32,
|
||||
device=x.device
|
||||
(MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device
|
||||
) # tracking which seq-idx the Triton program is handling
|
||||
token_chunk_offset_ptr = torch.full(
|
||||
(MAX_NUM_PROGRAMS, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.int32,
|
||||
device=x.device
|
||||
(MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device
|
||||
) # tracking BLOCK_M-based index in the sequence the Triton program is handling
|
||||
|
||||
is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1)
|
||||
@@ -558,9 +583,11 @@ def causal_conv1d_fn(
|
||||
# 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx]
|
||||
# 4. computation can be skipped if cache_indices[idx] == pad_slot_id
|
||||
num_cache_lines = conv_states.size(0)
|
||||
assert (num_cache_lines == conv_states.shape[0]
|
||||
and dim == conv_states.shape[1]
|
||||
and width - 1 <= conv_states.shape[2])
|
||||
assert (
|
||||
num_cache_lines == conv_states.shape[0]
|
||||
and dim == conv_states.shape[1]
|
||||
and width - 1 <= conv_states.shape[2]
|
||||
)
|
||||
stride_istate_seq = conv_states.stride(0)
|
||||
stride_istate_dim = conv_states.stride(1)
|
||||
stride_istate_token = conv_states.stride(2)
|
||||
@@ -571,8 +598,7 @@ def causal_conv1d_fn(
|
||||
else:
|
||||
stride_o_dim = out.stride(1)
|
||||
stride_o_token = out.stride(2)
|
||||
stride_cache_indices = cache_indices.stride(
|
||||
0) if cache_indices is not None else 0
|
||||
stride_cache_indices = cache_indices.stride(0) if cache_indices is not None else 0
|
||||
|
||||
if validate_data:
|
||||
assert x.dim() == 2
|
||||
@@ -586,15 +612,17 @@ def causal_conv1d_fn(
|
||||
assert cache_indices.dim() == 1
|
||||
assert padded_batch == cache_indices.size(0)
|
||||
if has_initial_state is not None:
|
||||
assert has_initial_state.size() == (padded_batch, )
|
||||
assert conv_states is not None, "ERROR: `has_initial_state` is used, which needs also `conv_states`"
|
||||
assert has_initial_state.size() == (padded_batch,)
|
||||
assert conv_states is not None, (
|
||||
"ERROR: `has_initial_state` is used, which needs also `conv_states`"
|
||||
)
|
||||
assert weight.stride(1) == 1
|
||||
assert (dim, width) == weight.shape
|
||||
assert is_channel_last, "Need to run in channel-last layout"
|
||||
if block_size_to_align is not None and block_size_to_align > 0:
|
||||
assert (
|
||||
block_size_to_align % BLOCK_M
|
||||
) == 0, "The mamba block size needs to be divisible by the BLOCK_M"
|
||||
assert (block_size_to_align % BLOCK_M) == 0, (
|
||||
"The mamba block size needs to be divisible by the BLOCK_M"
|
||||
)
|
||||
else:
|
||||
block_size_to_align = BLOCK_M
|
||||
|
||||
@@ -618,44 +646,45 @@ def causal_conv1d_fn(
|
||||
if META["batch_ptr"].nelement() < len(mlist):
|
||||
newlen = len(mlist) + 1
|
||||
META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
|
||||
META["token_chunk_offset_ptr"].resize_(newlen).fill_(
|
||||
PAD_SLOT_ID)
|
||||
META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
|
||||
|
||||
if META["batch_ptr"].nelement() >= len(mlist):
|
||||
META["batch_ptr"][0:len(mlist)].copy_(
|
||||
torch.from_numpy(np.array(mlist)))
|
||||
META["token_chunk_offset_ptr"][0:len(mlist)].copy_(
|
||||
torch.from_numpy(np.array(offsetlist)))
|
||||
META["batch_ptr"][0 : len(mlist)].copy_(
|
||||
torch.from_numpy(np.array(mlist))
|
||||
)
|
||||
META["token_chunk_offset_ptr"][0 : len(mlist)].copy_(
|
||||
torch.from_numpy(np.array(offsetlist))
|
||||
)
|
||||
|
||||
META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device)
|
||||
META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to(
|
||||
META["x_ptr"].device)
|
||||
META["x_ptr"].device
|
||||
)
|
||||
return tot
|
||||
else:
|
||||
|
||||
def num_program(META, nums_dict):
|
||||
tot = nums_dict[META["BLOCK_M"]]['tot']
|
||||
tot = nums_dict[META["BLOCK_M"]]["tot"]
|
||||
|
||||
mlist = nums_dict[META["BLOCK_M"]]['mlist']
|
||||
mlist_len = nums_dict[META["BLOCK_M"]]['mlist_len']
|
||||
mlist = nums_dict[META["BLOCK_M"]]["mlist"]
|
||||
mlist_len = nums_dict[META["BLOCK_M"]]["mlist_len"]
|
||||
|
||||
offsetlist = nums_dict[META["BLOCK_M"]]['offsetlist']
|
||||
offsetlist = nums_dict[META["BLOCK_M"]]["offsetlist"]
|
||||
|
||||
if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None:
|
||||
META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"]
|
||||
META["token_chunk_offset_ptr"] = nums_dict[
|
||||
META["BLOCK_M"]]["token_chunk_offset_ptr"]
|
||||
META["token_chunk_offset_ptr"] = nums_dict[META["BLOCK_M"]][
|
||||
"token_chunk_offset_ptr"
|
||||
]
|
||||
else:
|
||||
if META["batch_ptr"].nelement() < mlist_len:
|
||||
newlen = mlist_len + 1
|
||||
META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
|
||||
META["token_chunk_offset_ptr"].resize_(newlen).fill_(
|
||||
PAD_SLOT_ID)
|
||||
META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
|
||||
|
||||
if META["batch_ptr"].nelement() >= mlist_len:
|
||||
META["batch_ptr"][0:mlist_len].copy_(mlist)
|
||||
META["token_chunk_offset_ptr"][0:mlist_len].copy_(
|
||||
offsetlist)
|
||||
META["token_chunk_offset_ptr"][0:mlist_len].copy_(offsetlist)
|
||||
return tot
|
||||
|
||||
def grid(META):
|
||||
@@ -709,7 +738,7 @@ def causal_conv1d_fn(
|
||||
IS_APC_ENABLED=current_last_idx is not None,
|
||||
USE_PAD_SLOT=pad_slot_id is not None,
|
||||
NP2_STATELEN=np2_statelen,
|
||||
#launch_cooperative_grid=True
|
||||
# launch_cooperative_grid=True
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=256,
|
||||
num_stages=2,
|
||||
@@ -728,7 +757,7 @@ def _causal_conv1d_update_kernel(
|
||||
num_accepted_tokens_ptr,
|
||||
query_start_loc_ptr, # (batch + 1)
|
||||
current_last_idx, # (batch,)
|
||||
initial_state_idx, #(batch,)
|
||||
initial_state_idx, # (batch,)
|
||||
o_ptr, # (batch, dim, seqlen)
|
||||
# Matrix dimensions
|
||||
batch: int,
|
||||
@@ -779,9 +808,9 @@ def _causal_conv1d_update_kernel(
|
||||
current_last_index = 0
|
||||
|
||||
# cache_idx
|
||||
conv_states_input_coord = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_state_indices +
|
||||
conv_state_init).to(tl.int64)
|
||||
conv_states_input_coord = tl.load(
|
||||
conv_state_indices_ptr + idx_seq * stride_state_indices + conv_state_init
|
||||
).to(tl.int64)
|
||||
|
||||
if USE_PAD_SLOT: # noqa
|
||||
if conv_states_input_coord == pad_slot_id:
|
||||
@@ -790,11 +819,9 @@ def _causal_conv1d_update_kernel(
|
||||
|
||||
if IS_VARLEN:
|
||||
query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64)
|
||||
query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to(
|
||||
tl.int64)
|
||||
query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to(tl.int64)
|
||||
# revise state_len and seqlen
|
||||
state_len = state_len - (seqlen -
|
||||
(query_end_index - query_start_index))
|
||||
state_len = state_len - (seqlen - (query_end_index - query_start_index))
|
||||
seqlen = query_end_index - query_start_index
|
||||
x_offset = query_start_index * stride_x_token
|
||||
o_offset = query_start_index * stride_o_token
|
||||
@@ -822,14 +849,17 @@ def _causal_conv1d_update_kernel(
|
||||
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
|
||||
# - and so on.
|
||||
conv_state_token_offset = (
|
||||
tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1)
|
||||
tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1
|
||||
)
|
||||
else:
|
||||
conv_state_token_offset = 0
|
||||
|
||||
# STEP 1: READ init_state data
|
||||
conv_states_base = (conv_state_ptr +
|
||||
(conv_states_input_coord * stride_conv_state_seq) +
|
||||
(idx_feats * stride_conv_state_dim))
|
||||
conv_states_base = (
|
||||
conv_state_ptr
|
||||
+ (conv_states_input_coord * stride_conv_state_seq)
|
||||
+ (idx_feats * stride_conv_state_dim)
|
||||
)
|
||||
mask_w = idx_feats < dim
|
||||
|
||||
prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
|
||||
@@ -856,25 +886,33 @@ def _causal_conv1d_update_kernel(
|
||||
# window manner, at each forward pass, the tokens are shift by 1, so we
|
||||
# load since idx_tokens + 1.
|
||||
conv_state_ptrs_source = (
|
||||
conv_state_ptr + (conv_states_input_coord * stride_conv_state_seq) +
|
||||
conv_state_token_offset * stride_conv_state_tok +
|
||||
(idx_feats * stride_conv_state_dim)[None, :] +
|
||||
((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) *
|
||||
stride_conv_state_tok)[:, None]) # [BLOCK_M, BLOCK_N]
|
||||
mask = ((conv_states_input_coord < num_cache_lines)
|
||||
& ((idx_tokens + seqlen) < state_len)[:, None]
|
||||
& (idx_feats < dim)[None, :])
|
||||
conv_state_ptr
|
||||
+ (conv_states_input_coord * stride_conv_state_seq)
|
||||
+ conv_state_token_offset * stride_conv_state_tok
|
||||
+ (idx_feats * stride_conv_state_dim)[None, :]
|
||||
+ ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[
|
||||
:, None
|
||||
]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
mask = (
|
||||
(conv_states_input_coord < num_cache_lines)
|
||||
& ((idx_tokens + seqlen) < state_len)[:, None]
|
||||
& (idx_feats < dim)[None, :]
|
||||
)
|
||||
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)
|
||||
|
||||
VAL = state_len - seqlen
|
||||
x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N]
|
||||
|
||||
x_ptrs = x_base[None, :] + (
|
||||
(idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
|
||||
x_ptrs = (
|
||||
x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
|
||||
mask_x = ((idx_tokens - VAL >= 0)[:, None] &
|
||||
(idx_tokens - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :]
|
||||
) # token-index # token-index # feature-index
|
||||
mask_x = (
|
||||
(idx_tokens - VAL >= 0)[:, None]
|
||||
& (idx_tokens - VAL < seqlen)[:, None]
|
||||
& (idx_feats < dim)[None, :]
|
||||
) # token-index # token-index # feature-index
|
||||
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
|
||||
tl.debug_barrier()
|
||||
|
||||
@@ -882,14 +920,16 @@ def _causal_conv1d_update_kernel(
|
||||
|
||||
# Get the state from the initial_state_idx
|
||||
# cache_idx
|
||||
conv_states_offset = tl.load(conv_state_indices_ptr +
|
||||
idx_seq * stride_state_indices +
|
||||
current_last_index).to(tl.int64)
|
||||
conv_states_offset = tl.load(
|
||||
conv_state_indices_ptr + idx_seq * stride_state_indices + current_last_index
|
||||
).to(tl.int64)
|
||||
conv_state_ptrs_target = (
|
||||
conv_state_ptr +
|
||||
(conv_states_offset * stride_conv_state_seq) + # Offset from seq
|
||||
(idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,]
|
||||
idx_tokens * stride_conv_state_tok)[:, None]
|
||||
conv_state_ptr
|
||||
+ (conv_states_offset * stride_conv_state_seq) # Offset from seq
|
||||
+ (idx_feats * stride_conv_state_dim)
|
||||
)[None, :] + ( # [BLOCK_N,]
|
||||
idx_tokens * stride_conv_state_tok
|
||||
)[:, None]
|
||||
mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :]
|
||||
tl.store(conv_state_ptrs_target, new_conv_state, mask)
|
||||
|
||||
@@ -897,10 +937,11 @@ def _causal_conv1d_update_kernel(
|
||||
if HAS_BIAS:
|
||||
bias = bias_ptr + idx_feats
|
||||
mask_bias = idx_feats < dim
|
||||
acc_preload = tl.load(bias, mask=mask_bias,
|
||||
other=0.0).to(tl.float32) # [BLOCK_N]
|
||||
acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(
|
||||
tl.float32
|
||||
) # [BLOCK_N]
|
||||
else:
|
||||
acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32)
|
||||
|
||||
# STEP 4:
|
||||
# PRE-LOAD WEIGHTS
|
||||
@@ -1016,10 +1057,12 @@ def _causal_conv1d_update_kernel(
|
||||
|
||||
if SILU_ACTIVATION:
|
||||
acc = acc / (1 + tl.exp(-acc))
|
||||
mask_1d = (idx_token < seqlen) & (idx_feats < dim
|
||||
) # token-index # feature-index
|
||||
o_ptrs = o_ptr + o_offset + idx_token * stride_o_token + (idx_feats *
|
||||
stride_o_dim)
|
||||
mask_1d = (idx_token < seqlen) & (
|
||||
idx_feats < dim
|
||||
) # token-index # feature-index
|
||||
o_ptrs = (
|
||||
o_ptr + o_offset + idx_token * stride_o_token + (idx_feats * stride_o_dim)
|
||||
)
|
||||
|
||||
tl.store(o_ptrs, acc, mask=mask_1d)
|
||||
|
||||
@@ -1104,16 +1147,16 @@ def causal_conv1d_update(
|
||||
|
||||
if validate_data:
|
||||
assert dim == weight.size(0)
|
||||
assert conv_state.stride(
|
||||
-2
|
||||
) == 1, f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})"
|
||||
assert conv_state.stride(-2) == 1, (
|
||||
f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})"
|
||||
)
|
||||
assert state_len >= width - 1
|
||||
# when above happens, we don't shift-left to keep any records in conv_state
|
||||
assert dim == conv_state.size(1)
|
||||
if conv_state_indices is None:
|
||||
assert conv_state.size(0) >= batch
|
||||
else:
|
||||
assert (batch, ) == conv_state_indices.shape
|
||||
assert (batch,) == conv_state_indices.shape
|
||||
|
||||
assert num_cache_lines >= batch
|
||||
assert weight.stride(1) == 1 # Need this
|
||||
@@ -1133,10 +1176,10 @@ def causal_conv1d_update(
|
||||
stride_o_token, stride_o_dim = out.stride()
|
||||
stride_o_seq = 0
|
||||
|
||||
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
|
||||
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride()
|
||||
stride_state_indices = (
|
||||
conv_state_indices.stride(0) if conv_state_indices is not None else 0
|
||||
)
|
||||
stride_state_indices = conv_state_indices.stride(
|
||||
0) if conv_state_indices is not None else 0
|
||||
if num_accepted_tokens is not None:
|
||||
state_len = width - 1 + (seqlen - 1) # effective state_len needed
|
||||
else:
|
||||
|
||||
@@ -46,17 +46,17 @@ def _layer_norm_fwd_1pass_kernel(
|
||||
B += group * N
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
tl.store(Mean + row, mean)
|
||||
xbar = tl.where(cols < N, x - mean, 0.)
|
||||
xbar = tl.where(cols < N, x - mean, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
else:
|
||||
xbar = tl.where(cols < N, x, 0.)
|
||||
xbar = tl.where(cols < N, x, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
tl.store(Rstd + row, rstd)
|
||||
@@ -74,15 +74,17 @@ def _layer_norm_fwd_1pass_kernel(
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
|
||||
|
||||
def _layer_norm_fwd(x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=None,
|
||||
out=None,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False):
|
||||
def _layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=None,
|
||||
out=None,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
M, N = x.shape
|
||||
if group_size is None:
|
||||
group_size = N
|
||||
@@ -92,57 +94,57 @@ def _layer_norm_fwd(x,
|
||||
if z is not None:
|
||||
assert z.stride(-1) == 1
|
||||
assert z.shape == (M, N)
|
||||
assert weight.shape == (N, )
|
||||
assert weight.shape == (N,)
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N, )
|
||||
assert bias.shape == (N,)
|
||||
# allocate output
|
||||
if out is not None:
|
||||
assert out.shape == x.shape
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
assert out.stride(-1) == 1
|
||||
mean = torch.empty((ngroups * M, ), dtype=torch.float32,
|
||||
device=x.device) if not is_rms_norm else None
|
||||
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
||||
mean = (
|
||||
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
||||
if not is_rms_norm
|
||||
else None
|
||||
)
|
||||
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
||||
if group_size > BLOCK_N:
|
||||
raise RuntimeError(
|
||||
"This layer norm doesn't support feature dim >= 64KB.")
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
grid = (M, ngroups)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_layer_norm_fwd_1pass_kernel[grid](x,
|
||||
out,
|
||||
weight,
|
||||
bias,
|
||||
z,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
out.stride(0),
|
||||
z.stride(0) if z is not None else 0,
|
||||
M,
|
||||
group_size,
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps)
|
||||
_layer_norm_fwd_1pass_kernel[grid](
|
||||
x,
|
||||
out,
|
||||
weight,
|
||||
bias,
|
||||
z,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
out.stride(0),
|
||||
z.stride(0) if z is not None else 0,
|
||||
M,
|
||||
group_size,
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return out, mean, rstd
|
||||
|
||||
|
||||
def rms_norm_gated(x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True):
|
||||
def rms_norm_gated(
|
||||
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
|
||||
):
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
@@ -156,13 +158,15 @@ def rms_norm_gated(x,
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, _, _ = _layer_norm_fwd(x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=True)
|
||||
y, _, _ = _layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=True,
|
||||
)
|
||||
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
@@ -11,8 +11,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.triton_utils import HAS_TRITON, tl, triton
|
||||
|
||||
TRITON3 = HAS_TRITON and (version.parse(triton.__version__)
|
||||
>= version.parse("3.0.0"))
|
||||
TRITON3 = HAS_TRITON and (version.parse(triton.__version__) >= version.parse("3.0.0"))
|
||||
|
||||
if TRITON3:
|
||||
|
||||
@@ -28,16 +27,18 @@ else:
|
||||
return dt
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
||||
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
||||
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
||||
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
||||
@triton.heuristics({
|
||||
"HAS_STATE_BATCH_INDICES":
|
||||
lambda args: args["state_batch_indices_ptr"] is not None
|
||||
})
|
||||
@triton.heuristics(
|
||||
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
|
||||
{
|
||||
"HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
|
||||
is not None
|
||||
}
|
||||
)
|
||||
@triton.heuristics(
|
||||
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
|
||||
)
|
||||
@triton.jit
|
||||
def _selective_scan_update_kernel(
|
||||
# Pointers to matrices
|
||||
@@ -110,15 +111,16 @@ def _selective_scan_update_kernel(
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
dst_state_batch_indices_ptr += pid_b
|
||||
dst_state_batch_idx = tl.load(dst_state_batch_indices_ptr).to(tl.int64)
|
||||
dst_state_ptr = state_ptr + (dst_state_batch_idx * stride_state_batch +
|
||||
pid_h * stride_state_head)
|
||||
dst_state_ptr = state_ptr + (
|
||||
dst_state_batch_idx * stride_state_batch + pid_h * stride_state_head
|
||||
)
|
||||
state_batch_indices_ptr += pid_b
|
||||
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
|
||||
state_ptr += (state_batch_idx * stride_state_batch +
|
||||
pid_h * stride_state_head)
|
||||
state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
|
||||
else:
|
||||
dst_state_ptr = state_ptr + pid_b * stride_state_batch + \
|
||||
pid_h * stride_state_head
|
||||
dst_state_ptr = (
|
||||
state_ptr + pid_b * stride_state_batch + pid_h * stride_state_head
|
||||
)
|
||||
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
||||
|
||||
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
|
||||
@@ -126,28 +128,29 @@ def _selective_scan_update_kernel(
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias_ptr += pid_h * stride_dt_bias_head
|
||||
A_ptr += pid_h * stride_A_head
|
||||
B_ptr += pid_b * stride_B_batch + (pid_h //
|
||||
nheads_ngroups_ratio) * stride_B_group
|
||||
C_ptr += pid_b * stride_C_batch + (pid_h //
|
||||
nheads_ngroups_ratio) * stride_C_group
|
||||
B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
|
||||
C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
|
||||
if HAS_Z:
|
||||
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
|
||||
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
||||
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim +
|
||||
offs_n[None, :] * stride_state_dstate)
|
||||
dst_state_ptrs = dst_state_ptr + (offs_m[:, None] * stride_state_dim +
|
||||
offs_n[None, :] * stride_state_dstate)
|
||||
state_ptrs = state_ptr + (
|
||||
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
|
||||
)
|
||||
dst_state_ptrs = dst_state_ptr + (
|
||||
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
|
||||
)
|
||||
x_ptrs = x_ptr + offs_m * stride_x_dim
|
||||
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
|
||||
if HAS_D:
|
||||
D_ptr += pid_h * stride_D_head
|
||||
A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim +
|
||||
offs_n[None, :] * stride_A_dstate)
|
||||
A_ptrs = A_ptr + (
|
||||
offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
|
||||
)
|
||||
B_ptrs = B_ptr + offs_n * stride_B_dstate
|
||||
C_ptrs = C_ptr + offs_n * stride_C_dstate
|
||||
if HAS_D:
|
||||
@@ -157,20 +160,19 @@ def _selective_scan_update_kernel(
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
mask &= (state_batch_idx != pad_slot_id)
|
||||
mask &= state_batch_idx != pad_slot_id
|
||||
state = tl.load(state_ptrs, mask=mask, other=0.0)
|
||||
|
||||
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if not TIE_HDIM:
|
||||
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if DT_SOFTPLUS:
|
||||
dt = softplus(dt)
|
||||
A = tl.load(A_ptrs,
|
||||
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
|
||||
other=0.0).to(tl.float32)
|
||||
A = tl.load(
|
||||
A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
|
||||
).to(tl.float32)
|
||||
dA = tl.exp(A * dt[:, None])
|
||||
else:
|
||||
dt = tl.load(dt_ptr).to(tl.float32)
|
||||
@@ -193,7 +195,7 @@ def _selective_scan_update_kernel(
|
||||
|
||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
mask &= (state_batch_idx != pad_slot_id)
|
||||
mask &= state_batch_idx != pad_slot_id
|
||||
tl.store(dst_state_ptrs, state, mask=mask)
|
||||
out = tl.sum(state * C[None, :], axis=1)
|
||||
if HAS_D:
|
||||
@@ -203,20 +205,22 @@ def _selective_scan_update_kernel(
|
||||
tl.store(out_ptrs, out, mask=offs_m < dim)
|
||||
|
||||
|
||||
def selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
dt_softplus=False,
|
||||
state_batch_indices=None,
|
||||
dst_state_batch_indices=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=None):
|
||||
def selective_state_update(
|
||||
state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
dt_softplus=False,
|
||||
state_batch_indices=None,
|
||||
dst_state_batch_indices=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=None,
|
||||
):
|
||||
"""
|
||||
Argument:
|
||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
@@ -229,12 +233,12 @@ def selective_state_update(state,
|
||||
z: (batch, dim) or (batch, nheads, dim)
|
||||
dt_bias: (dim,) or (nheads, dim)
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: Preallocated ssm output tensor. Assume same shape as x.
|
||||
out: Preallocated ssm output tensor. Assume same shape as x.
|
||||
In-place updated.
|
||||
"""
|
||||
if state.dim() == 3:
|
||||
@@ -275,25 +279,33 @@ def selective_state_update(state,
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
if state_batch_indices is not None:
|
||||
assert state_batch_indices.shape == (batch, )
|
||||
assert state_batch_indices.shape == (batch,)
|
||||
if dst_state_batch_indices is not None:
|
||||
assert dst_state_batch_indices.shape == (batch, )
|
||||
assert dst_state_batch_indices.shape == (batch,)
|
||||
else:
|
||||
# revert to the default behavior of in-place state updates
|
||||
dst_state_batch_indices = state_batch_indices
|
||||
assert out.shape == x.shape
|
||||
|
||||
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
|
||||
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
|
||||
(0, 0, 0))
|
||||
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
|
||||
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
|
||||
# We don't want autotune since it will overwrite the state
|
||||
# We instead tune by hand.
|
||||
BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else
|
||||
((16, 4) if dstate <= 32 else
|
||||
((8, 4) if dstate <= 64 else
|
||||
((4, 4) if dstate <= 128 else ((4, 8))))))
|
||||
tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(
|
||||
-1) == 0 and dt_bias.stride(-1) == 0
|
||||
BLOCK_SIZE_M, num_warps = (
|
||||
(32, 4)
|
||||
if dstate <= 16
|
||||
else (
|
||||
(16, 4)
|
||||
if dstate <= 32
|
||||
else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
|
||||
)
|
||||
)
|
||||
tie_hdim = (
|
||||
A.stride(-1) == 0
|
||||
and A.stride(-2) == 0
|
||||
and dt.stride(-1) == 0
|
||||
and dt_bias.stride(-1) == 0
|
||||
)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_selective_scan_update_kernel[grid](
|
||||
state,
|
||||
@@ -324,8 +336,7 @@ def selective_state_update(state,
|
||||
dt.stride(0),
|
||||
dt.stride(1),
|
||||
dt.stride(2),
|
||||
*(dt_bias.stride(0),
|
||||
dt_bias.stride(1)) if dt_bias is not None else 0,
|
||||
*(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
A.stride(2),
|
||||
@@ -349,54 +360,56 @@ def selective_state_update(state,
|
||||
)
|
||||
|
||||
|
||||
def selective_scan_fn(u,
|
||||
ssm_states,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
delta_bias=None,
|
||||
delta_softplus=False,
|
||||
query_start_loc=None,
|
||||
cache_indices=None,
|
||||
has_initial_state=None,
|
||||
pad_slot_id=PAD_SLOT_ID) -> torch.Tensor:
|
||||
def selective_scan_fn(
|
||||
u,
|
||||
ssm_states,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
delta_bias=None,
|
||||
delta_softplus=False,
|
||||
query_start_loc=None,
|
||||
cache_indices=None,
|
||||
has_initial_state=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
u: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
u: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
applies changes in place.
|
||||
ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
applies changes in place.
|
||||
delta: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
A: (dim, dstate)
|
||||
B: (ngroups, dstate, total_length) for varlen or
|
||||
A: (dim, dstate)
|
||||
B: (ngroups, dstate, total_length) for varlen or
|
||||
(batch,ngroups,dstate,seqlen)
|
||||
C: (ngroups, dstate, total_length) for varlen or
|
||||
C: (ngroups, dstate, total_length) for varlen or
|
||||
(batch,ngroups,dstate,seqlen)
|
||||
D: (dim,)
|
||||
z: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
D: (dim,)
|
||||
z: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
dt_bias: (dim,) or (dim)
|
||||
query_start_loc: (batch + 1) int32
|
||||
The cumulative sequence lengths of the sequences in
|
||||
the batch, used to index into sequence. prepended with 0.
|
||||
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
||||
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
||||
x.shape=(dim,17)
|
||||
cache_indices: (batch) int32
|
||||
A tensor with each cell is a correspondent
|
||||
A tensor with each cell is a correspondent
|
||||
input and output ssm_state index
|
||||
has_initial_state: (batch) bool
|
||||
A tensor populated with ones and zeros,
|
||||
indicate if the ssm_state at the corresponding index should be
|
||||
used as initial state. Not providing argument assumes
|
||||
A tensor populated with ones and zeros,
|
||||
indicate if the ssm_state at the corresponding index should be
|
||||
used as initial state. Not providing argument assumes
|
||||
there's no initial state
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padding entries
|
||||
that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
if cache_indices is passed, lets the kernel identify padding entries
|
||||
that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
in this case, the kernel will not process entries at indices 0 and 3
|
||||
returns
|
||||
output: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
output: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
supports inplace replacement
|
||||
"""
|
||||
if u.stride(-1) != 1:
|
||||
@@ -420,9 +433,22 @@ def selective_scan_fn(u,
|
||||
if C.dim() == 2 and query_start_loc is not None:
|
||||
C = C.unsqueeze(0)
|
||||
|
||||
ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus,
|
||||
query_start_loc, cache_indices, has_initial_state,
|
||||
ssm_states, pad_slot_id)
|
||||
ops.selective_scan_fwd(
|
||||
u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z,
|
||||
delta_bias,
|
||||
delta_softplus,
|
||||
query_start_loc,
|
||||
cache_indices,
|
||||
has_initial_state,
|
||||
ssm_states,
|
||||
pad_slot_id,
|
||||
)
|
||||
|
||||
if z is None:
|
||||
return delta # output written inplace to delta
|
||||
|
||||
@@ -14,79 +14,52 @@ from vllm.triton_utils import tl, triton
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 64
|
||||
},
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8),
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 32,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=2),
|
||||
num_warps=2,
|
||||
),
|
||||
],
|
||||
key=['chunk_size', 'K', 'IS_CAUSAL'],
|
||||
key=["chunk_size", "K", "IS_CAUSAL"],
|
||||
)
|
||||
@triton.jit
|
||||
def _bmm_chunk_fwd_kernel(
|
||||
@@ -136,24 +109,26 @@ def _bmm_chunk_fwd_kernel(
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen +
|
||||
offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk +
|
||||
offs_n[None, :] * stride_b_seqlen)
|
||||
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
|
||||
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
# compute a * b.T
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(a_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit) &
|
||||
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0).to(dot_dtype)
|
||||
b = tl.load(b_ptrs,
|
||||
mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) &
|
||||
(offs_n[None, :] < chunk_size_limit),
|
||||
other=0.0).to(dot_dtype)
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit)
|
||||
& (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0,
|
||||
).to(dot_dtype)
|
||||
b = tl.load(
|
||||
b_ptrs,
|
||||
mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K)
|
||||
& (offs_n[None, :] < chunk_size_limit),
|
||||
other=0.0,
|
||||
).to(dot_dtype)
|
||||
acc += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
@@ -163,20 +138,15 @@ def _bmm_chunk_fwd_kernel(
|
||||
|
||||
out = acc.to(out_ptr.dtype.element_ty)
|
||||
out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head
|
||||
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] +
|
||||
offs_n[None, :] * stride_outn)
|
||||
tl.store(out_ptrs,
|
||||
out,
|
||||
mask=(offs_m[:, None] < chunk_size) &
|
||||
(offs_n[None, :] < chunk_size))
|
||||
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
|
||||
tl.store(
|
||||
out_ptrs,
|
||||
out,
|
||||
mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size),
|
||||
)
|
||||
|
||||
|
||||
def _bmm_chunk_fwd(a,
|
||||
b,
|
||||
chunk_size,
|
||||
cu_chunk_seqlens,
|
||||
causal=False,
|
||||
output_dtype=None):
|
||||
def _bmm_chunk_fwd(a, b, chunk_size, cu_chunk_seqlens, causal=False, output_dtype=None):
|
||||
"""
|
||||
Argument:
|
||||
a: (seqlen, ngroups, k)
|
||||
@@ -198,16 +168,23 @@ def _bmm_chunk_fwd(a,
|
||||
nchunks = len(cu_chunk_seqlens) - 1
|
||||
# Allocates output.
|
||||
out_dtype = a.dtype if output_dtype is None else output_dtype
|
||||
out = torch.empty((nchunks, ngroups, chunk_size, chunk_size),
|
||||
device=a.device,
|
||||
dtype=out_dtype)
|
||||
dot_dtype = (tl.bfloat16
|
||||
if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
|
||||
(tl.float16 if a.dtype == torch.float16
|
||||
or b.dtype == torch.float16 else tl.float32))
|
||||
grid = lambda META: (triton.cdiv(
|
||||
chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||
chunk_size, META['BLOCK_SIZE_N']), nchunks * ngroups)
|
||||
out = torch.empty(
|
||||
(nchunks, ngroups, chunk_size, chunk_size), device=a.device, dtype=out_dtype
|
||||
)
|
||||
dot_dtype = (
|
||||
tl.bfloat16
|
||||
if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16
|
||||
else (
|
||||
tl.float16
|
||||
if a.dtype == torch.float16 or b.dtype == torch.float16
|
||||
else tl.float32
|
||||
)
|
||||
)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]),
|
||||
nchunks * ngroups,
|
||||
)
|
||||
with torch.cuda.device(a.device.index):
|
||||
_bmm_chunk_fwd_kernel[grid](
|
||||
a_ptr=a,
|
||||
|
||||
@@ -10,101 +10,68 @@ from packaging import version
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
|
||||
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 64
|
||||
},
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8),
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 64
|
||||
},
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 64
|
||||
},
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 32,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=2),
|
||||
num_warps=2,
|
||||
),
|
||||
],
|
||||
key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'],
|
||||
key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_scan_fwd_kernel(
|
||||
@@ -177,15 +144,16 @@ def _chunk_scan_fwd_kernel(
|
||||
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
cb_ptr += pid_c * stride_cb_chunk + (pid_h //
|
||||
nheads_ngroups_ratio) * stride_cb_head
|
||||
cb_ptr += pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
|
||||
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
|
||||
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
|
||||
x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
C_ptr += chunk_seqlen_start * stride_C_seqlen + (
|
||||
pid_h // nheads_ngroups_ratio) * stride_C_head
|
||||
C_ptr += (
|
||||
chunk_seqlen_start * stride_C_seqlen
|
||||
+ (pid_h // nheads_ngroups_ratio) * stride_C_head
|
||||
)
|
||||
|
||||
# M-block offsets and prev states
|
||||
# - logic in next block may override these if there is an active offset
|
||||
@@ -193,26 +161,31 @@ def _chunk_scan_fwd_kernel(
|
||||
|
||||
seq_idx_ptr += pid_c * stride_seq_idx_chunk
|
||||
seq_idx = tl.load(seq_idx_ptr)
|
||||
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_chunk,
|
||||
mask=pid_c >= 1,
|
||||
other=-1)
|
||||
seq_idx_prev = tl.load(
|
||||
seq_idx_ptr - stride_seq_idx_chunk, mask=pid_c >= 1, other=-1
|
||||
)
|
||||
|
||||
if HAS_INITSTATES and (seq_idx != seq_idx_prev):
|
||||
prev_states_ptr = initstates_ptr + seq_idx * stride_init_states_batch + pid_h * stride_init_states_head
|
||||
prev_states_ptr = (
|
||||
initstates_ptr
|
||||
+ seq_idx * stride_init_states_batch
|
||||
+ pid_h * stride_init_states_head
|
||||
)
|
||||
prev_states_hdim = stride_init_states_hdim
|
||||
prev_states_dstate = stride_init_states_dstate
|
||||
else:
|
||||
prev_states_ptr = states_ptr + (
|
||||
pid_c - 1) * stride_states_chunk + pid_h * stride_states_head
|
||||
prev_states_ptr = (
|
||||
states_ptr + (pid_c - 1) * stride_states_chunk + pid_h * stride_states_head
|
||||
)
|
||||
prev_states_hdim = stride_states_hdim
|
||||
prev_states_dstate = stride_states_dstate
|
||||
|
||||
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
|
||||
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
|
||||
mask=offs_m < chunk_size,
|
||||
other=0.0).to(tl.float32)
|
||||
dA_cs_m = tl.load(
|
||||
dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
|
||||
).to(tl.float32)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
@@ -221,52 +194,66 @@ def _chunk_scan_fwd_kernel(
|
||||
|
||||
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
||||
offs_k_dstate = tl.arange(
|
||||
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
|
||||
C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen +
|
||||
offs_k_dstate[None, :] * stride_C_dstate)
|
||||
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
|
||||
)
|
||||
C_ptrs = C_ptr + (
|
||||
offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate
|
||||
)
|
||||
|
||||
scale_m = tl.exp(dA_cs_m)
|
||||
if BLOCK_SIZE_DSTATE <= 128:
|
||||
C = tl.load(C_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit) &
|
||||
(offs_k_dstate[None, :] < dstate),
|
||||
other=0.0)
|
||||
C = tl.load(
|
||||
C_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit)
|
||||
& (offs_k_dstate[None, :] < dstate),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
if not HAS_INITSTATES and (seq_idx != seq_idx_prev):
|
||||
# if no init states AND starting a new sequence, we need zeros
|
||||
prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N),
|
||||
dtype=C_ptr.dtype.element_ty)
|
||||
prev_states = tl.zeros(
|
||||
(BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty
|
||||
)
|
||||
else:
|
||||
# otherwise read the previous state
|
||||
prev_states_ptrs = prev_states_ptr \
|
||||
+ offs_n[None, :] * prev_states_hdim \
|
||||
+ offs_k_dstate[:, None] * prev_states_dstate
|
||||
prev_states = tl.load(prev_states_ptrs,
|
||||
mask=(offs_k_dstate[:, None] < dstate) &
|
||||
(offs_n[None, :] < hdim),
|
||||
other=0.0)
|
||||
prev_states_ptrs = (
|
||||
prev_states_ptr
|
||||
+ offs_n[None, :] * prev_states_hdim
|
||||
+ offs_k_dstate[:, None] * prev_states_dstate
|
||||
)
|
||||
prev_states = tl.load(
|
||||
prev_states_ptrs,
|
||||
mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim),
|
||||
other=0.0,
|
||||
)
|
||||
prev_states = prev_states.to(C_ptr.dtype.element_ty)
|
||||
|
||||
acc = tl.dot(C, prev_states) * scale_m[:, None]
|
||||
|
||||
else:
|
||||
prev_states_ptrs = prev_states_ptr \
|
||||
+ offs_n[None, :] * prev_states_hdim \
|
||||
+ offs_k_dstate[:, None] * prev_states_dstate
|
||||
prev_states_ptrs = (
|
||||
prev_states_ptr
|
||||
+ offs_n[None, :] * prev_states_hdim
|
||||
+ offs_k_dstate[:, None] * prev_states_dstate
|
||||
)
|
||||
for k in range(0, dstate, BLOCK_SIZE_K):
|
||||
C = tl.load(C_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit) &
|
||||
(offs_k_dstate[None, :] < dstate - k),
|
||||
other=0.0)
|
||||
C = tl.load(
|
||||
C_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit)
|
||||
& (offs_k_dstate[None, :] < dstate - k),
|
||||
other=0.0,
|
||||
)
|
||||
if not HAS_INITSTATES and (seq_idx != seq_idx_prev):
|
||||
prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K),
|
||||
dtype=C_ptr.dtype.element_ty)
|
||||
prev_states = tl.zeros(
|
||||
(BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=C_ptr.dtype.element_ty
|
||||
)
|
||||
else:
|
||||
prev_states = tl.load(
|
||||
prev_states_ptrs,
|
||||
mask=(offs_k_dstate[:, None] < dstate - k) &
|
||||
(offs_n[None, :] < hdim),
|
||||
other=0.0)
|
||||
mask=(offs_k_dstate[:, None] < dstate - k)
|
||||
& (offs_n[None, :] < hdim),
|
||||
other=0.0,
|
||||
)
|
||||
prev_states = prev_states.to(C_ptr.dtype.element_ty)
|
||||
acc += tl.dot(C, prev_states)
|
||||
C_ptrs += BLOCK_SIZE_K
|
||||
@@ -274,36 +261,42 @@ def _chunk_scan_fwd_kernel(
|
||||
acc *= scale_m[:, None]
|
||||
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m +
|
||||
offs_k[None, :] * stride_cb_csize_k)
|
||||
x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen +
|
||||
offs_n[None, :] * stride_x_hdim)
|
||||
cb_ptrs = cb_ptr + (
|
||||
offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k
|
||||
)
|
||||
x_ptrs = x_ptr + (
|
||||
offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
|
||||
)
|
||||
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
K_MAX = chunk_size_limit if not IS_CAUSAL else min(
|
||||
(pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
|
||||
K_MAX = (
|
||||
chunk_size_limit
|
||||
if not IS_CAUSAL
|
||||
else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
|
||||
)
|
||||
for k in range(0, K_MAX, BLOCK_SIZE_K):
|
||||
cb = tl.load(cb_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size) &
|
||||
(offs_k[None, :] < chunk_size - k),
|
||||
other=0.0).to(tl.float32)
|
||||
dA_cs_k = tl.load(dA_cumsum_ptrs,
|
||||
mask=offs_k < chunk_size - k,
|
||||
other=0.0).to(tl.float32)
|
||||
cb = tl.load(
|
||||
cb_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
|
||||
# So we don't need masking wrt seq_idx here.
|
||||
cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :])
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k,
|
||||
other=0.0).to(tl.float32)
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
|
||||
cb *= dt_k
|
||||
if IS_CAUSAL:
|
||||
mask = offs_m[:, None] >= k + offs_k[None, :]
|
||||
cb = tl.where(mask, cb, 0.0)
|
||||
cb = cb.to(x_ptr.dtype.element_ty)
|
||||
x = tl.load(x_ptrs,
|
||||
mask=(offs_k[:, None] < chunk_size_limit - k) &
|
||||
(offs_n[None, :] < hdim),
|
||||
other=0.0)
|
||||
x = tl.load(
|
||||
x_ptrs,
|
||||
mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim),
|
||||
other=0.0,
|
||||
)
|
||||
acc += tl.dot(cb, x)
|
||||
cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
|
||||
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
||||
@@ -315,35 +308,41 @@ def _chunk_scan_fwd_kernel(
|
||||
|
||||
if HAS_D:
|
||||
if D_HAS_HDIM:
|
||||
D = tl.load(D_ptr + pid_h * stride_D_head + offs_n,
|
||||
mask=offs_n < hdim,
|
||||
other=0.0).to(tl.float32)
|
||||
D = tl.load(
|
||||
D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0
|
||||
).to(tl.float32)
|
||||
else:
|
||||
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
|
||||
x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen +
|
||||
offs_n[None, :] * stride_x_hdim),
|
||||
mask=(offs_m[:, None] < chunk_size_limit) &
|
||||
(offs_n[None, :] < hdim),
|
||||
other=0.0).to(tl.float32)
|
||||
x_residual = tl.load(
|
||||
x_ptr
|
||||
+ (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim),
|
||||
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
acc += x_residual * D
|
||||
|
||||
if HAS_Z:
|
||||
z_ptr += chunk_seqlen_start * stride_z_seqlen + pid_h * stride_z_head
|
||||
z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] +
|
||||
stride_z_hdim * offs_out_n[None, :])
|
||||
z = tl.load(z_ptrs,
|
||||
mask=(offs_out_m[:, None] < chunk_size_limit) &
|
||||
(offs_out_n[None, :] < hdim),
|
||||
other=0.0).to(tl.float32)
|
||||
z_ptrs = z_ptr + (
|
||||
stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]
|
||||
)
|
||||
z = tl.load(
|
||||
z_ptrs,
|
||||
mask=(offs_out_m[:, None] < chunk_size_limit)
|
||||
& (offs_out_n[None, :] < hdim),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
acc *= z * tl.sigmoid(z)
|
||||
|
||||
out_ptr += chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head
|
||||
out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] +
|
||||
offs_out_n[None, :] * stride_out_hdim)
|
||||
tl.store(out_ptrs,
|
||||
acc,
|
||||
mask=(offs_out_m[:, None] < chunk_size_limit) &
|
||||
(offs_out_n[None, :] < hdim))
|
||||
out_ptrs = out_ptr + (
|
||||
stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim
|
||||
)
|
||||
tl.store(
|
||||
out_ptrs,
|
||||
acc,
|
||||
mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim),
|
||||
)
|
||||
|
||||
|
||||
def _chunk_scan_fwd(
|
||||
@@ -369,24 +368,32 @@ def _chunk_scan_fwd(
|
||||
assert C.shape == (seqlen, ngroups, dstate)
|
||||
assert cb.shape == (nchunks, ngroups, chunk_size, chunk_size)
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, headdim) or D.shape == (nheads, )
|
||||
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
|
||||
assert states.shape == (nchunks, nheads, headdim, dstate)
|
||||
assert seq_idx.shape == (nchunks, )
|
||||
assert seq_idx.shape == (nchunks,)
|
||||
|
||||
grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton
|
||||
.cdiv(headdim, META['BLOCK_SIZE_N']), nchunks, nheads)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
|
||||
nchunks,
|
||||
nheads,
|
||||
)
|
||||
|
||||
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
|
||||
(0, 0, 0))
|
||||
initial_states_strides = ((initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2),
|
||||
initial_states.stride(3))
|
||||
if initial_states is not None else (0, 0, 0, 0))
|
||||
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
|
||||
initial_states_strides = (
|
||||
(
|
||||
initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2),
|
||||
initial_states.stride(3),
|
||||
)
|
||||
if initial_states is not None
|
||||
else (0, 0, 0, 0)
|
||||
)
|
||||
|
||||
_chunk_scan_fwd_kernel[grid](
|
||||
cb_ptr=cb,
|
||||
|
||||
@@ -15,14 +15,14 @@ from .mamba_ssm import softplus
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_H': 2}),
|
||||
triton.Config({'BLOCK_SIZE_H': 4}),
|
||||
triton.Config({'BLOCK_SIZE_H': 8}),
|
||||
triton.Config({'BLOCK_SIZE_H': 16}),
|
||||
triton.Config({'BLOCK_SIZE_H': 32}),
|
||||
triton.Config({'BLOCK_SIZE_H': 64}),
|
||||
triton.Config({"BLOCK_SIZE_H": 2}),
|
||||
triton.Config({"BLOCK_SIZE_H": 4}),
|
||||
triton.Config({"BLOCK_SIZE_H": 8}),
|
||||
triton.Config({"BLOCK_SIZE_H": 16}),
|
||||
triton.Config({"BLOCK_SIZE_H": 32}),
|
||||
triton.Config({"BLOCK_SIZE_H": 64}),
|
||||
],
|
||||
key=['chunk_size', 'nheads'],
|
||||
key=["chunk_size", "nheads"],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_cumsum_fwd_kernel(
|
||||
@@ -70,118 +70,99 @@ def _chunk_cumsum_fwd_kernel(
|
||||
|
||||
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
||||
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
||||
dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head +
|
||||
offs_c[None, :] * stride_dt_seqlen)
|
||||
dt_ptrs = dt_ptr + (
|
||||
offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
|
||||
)
|
||||
A_ptrs = A_ptr + offs_h * stride_A_head
|
||||
dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head +
|
||||
offs_c[None, :] * stride_dt_out_csize)
|
||||
dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head +
|
||||
offs_c[None, :] * stride_dA_cs_csize)
|
||||
dt_out_ptrs = dt_out_ptr + (
|
||||
offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize
|
||||
)
|
||||
dA_cs_ptrs = dA_cumsum_ptr + (
|
||||
offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize
|
||||
)
|
||||
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
|
||||
|
||||
dt = tl.load(dt_ptrs,
|
||||
mask=(offs_h[:, None] < nheads) &
|
||||
(offs_c[None, :] < chunk_size_limit),
|
||||
other=0.0).to(tl.float32)
|
||||
dt = tl.load(
|
||||
dt_ptrs,
|
||||
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head,
|
||||
mask=offs_h < nheads,
|
||||
other=0.0).to(tl.float32)
|
||||
dt_bias = tl.load(
|
||||
dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
|
||||
).to(tl.float32)
|
||||
dt += dt_bias[:, None]
|
||||
if DT_SOFTPLUS:
|
||||
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
||||
|
||||
dt = tl.clamp(dt, dt_min, dt_max)
|
||||
dt = tl.where(
|
||||
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt,
|
||||
0.0)
|
||||
tl.store(dt_out_ptrs,
|
||||
dt,
|
||||
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
|
||||
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
|
||||
)
|
||||
tl.store(
|
||||
dt_out_ptrs,
|
||||
dt,
|
||||
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
|
||||
)
|
||||
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
||||
dA = dt * A[:, None]
|
||||
dA_cs = tl.cumsum(dA, axis=1)
|
||||
tl.store(dA_cs_ptrs,
|
||||
dA_cs,
|
||||
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
|
||||
tl.store(
|
||||
dA_cs_ptrs,
|
||||
dA_cs,
|
||||
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
|
||||
)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 64
|
||||
},
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8),
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 32,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=2),
|
||||
num_warps=2,
|
||||
),
|
||||
],
|
||||
key=['hdim', 'dstate', 'chunk_size'],
|
||||
key=["hdim", "dstate", "chunk_size"],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_state_fwd_kernel(
|
||||
@@ -227,8 +208,10 @@ def _chunk_state_fwd_kernel(
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
|
||||
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
|
||||
b_ptr += chunk_seqlen_start * stride_b_seqlen + (
|
||||
pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||
b_ptr += (
|
||||
chunk_seqlen_start * stride_b_seqlen
|
||||
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||
)
|
||||
x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
@@ -236,32 +219,38 @@ def _chunk_state_fwd_kernel(
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim +
|
||||
offs_k[None, :] * stride_x_seqlen)
|
||||
b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate +
|
||||
offs_k[:, None] * stride_b_seqlen)
|
||||
x_ptrs = x_ptr + (
|
||||
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
|
||||
)
|
||||
b_ptrs = b_ptr + (
|
||||
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
|
||||
)
|
||||
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
||||
dA_cs_last = tl.load(dA_cumsum_ptr +
|
||||
(chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
|
||||
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
||||
tl.float32
|
||||
)
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
|
||||
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
||||
x = tl.load(x_ptrs,
|
||||
mask=(offs_m[:, None] < hdim) &
|
||||
(offs_k[None, :] < chunk_size_limit - k),
|
||||
other=0.0)
|
||||
b = tl.load(b_ptrs,
|
||||
mask=(offs_k[:, None] < chunk_size_limit - k) &
|
||||
(offs_n[None, :] < dstate),
|
||||
other=0.0).to(tl.float32)
|
||||
dA_cs_k = tl.load(dA_cumsum_ptrs,
|
||||
mask=offs_k < chunk_size_limit - k,
|
||||
other=0.0).to(tl.float32)
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k,
|
||||
other=0.0).to(tl.float32)
|
||||
x = tl.load(
|
||||
x_ptrs,
|
||||
mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(
|
||||
b_ptrs,
|
||||
mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
dA_cs_k = tl.load(
|
||||
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
|
||||
).to(tl.float32)
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k
|
||||
b *= scale[:, None]
|
||||
b = b.to(x_ptr.dtype.element_ty)
|
||||
@@ -277,8 +266,9 @@ def _chunk_state_fwd_kernel(
|
||||
states_ptr += pid_c * stride_states_chunk + pid_h * stride_states_head
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim +
|
||||
offs_n[None, :] * stride_states_dstate)
|
||||
states_ptrs = states_ptr + (
|
||||
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
|
||||
)
|
||||
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
||||
tl.store(states_ptrs, states, mask=c_mask)
|
||||
|
||||
@@ -286,79 +276,52 @@ def _chunk_state_fwd_kernel(
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 64
|
||||
},
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8),
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 32,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=2),
|
||||
num_warps=2,
|
||||
),
|
||||
],
|
||||
key=['hdim', 'dstate', 'chunk_size'],
|
||||
key=["hdim", "dstate", "chunk_size"],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_state_varlen_kernel(
|
||||
@@ -414,12 +377,16 @@ def _chunk_state_varlen_kernel(
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
|
||||
pid_c = (end_idx - 1) // chunk_size
|
||||
b_ptr += pid_c * chunk_size * stride_b_seqlen + (
|
||||
pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||
b_ptr += (
|
||||
pid_c * chunk_size * stride_b_seqlen
|
||||
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||
)
|
||||
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
|
||||
chunk_states_ptr += (
|
||||
pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
|
||||
)
|
||||
|
||||
if HAS_INITSTATES:
|
||||
# if there are init states provided, we differentiate between states (which
|
||||
@@ -430,13 +397,16 @@ def _chunk_state_varlen_kernel(
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim +
|
||||
offs_k[None, :] * stride_x_seqlen)
|
||||
b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate +
|
||||
offs_k[:, None] * stride_b_seqlen)
|
||||
x_ptrs = x_ptr + (
|
||||
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
|
||||
)
|
||||
b_ptrs = b_ptr + (
|
||||
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
|
||||
)
|
||||
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
||||
dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) *
|
||||
stride_dA_cs_csize).to(tl.float32)
|
||||
dA_cs_last = tl.load(
|
||||
dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
|
||||
).to(tl.float32)
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
|
||||
chunk_size_limit = end_idx - pid_c * chunk_size
|
||||
@@ -445,24 +415,31 @@ def _chunk_state_varlen_kernel(
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
||||
x = tl.load(x_ptrs,
|
||||
mask=(offs_m[:, None] < hdim) &
|
||||
(offs_k[None, :] < chunk_size_limit - k) &
|
||||
(offs_k[None, :] >= start_idx_cur - k),
|
||||
other=0.0)
|
||||
b = tl.load(b_ptrs,
|
||||
mask=(offs_k[:, None] < chunk_size_limit - k) &
|
||||
(offs_n[None, :] < dstate) &
|
||||
(offs_k[:, None] >= start_idx_cur - k),
|
||||
other=0.0).to(tl.float32)
|
||||
dA_cs_k = tl.load(dA_cumsum_ptrs,
|
||||
mask=offs_k < chunk_size_limit - k,
|
||||
other=0.0).to(tl.float32)
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k,
|
||||
other=0.0).to(tl.float32)
|
||||
x = tl.load(
|
||||
x_ptrs,
|
||||
mask=(offs_m[:, None] < hdim)
|
||||
& (offs_k[None, :] < chunk_size_limit - k)
|
||||
& (offs_k[None, :] >= start_idx_cur - k),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(
|
||||
b_ptrs,
|
||||
mask=(offs_k[:, None] < chunk_size_limit - k)
|
||||
& (offs_n[None, :] < dstate)
|
||||
& (offs_k[:, None] >= start_idx_cur - k),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
dA_cs_k = tl.load(
|
||||
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
|
||||
).to(tl.float32)
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
scale = tl.where(
|
||||
(offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
|
||||
tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0)
|
||||
tl.exp(dA_cs_last - dA_cs_k) * dt_k,
|
||||
0.0,
|
||||
)
|
||||
b *= scale[:, None]
|
||||
b = b.to(x_ptr.dtype.element_ty)
|
||||
acc += tl.dot(x, b)
|
||||
@@ -475,39 +452,43 @@ def _chunk_state_varlen_kernel(
|
||||
# If HAS_INITSTATES==True need to consider two possibilities
|
||||
# - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs
|
||||
# - if state_idx >= pid * chunk_size, then we need to insert initstates
|
||||
if ((start_idx < pid_c * chunk_size) # first chunk
|
||||
or (HAS_INITSTATES)):
|
||||
|
||||
if (
|
||||
(start_idx < pid_c * chunk_size) # first chunk
|
||||
or (HAS_INITSTATES)
|
||||
):
|
||||
dA_cs_boundary = 0.0 # default
|
||||
|
||||
if not HAS_INITSTATES:
|
||||
past_states_ptrs = chunk_states_ptr + (
|
||||
offs_m[:, None] * stride_chunk_states_hdim +
|
||||
offs_n[None, :] * stride_chunk_states_dstate)
|
||||
offs_m[:, None] * stride_chunk_states_hdim
|
||||
+ offs_n[None, :] * stride_chunk_states_dstate
|
||||
)
|
||||
else:
|
||||
|
||||
# - this seems repetitive, buts its to help the compiler
|
||||
if start_idx < pid_c * chunk_size:
|
||||
past_states_ptrs = chunk_states_ptr + (
|
||||
offs_m[:, None] * stride_chunk_states_hdim +
|
||||
offs_n[None, :] * stride_chunk_states_dstate)
|
||||
offs_m[:, None] * stride_chunk_states_hdim
|
||||
+ offs_n[None, :] * stride_chunk_states_dstate
|
||||
)
|
||||
else:
|
||||
past_states_ptrs = initstates_ptr + (
|
||||
pid_b * stride_init_states_batch +
|
||||
offs_m[:, None] * stride_init_states_hdim +
|
||||
offs_n[None, :] * stride_init_states_dstate)
|
||||
pid_b * stride_init_states_batch
|
||||
+ offs_m[:, None] * stride_init_states_hdim
|
||||
+ offs_n[None, :] * stride_init_states_dstate
|
||||
)
|
||||
|
||||
# need to adjust the boundary
|
||||
if start_idx > pid_c * chunk_size:
|
||||
dA_cs_boundary = tl.load(dA_cumsum_ptr +
|
||||
(start_idx - pid_c * chunk_size -
|
||||
1) * stride_dA_cs_csize).to(
|
||||
tl.float32)
|
||||
dA_cs_boundary = tl.load(
|
||||
dA_cumsum_ptr
|
||||
+ (start_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
|
||||
).to(tl.float32)
|
||||
|
||||
past_states = tl.load(past_states_ptrs,
|
||||
mask=(offs_m[:, None] < hdim) &
|
||||
(offs_n[None, :] < dstate),
|
||||
other=0.0).to(tl.float32)
|
||||
past_states = tl.load(
|
||||
past_states_ptrs,
|
||||
mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
|
||||
scale = tl.exp(dA_cs_last - dA_cs_boundary)
|
||||
acc += past_states * scale
|
||||
@@ -517,36 +498,34 @@ def _chunk_state_varlen_kernel(
|
||||
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim +
|
||||
offs_n[None, :] * stride_states_dstate)
|
||||
states_ptrs = states_ptr + (
|
||||
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
|
||||
)
|
||||
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
||||
tl.store(states_ptrs, states, mask=c_mask)
|
||||
|
||||
|
||||
def _chunk_cumsum_fwd(dt,
|
||||
A,
|
||||
chunk_size,
|
||||
cu_chunk_seqlens,
|
||||
dt_bias=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf"))):
|
||||
def _chunk_cumsum_fwd(
|
||||
dt,
|
||||
A,
|
||||
chunk_size,
|
||||
cu_chunk_seqlens,
|
||||
dt_bias=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
):
|
||||
seqlen, nheads = dt.shape
|
||||
assert A.shape == (nheads, )
|
||||
assert A.shape == (nheads,)
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, )
|
||||
assert dt_bias.shape == (nheads,)
|
||||
nchunks = cu_chunk_seqlens.shape[0] - 1
|
||||
dt_out = torch.empty(nheads,
|
||||
nchunks,
|
||||
chunk_size,
|
||||
device=dt.device,
|
||||
dtype=torch.float32)
|
||||
dA_cumsum = torch.empty(nheads,
|
||||
nchunks,
|
||||
chunk_size,
|
||||
device=dt.device,
|
||||
dtype=torch.float32)
|
||||
grid_chunk_cs = lambda META: (nchunks,
|
||||
triton.cdiv(nheads, META['BLOCK_SIZE_H']))
|
||||
dt_out = torch.empty(
|
||||
nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
||||
)
|
||||
dA_cumsum = torch.empty(
|
||||
nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
||||
)
|
||||
grid_chunk_cs = lambda META: (nchunks, triton.cdiv(nheads, META["BLOCK_SIZE_H"]))
|
||||
with torch.cuda.device(dt.device.index):
|
||||
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
|
||||
dt_ptr=dt,
|
||||
@@ -563,8 +542,7 @@ def _chunk_cumsum_fwd(dt,
|
||||
stride_dt_seqlen=dt.stride(0),
|
||||
stride_dt_head=dt.stride(1),
|
||||
stride_A_head=A.stride(0),
|
||||
stride_dt_bias_head=dt_bias.stride(0)
|
||||
if dt_bias is not None else 0,
|
||||
stride_dt_bias_head=dt_bias.stride(0) if dt_bias is not None else 0,
|
||||
stride_dt_out_head=dt_out.stride(0),
|
||||
stride_dt_out_chunk=dt_out.stride(1),
|
||||
stride_dt_out_csize=dt_out.stride(2),
|
||||
@@ -578,13 +556,9 @@ def _chunk_cumsum_fwd(dt,
|
||||
return dA_cumsum, dt_out
|
||||
|
||||
|
||||
def _chunk_state_fwd(B,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
cu_chunk_seqlens,
|
||||
states=None,
|
||||
states_in_fp32=True):
|
||||
def _chunk_state_fwd(
|
||||
B, x, dt, dA_cumsum, cu_chunk_seqlens, states=None, states_in_fp32=True
|
||||
):
|
||||
seqlen, nheads, headdim = x.shape
|
||||
_, nchunks, chunk_size = dt.shape
|
||||
_, ngroups, dstate = B.shape
|
||||
@@ -597,12 +571,16 @@ def _chunk_state_fwd(B,
|
||||
assert states.shape == (nchunks, nheads, headdim, dstate)
|
||||
else:
|
||||
states_dtype = torch.float32 if states_in_fp32 else B.dtype
|
||||
states = torch.empty((nchunks, nheads, headdim, dstate),
|
||||
device=x.device,
|
||||
dtype=states_dtype)
|
||||
states = torch.empty(
|
||||
(nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype
|
||||
)
|
||||
|
||||
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.
|
||||
cdiv(dstate, META['BLOCK_SIZE_N']), nchunks, nheads)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
||||
nchunks,
|
||||
nheads,
|
||||
)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_chunk_state_fwd_kernel[grid](
|
||||
x_ptr=x,
|
||||
@@ -636,13 +614,9 @@ def _chunk_state_fwd(B,
|
||||
return states
|
||||
|
||||
|
||||
def chunk_state_varlen(B,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
cu_seqlens,
|
||||
chunk_states,
|
||||
initial_states=None):
|
||||
def chunk_state_varlen(
|
||||
B, x, dt, dA_cumsum, cu_seqlens, chunk_states, initial_states=None
|
||||
):
|
||||
total_seqlen, nheads, headdim = x.shape
|
||||
_, nchunks, chunk_size = dt.shape
|
||||
_, ngroups, dstate = B.shape
|
||||
@@ -657,21 +631,32 @@ def chunk_state_varlen(B,
|
||||
if initial_states is not None:
|
||||
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
||||
|
||||
states = torch.empty(batch,
|
||||
nheads,
|
||||
headdim,
|
||||
dstate,
|
||||
dtype=chunk_states.dtype,
|
||||
device=chunk_states.device)
|
||||
states = torch.empty(
|
||||
batch,
|
||||
nheads,
|
||||
headdim,
|
||||
dstate,
|
||||
dtype=chunk_states.dtype,
|
||||
device=chunk_states.device,
|
||||
)
|
||||
|
||||
initial_states_strides = ((initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2),
|
||||
initial_states.stride(3))
|
||||
if initial_states is not None else (0, 0, 0, 0))
|
||||
initial_states_strides = (
|
||||
(
|
||||
initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2),
|
||||
initial_states.stride(3),
|
||||
)
|
||||
if initial_states is not None
|
||||
else (0, 0, 0, 0)
|
||||
)
|
||||
|
||||
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.
|
||||
cdiv(dstate, META['BLOCK_SIZE_N']), batch, nheads)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
||||
batch,
|
||||
nheads,
|
||||
)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_chunk_state_varlen_kernel[grid](
|
||||
x_ptr=x,
|
||||
@@ -710,5 +695,6 @@ def chunk_state_varlen(B,
|
||||
stride_init_states_head=initial_states_strides[1],
|
||||
stride_init_states_hdim=initial_states_strides[2],
|
||||
stride_init_states_dstate=initial_states_strides[3],
|
||||
HAS_INITSTATES=initial_states is not None)
|
||||
HAS_INITSTATES=initial_states is not None,
|
||||
)
|
||||
return states
|
||||
|
||||
@@ -17,63 +17,66 @@ from .ssd_chunk_scan import _chunk_scan_fwd
|
||||
from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd
|
||||
from .ssd_state_passing import _state_passing_fwd
|
||||
|
||||
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
|
||||
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
|
||||
|
||||
|
||||
def is_int_pow_2(n):
|
||||
return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
|
||||
|
||||
|
||||
def _mamba_chunk_scan_combined_fwd(x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
out,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
return_intermediate_states=False,
|
||||
seq_idx=None,
|
||||
cu_seqlens=None,
|
||||
cu_chunk_seqlens=None,
|
||||
last_chunk_indices=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
state_dtype=None):
|
||||
def _mamba_chunk_scan_combined_fwd(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
out,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
return_intermediate_states=False,
|
||||
seq_idx=None,
|
||||
cu_seqlens=None,
|
||||
cu_chunk_seqlens=None,
|
||||
last_chunk_indices=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
state_dtype=None,
|
||||
):
|
||||
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
|
||||
seqlen, nheads, headdim = x.shape
|
||||
_, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (seqlen, ngroups, dstate)
|
||||
assert dt.shape == (seqlen, nheads)
|
||||
assert A.shape == (nheads, )
|
||||
assert A.shape == (nheads,)
|
||||
assert C.shape == B.shape
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, headdim) or D.shape == (nheads, )
|
||||
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.shape == (cu_chunk_seqlens.shape[0] - 1, )
|
||||
assert seq_idx.shape == (cu_chunk_seqlens.shape[0] - 1,)
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if x.stride(-1) != 1 and x.stride(
|
||||
0) != 1: # Either M or K dimension should be contiguous
|
||||
if (
|
||||
x.stride(-1) != 1 and x.stride(0) != 1
|
||||
): # Either M or K dimension should be contiguous
|
||||
x = x.contiguous()
|
||||
if z is not None and z.stride(-1) != 1 and z.stride(
|
||||
0) != 1: # Either M or K dimension should be contiguous
|
||||
if (
|
||||
z is not None and z.stride(-1) != 1 and z.stride(0) != 1
|
||||
): # Either M or K dimension should be contiguous
|
||||
z = z.contiguous()
|
||||
if D is not None and D.stride(-1) != 1:
|
||||
D = D.contiguous()
|
||||
assert cu_seqlens is not None, "Assuming varlen input - must supply cu_seqlens"
|
||||
|
||||
if initial_states is not None:
|
||||
assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim,
|
||||
dstate)
|
||||
assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim, dstate)
|
||||
|
||||
# This function executes 5 sub-functions for computing mamba
|
||||
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
|
||||
@@ -86,22 +89,21 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
|
||||
# 1. Compute chunked cumsum of A * dt
|
||||
# - here dt may go through a softplus activation
|
||||
dA_cumsum, dt = _chunk_cumsum_fwd(dt,
|
||||
A,
|
||||
chunk_size,
|
||||
cu_chunk_seqlens,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=dt_softplus,
|
||||
dt_limit=dt_limit)
|
||||
dA_cumsum, dt = _chunk_cumsum_fwd(
|
||||
dt,
|
||||
A,
|
||||
chunk_size,
|
||||
cu_chunk_seqlens,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=dt_softplus,
|
||||
dt_limit=dt_limit,
|
||||
)
|
||||
|
||||
# 2. Compute the state for each intra-chunk
|
||||
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
||||
states = _chunk_state_fwd(B,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
cu_chunk_seqlens,
|
||||
states_in_fp32=True)
|
||||
states = _chunk_state_fwd(
|
||||
B, x, dt, dA_cumsum, cu_chunk_seqlens, states_in_fp32=True
|
||||
)
|
||||
|
||||
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
||||
# (middle term of factorization of off-diag blocks; A terms)
|
||||
@@ -114,18 +116,15 @@ def _mamba_chunk_scan_combined_fwd(x,
|
||||
dA_cumsum, # (nheads, nchunks, chunk_size)
|
||||
cu_chunk_seqlens,
|
||||
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
|
||||
if initial_states is not None else
|
||||
None, # (batch, nheads, headdim*dstate)
|
||||
if initial_states is not None
|
||||
else None, # (batch, nheads, headdim*dstate)
|
||||
seq_idx=seq_idx,
|
||||
out_dtype=state_dtype if state_dtype is not None else C.dtype)
|
||||
out_dtype=state_dtype if state_dtype is not None else C.dtype,
|
||||
)
|
||||
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
||||
|
||||
# 4. Compute batched matrix multiply for C_j^T B_i terms
|
||||
CB = _bmm_chunk_fwd(C,
|
||||
B,
|
||||
chunk_size,
|
||||
cu_chunk_seqlens,
|
||||
output_dtype=torch.float32)
|
||||
CB = _bmm_chunk_fwd(C, B, chunk_size, cu_chunk_seqlens, output_dtype=torch.float32)
|
||||
|
||||
# 5. Scan and compute the diagonal blocks, taking into
|
||||
# account past causal states.
|
||||
@@ -225,6 +224,7 @@ def mamba_chunk_scan_combined_varlen(
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
dt_softplus=dt_softplus,
|
||||
dt_limit=dt_limit,
|
||||
state_dtype=state_dtype)
|
||||
state_dtype=state_dtype,
|
||||
)
|
||||
|
||||
return varlen_states
|
||||
|
||||
@@ -13,14 +13,14 @@ from vllm.triton_utils import tl, triton
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE': 64}),
|
||||
triton.Config({'BLOCK_SIZE': 128}),
|
||||
triton.Config({'BLOCK_SIZE': 256}),
|
||||
triton.Config({'BLOCK_SIZE': 512}),
|
||||
triton.Config({'BLOCK_SIZE': 1024}),
|
||||
triton.Config({'BLOCK_SIZE': 2048}),
|
||||
triton.Config({"BLOCK_SIZE": 64}),
|
||||
triton.Config({"BLOCK_SIZE": 128}),
|
||||
triton.Config({"BLOCK_SIZE": 256}),
|
||||
triton.Config({"BLOCK_SIZE": 512}),
|
||||
triton.Config({"BLOCK_SIZE": 1024}),
|
||||
triton.Config({"BLOCK_SIZE": 2048}),
|
||||
],
|
||||
key=['dim'],
|
||||
key=["dim"],
|
||||
)
|
||||
@triton.jit
|
||||
def _state_passing_fwd_kernel(
|
||||
@@ -58,8 +58,7 @@ def _state_passing_fwd_kernel(
|
||||
pid_m = tl.program_id(axis=0)
|
||||
|
||||
states_ptr += pid_h * stride_states_head
|
||||
dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size -
|
||||
1) * stride_dA_cs_csize
|
||||
dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size - 1) * stride_dA_cs_csize
|
||||
out_ptr += pid_h * stride_out_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
@@ -67,31 +66,35 @@ def _state_passing_fwd_kernel(
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
|
||||
if HAS_INITSTATES:
|
||||
initstates_ptrs = initstates_ptr \
|
||||
+ pid_h * stride_initstates_head \
|
||||
initstates_ptrs = (
|
||||
initstates_ptr
|
||||
+ pid_h * stride_initstates_head
|
||||
+ offs_m * stride_initstates_dim
|
||||
)
|
||||
|
||||
states = tl.load(initstates_ptrs, mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
else:
|
||||
states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
|
||||
states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
||||
|
||||
prev_seq_idx = 0
|
||||
for c in range(nchunks):
|
||||
new_states = tl.load(states_ptrs, mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
||||
seq_idx = tl.load(seq_idx_ptr + c * stride_seq_idx_chunk)
|
||||
# we have started a new sequence
|
||||
if prev_seq_idx != seq_idx:
|
||||
if HAS_INITSTATES:
|
||||
initstates_ptrs = initstates_ptr + seq_idx * stride_initstates_batch \
|
||||
+ pid_h * stride_initstates_head \
|
||||
initstates_ptrs = (
|
||||
initstates_ptr
|
||||
+ seq_idx * stride_initstates_batch
|
||||
+ pid_h * stride_initstates_head
|
||||
+ offs_m * stride_initstates_dim
|
||||
states = tl.load(initstates_ptrs, mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
)
|
||||
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
else:
|
||||
states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
|
||||
states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
||||
|
||||
prev_seq_idx = seq_idx
|
||||
states = tl.exp(dA_cs) * states + new_states
|
||||
@@ -115,16 +118,15 @@ def _state_passing_fwd(
|
||||
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
|
||||
seqlen = seq_idx.shape[-1]
|
||||
out_dtype = states.dtype if out_dtype is None else out_dtype
|
||||
out = torch.empty((nchunks, nheads, dim),
|
||||
device=states.device,
|
||||
dtype=out_dtype)
|
||||
out = torch.empty((nchunks, nheads, dim), device=states.device, dtype=out_dtype)
|
||||
|
||||
initial_states_strides = ((initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2))
|
||||
if initial_states is not None else (0, 0, 0))
|
||||
initial_states_strides = (
|
||||
(initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))
|
||||
if initial_states is not None
|
||||
else (0, 0, 0)
|
||||
)
|
||||
|
||||
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), nheads)
|
||||
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), nheads)
|
||||
with torch.cuda.device(states.device.index):
|
||||
_state_passing_fwd_kernel[grid](
|
||||
states_ptr=states,
|
||||
|
||||
@@ -13,29 +13,35 @@ from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
causal_conv1d_fn,
|
||||
causal_conv1d_update,
|
||||
)
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.short_conv_attn import (
|
||||
ShortConvAttentionMetadata)
|
||||
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionMetadata
|
||||
|
||||
|
||||
@CustomOp.register("short_conv")
|
||||
class ShortConv(MambaBase, CustomOp):
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
dim: int,
|
||||
layer_idx: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
dim: int,
|
||||
layer_idx: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
@@ -72,7 +78,7 @@ class ShortConv(MambaBase, CustomOp):
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.kv_cache = (torch.tensor([]), )
|
||||
self.kv_cache = (torch.tensor([]),)
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
@@ -121,8 +127,9 @@ class ShortConv(MambaBase, CustomOp):
|
||||
|
||||
B, C, x = BCx.chunk(3, dim=-1)
|
||||
|
||||
conv_weights = self.conv.weight.view(self.conv.weight.size(0),
|
||||
self.conv.weight.size(2))
|
||||
conv_weights = self.conv.weight.view(
|
||||
self.conv.weight.size(0), self.conv.weight.size(2)
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# V1 profile run
|
||||
@@ -163,23 +170,26 @@ class ShortConv(MambaBase, CustomOp):
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (
|
||||
attn_metadata.query_start_loc[-num_prefills - 1:] -
|
||||
num_decodes if has_prefill else None)
|
||||
attn_metadata.query_start_loc[-num_prefills - 1 :] - num_decodes
|
||||
if has_prefill
|
||||
else None
|
||||
)
|
||||
|
||||
conv_output_list = []
|
||||
|
||||
if has_prefill:
|
||||
Bx_p = (B_p * x_p).transpose(0, 1)
|
||||
Bx = causal_conv1d_fn(Bx_p,
|
||||
conv_weights,
|
||||
self.conv.bias,
|
||||
activation=None,
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
metadata=attn_metadata,
|
||||
query_start_loc=query_start_loc_p).transpose(
|
||||
0, 1)[:num_prefill_tokens]
|
||||
Bx = causal_conv1d_fn(
|
||||
Bx_p,
|
||||
conv_weights,
|
||||
self.conv.bias,
|
||||
activation=None,
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
metadata=attn_metadata,
|
||||
query_start_loc=query_start_loc_p,
|
||||
).transpose(0, 1)[:num_prefill_tokens]
|
||||
|
||||
y = C_p * Bx
|
||||
conv_output_list.append(y)
|
||||
@@ -192,7 +202,8 @@ class ShortConv(MambaBase, CustomOp):
|
||||
conv_weights,
|
||||
self.conv.bias,
|
||||
activation=None,
|
||||
conv_state_indices=state_indices_tensor_d)
|
||||
conv_state_indices=state_indices_tensor_d,
|
||||
)
|
||||
y = C_d * Bx
|
||||
conv_output_list.insert(0, y)
|
||||
|
||||
@@ -222,8 +233,8 @@ class ShortConv(MambaBase, CustomOp):
|
||||
return "short_conv"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.short_conv_attn import (
|
||||
ShortConvAttentionBackend)
|
||||
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
|
||||
|
||||
return ShortConvAttentionBackend
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user