Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -19,16 +19,21 @@ from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.lightning_attn import (
lightning_attention, linear_decode_forward_triton)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
lightning_attention,
linear_decode_forward_triton,
)
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
@@ -47,8 +52,7 @@ class MiniMaxText01RMSNormTP(CustomOp):
super().__init__()
self.tp_world = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.weight = nn.Parameter(torch.ones(int(hidden_size /
self.tp_world)))
self.weight = nn.Parameter(torch.ones(int(hidden_size / self.tp_world)))
self.weight.weight_loader = self.weight_loader
self.variance_epsilon = eps
@@ -75,8 +79,7 @@ class MiniMaxText01RMSNormTP(CustomOp):
x = x.to(torch.float32)
variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
if self.tp_world > 1:
variance = tensor_model_parallel_all_reduce(
variance) / self.tp_world
variance = tensor_model_parallel_all_reduce(variance) / self.tp_world
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
return x
@@ -91,17 +94,17 @@ class MiniMaxText01RMSNormTP(CustomOp):
class MiniMaxText01LinearKernel:
@staticmethod
def jit_linear_forward_prefix(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_caches: torch.Tensor,
slope_rate: torch.Tensor,
block_size: int,
layer_idx: Optional[int] = None,
**kwargs) -> torch.Tensor:
def jit_linear_forward_prefix(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_caches: torch.Tensor,
slope_rate: torch.Tensor,
block_size: int,
layer_idx: Optional[int] = None,
**kwargs,
) -> torch.Tensor:
slope_rate = slope_rate.to(torch.float32)
should_pad_dim = q.dim() == 3
if should_pad_dim:
@@ -111,26 +114,22 @@ class MiniMaxText01LinearKernel:
b, h, n, d = q.shape
e = d
kv_history = kv_caches.reshape(1, h, d, e).contiguous()
output, kv_history = lightning_attention(q,
k,
v,
slope_rate,
block_size=block_size,
kv_history=kv_history)
output, kv_history = lightning_attention(
q, k, v, slope_rate, block_size=block_size, kv_history=kv_history
)
kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e))
assert output.shape[0] == 1, "batch size must be 1"
return rearrange(output.squeeze(0), "h n d -> n (h d)")
class MiniMaxText01LinearAttention(nn.Module, MambaBase):
@property
def mamba_type(self) -> str:
return "linear_attention"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.linear_attn import (
LinearAttentionBackend)
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
return LinearAttentionBackend
def get_state_dtype(self) -> tuple[torch.dtype]:
@@ -143,9 +142,8 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
def get_state_shape(self) -> tuple[tuple[int, int, int], ...]:
return MambaStateShapeCalculator.linear_attention_state_shape(
num_heads=self.num_heads,
tp_size=self.tp_size,
head_dim=self.head_dim)
num_heads=self.num_heads, tp_size=self.tp_size, head_dim=self.head_dim
)
def __init__(
self,
@@ -209,16 +207,16 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
eps=1e-5,
)
slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(
self.num_heads)
slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(self.num_heads)
if num_hidden_layer <= 1:
self.slope_rate = slope_rate * (1 + 1e-5)
else:
self.slope_rate = slope_rate * (1 - layer_idx /
(num_hidden_layer - 1) + 1e-5)
self.tp_slope = self.slope_rate[self.tp_rank *
self.tp_heads:(self.tp_rank + 1) *
self.tp_heads].contiguous()
self.slope_rate = slope_rate * (
1 - layer_idx / (num_hidden_layer - 1) + 1e-5
)
self.tp_slope = self.slope_rate[
self.tp_rank * self.tp_heads : (self.tp_rank + 1) * self.tp_heads
].contiguous()
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
@@ -226,36 +224,36 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
compilation_config.static_forward_context[prefix] = self
@staticmethod
def weight_direct_load(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
def weight_direct_load(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
return
@staticmethod
def _build_slope_tensor(n_attention_heads: int):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2**(-(2**-(math.log2(n) - 3)))
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2**math.floor(math.log2(n))
return (get_slopes_power_of_2(closest_power_of_2) + get_slopes(
2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
slopes = torch.tensor(get_slopes(n_attention_heads),
dtype=torch.float32).reshape(
n_attention_heads, 1, 1)
slopes = torch.tensor(
get_slopes(n_attention_heads), dtype=torch.float32
).reshape(n_attention_heads, 1, 1)
return slopes
def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
attn_metadata):
def _prefill_and_mix_infer(
self, q, k, v, kv_cache, state_indices_tensor, attn_metadata
):
hidden = []
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
if _prefill_idx >= len(attn_metadata.query_start_loc):
@@ -278,12 +276,13 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
slice_layer_cache,
self.tp_slope,
self.BLOCK,
layer_idx=self.layer_idx)
layer_idx=self.layer_idx,
)
hidden.append(out_slice.contiguous())
if attn_metadata.num_decode_tokens > 0:
hidden_decode = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
hidden_decode = self._decode_infer(
q, k, v, kv_cache, state_indices_tensor, attn_metadata
)
hidden.insert(0, hidden_decode)
if not hidden:
@@ -292,18 +291,19 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
hidden = torch.concat(hidden, dim=0).contiguous()
return hidden
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
attn_metadata):
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
slot_id = state_indices_tensor[:attn_metadata.num_decodes]
hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
slot_id, 32)
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata):
q = q[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
k = k[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
v = v[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
slot_id = state_indices_tensor[: attn_metadata.num_decodes]
hidden = linear_decode_forward_triton(
q, k, v, kv_cache, self.tp_slope, slot_id, 32
)
return hidden
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor) -> None:
def forward(
self, hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor
) -> None:
torch.ops.vllm.linear_attention(
hidden_states,
output,
@@ -311,16 +311,18 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
self.prefix,
)
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor) -> None:
def _forward(
self, hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor
) -> None:
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, LinearAttentionMetadata)
num_actual_tokens = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
num_actual_tokens = (
attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
)
else:
num_actual_tokens = hidden_states.shape[0]
@@ -335,35 +337,39 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
num_prefills = getattr(attn_metadata, "num_prefills", 0)
if num_prefills > 0:
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens",
0)
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0)
for prefill_idx in range(num_prefills):
q_start = attn_metadata.query_start_loc[num_decode_tokens +
prefill_idx]
q_end = attn_metadata.query_start_loc[num_decode_tokens +
prefill_idx + 1]
q_start = attn_metadata.query_start_loc[
num_decode_tokens + prefill_idx
]
q_end = attn_metadata.query_start_loc[
num_decode_tokens + prefill_idx + 1
]
query_len = q_end - q_start
context_len = attn_metadata.seq_lens[
num_decode_tokens + prefill_idx] - query_len
context_len = (
attn_metadata.seq_lens[num_decode_tokens + prefill_idx]
- query_len
)
if context_len == 0:
block_to_clear = state_indices_tensor[num_decode_tokens
+ prefill_idx]
block_to_clear = state_indices_tensor[
num_decode_tokens + prefill_idx
]
kv_cache[block_to_clear, ...] = 0
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
if attn_metadata is None:
hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]),
device=q.device,
dtype=q.dtype)
hidden = torch.empty(
(q.shape[0], q.shape[1] * q.shape[2]), device=q.device, dtype=q.dtype
)
else:
if not decode_only:
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
hidden = self._prefill_and_mix_infer(
q, k, v, kv_cache, state_indices_tensor, attn_metadata
)
else:
hidden = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
hidden = self._decode_infer(
q, k, v, kv_cache, state_indices_tensor, attn_metadata
)
hidden = self.norm._forward(hidden)
gate, _ = self.output_gate(hidden_states[:num_actual_tokens])
hidden = F.sigmoid(gate) * hidden
@@ -380,9 +386,7 @@ def linear_attention(
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self._forward(hidden_states=hidden_states,
output=output,
positions=positions)
self._forward(hidden_states=hidden_states, output=output, positions=positions)
def linear_attention_fake(

View File

@@ -12,20 +12,30 @@ from torch.nn.parameter import Parameter
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
causal_conv1d_fn,
causal_conv1d_update,
)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
selective_scan_fn,
selective_state_update,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
@@ -44,22 +54,24 @@ class MambaMixer(MambaBase, CustomOp):
**selective** state spaces)
"""
def __init__(self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
time_step_rank: int,
use_conv_bias: bool,
use_bias: bool,
use_rms_norm: bool,
rms_norm_has_weight: bool = True,
rms_norm_eps: float = 1e-5,
activation="silu",
is_lora_enabled: bool = False,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
prefix: str = ""):
def __init__(
self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
time_step_rank: int,
use_conv_bias: bool,
use_bias: bool,
use_rms_norm: bool,
rms_norm_has_weight: bool = True,
rms_norm_eps: float = 1e-5,
activation="silu",
is_lora_enabled: bool = False,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
):
super().__init__()
self.time_step_rank = time_step_rank
self.ssm_state_size = ssm_state_size
@@ -80,9 +92,9 @@ class MambaMixer(MambaBase, CustomOp):
# doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
self.in_proj = MergedColumnParallelLinear(hidden_size,
[intermediate_size] * 2,
bias=use_bias)
self.in_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=use_bias
)
# selective projection used to make dt, B and C input dependent
self.x_proj = RowParallelLinear(
@@ -93,17 +105,18 @@ class MambaMixer(MambaBase, CustomOp):
# time step projection (discretization) -
# In the forward we need to apply dt_proj without the bias,
# as the bias is added in the selective scan kernel.
self.dt_proj = ColumnParallelLinear(time_step_rank,
intermediate_size,
bias=True,
skip_bias_add=True)
self.dt_proj = ColumnParallelLinear(
time_step_rank, intermediate_size, bias=True, skip_bias_add=True
)
def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
param.data.copy_(
loaded_weight.data.split(loaded_weight.shape[0] // tp_size,
dim=0)[tp_rank])
loaded_weight.data.split(loaded_weight.shape[0] // tp_size, dim=0)[
tp_rank
]
)
def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
weight_loader(param, -torch.exp(loaded_weight.float()))
@@ -114,7 +127,8 @@ class MambaMixer(MambaBase, CustomOp):
intermediate_size // tp_size,
ssm_state_size,
dtype=torch.float32,
))
)
)
self.D = nn.Parameter(torch.ones(intermediate_size // tp_size))
set_weight_attrs(self.D, {"weight_loader": weight_loader})
@@ -127,23 +141,35 @@ class MambaMixer(MambaBase, CustomOp):
input_is_parallel=True,
)
self.dt_layernorm = RMSNorm(
time_step_rank,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
) if use_rms_norm else None
self.dt_layernorm = (
RMSNorm(
time_step_rank,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
)
if use_rms_norm
else None
)
self.b_layernorm = RMSNorm(
ssm_state_size,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
) if use_rms_norm else None
self.b_layernorm = (
RMSNorm(
ssm_state_size,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
)
if use_rms_norm
else None
)
self.c_layernorm = RMSNorm(
ssm_state_size,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
) if use_rms_norm else None
self.c_layernorm = (
RMSNorm(
ssm_state_size,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
)
if use_rms_norm
else None
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
@@ -157,7 +183,7 @@ class MambaMixer(MambaBase, CustomOp):
self.prefix = prefix
def _ssm_transform(
self, x: torch.Tensor
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if self.is_lora_enabled:
# Lora kernel requires contiguous tensor.
@@ -167,7 +193,8 @@ class MambaMixer(MambaBase, CustomOp):
time_step, B, C = torch.split(
ssm_params,
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
dim=-1)
dim=-1,
)
if self.use_rms_norm:
assert self.dt_layernorm is not None
assert self.b_layernorm is not None
@@ -185,8 +212,7 @@ class MambaMixer(MambaBase, CustomOp):
self.prefix,
)
def forward_native(self, hidden_states: torch.Tensor,
output: torch.Tensor):
def forward_native(self, hidden_states: torch.Tensor, output: torch.Tensor):
pass
def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor):
@@ -232,8 +258,9 @@ class MambaMixer(MambaBase, CustomOp):
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
hidden_states_BC, gate = projected_states.chunk(2, dim=-2)
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
if attn_metadata is None:
# V1 profile run
@@ -281,10 +308,12 @@ class MambaMixer(MambaBase, CustomOp):
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
query_start_loc=query_start_loc_p)
query_start_loc=query_start_loc_p,
)
# 3. State Space Model sequence transformations.
discrete_time_step_p, B_p, C_p = self._ssm_transform(
conv_out_p.transpose(-2, -1))
conv_out_p.transpose(-2, -1)
)
time_proj_bias = self._time_proj_bias()
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
@@ -301,7 +330,8 @@ class MambaMixer(MambaBase, CustomOp):
delta_softplus=True,
cache_indices=state_indices_tensor_p,
has_initial_state=has_initial_states_p,
query_start_loc=query_start_loc_p)
query_start_loc=query_start_loc_p,
)
ssm_outputs.append(scan_out_p)
if has_decode:
@@ -312,39 +342,42 @@ class MambaMixer(MambaBase, CustomOp):
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor_d).transpose(0, 1)
conv_state_indices=state_indices_tensor_d,
).transpose(0, 1)
# 3. State Space Model sequence transformation.
discrete_time_step_d, B_d, C_d = self._ssm_transform(
conv_out_d.transpose(-2, -1))
conv_out_d.transpose(-2, -1)
)
time_proj_bias = self._time_proj_bias()
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
scan_outputs_d = torch.empty_like(
hidden_states_BC_d.transpose(0, 1))
selective_state_update(ssm_state,
conv_out_d.transpose(0, 1),
discrete_time_step_d.transpose(0, 1),
self.A,
B_d,
C_d,
self.D,
gate_d.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor_d,
out=scan_outputs_d)
scan_outputs_d = torch.empty_like(hidden_states_BC_d.transpose(0, 1))
selective_state_update(
ssm_state,
conv_out_d.transpose(0, 1),
discrete_time_step_d.transpose(0, 1),
self.A,
B_d,
C_d,
self.D,
gate_d.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor_d,
out=scan_outputs_d,
)
scan_outputs_d = scan_outputs_d.transpose(0, 1)
ssm_outputs.insert(0, scan_outputs_d)
scan_outputs_combined = ssm_outputs[0] if len(
ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
scan_outputs_combined = (
ssm_outputs[0] if len(ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
)
# 5. Final output projection
if self.is_lora_enabled: # Lora kernel requires contiguous tensor.
scan_outputs_combined = scan_outputs_combined.transpose(
-2, -1).contiguous()
scan_outputs_combined = scan_outputs_combined.transpose(-2, -1).contiguous()
out = self.out_proj(scan_outputs_combined)[0]
else:
out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0]
@@ -373,8 +406,8 @@ class MambaMixer(MambaBase, CustomOp):
return "mamba1"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.mamba1_attn import (
Mamba1AttentionBackend)
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
return Mamba1AttentionBackend
def _time_proj_bias(self) -> Optional[torch.Tensor]:
@@ -406,27 +439,34 @@ def split_batch_to_prefill_and_decode(
num_decodes: int,
num_padded_decodes: int,
) -> PrefillDecodeSplit:
num_actual_tokens = num_prefill_tokens + num_padded_decodes
# In v1, decode tokens come first, then prefill tokens.
hidden_states_BC_d, hidden_states_BC_p = torch.split(
hidden_states_BC[..., :num_actual_tokens],
[num_padded_decodes, num_prefill_tokens],
dim=-1)
gate_d, gate_p = torch.split(gate[..., :num_actual_tokens],
[num_padded_decodes, num_prefill_tokens],
dim=-1)
dim=-1,
)
gate_d, gate_p = torch.split(
gate[..., :num_actual_tokens], [num_padded_decodes, num_prefill_tokens], dim=-1
)
# num_padded_decodes accounts for CUDA graph padding when applicable
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor[:num_padded_decodes + num_prefills],
state_indices_tensor[: num_padded_decodes + num_prefills],
[num_padded_decodes, num_prefills],
dim=0)
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
num_padded_decodes if num_prefills > 0 else None)
has_initial_states_p = has_initial_states[-num_prefills:] if (
has_initial_states is not None and num_prefills > 0) else None
dim=0,
)
query_start_loc_p = (
query_start_loc[-num_prefills - 1 :] - num_padded_decodes
if num_prefills > 0
else None
)
has_initial_states_p = (
has_initial_states[-num_prefills:]
if (has_initial_states is not None and num_prefills > 0)
else None
)
return PrefillDecodeSplit(
hidden_states_BC_p=hidden_states_BC_p,

View File

@@ -11,28 +11,40 @@ from torch import nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
causal_conv1d_fn,
causal_conv1d_update,
)
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_state_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_state_update
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined_varlen)
mamba_chunk_scan_combined_varlen,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction, composed_weight_loader, sharded_weight_loader)
LoaderFunction,
composed_weight_loader,
sharded_weight_loader,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
@@ -43,12 +55,13 @@ from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated
@CustomOp.register("mixer2_gated_rms_norm")
class Mixer2RMSNormGated(CustomOp):
def __init__(self,
full_hidden_size: int,
full_n_groups: int,
use_rms_norm: bool = True,
eps: float = 1e-6):
def __init__(
self,
full_hidden_size: int,
full_n_groups: int,
use_rms_norm: bool = True,
eps: float = 1e-6,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
@@ -62,13 +75,13 @@ class Mixer2RMSNormGated(CustomOp):
if self.use_rms_norm:
# Register norm weight only if we're actually applying RMSNorm
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
set_weight_attrs(self.weight,
{"weight_loader": sharded_weight_loader(0)})
set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
else:
# Avoid checkpoint mismatch by skipping unused parameter
self.register_parameter("weight", None)
assert (self.full_hidden_size % self.tp_size == 0
), "Tensor parallel world size must divide hidden size."
assert self.full_hidden_size % self.tp_size == 0, (
"Tensor parallel world size must divide hidden size."
)
def forward_native(
self,
@@ -111,8 +124,7 @@ class Mixer2RMSNormGated(CustomOp):
group_count = hidden_dim // self.group_size
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
variance = x_grouped.pow(2).mean(-1, keepdim=True)
x_grouped = x_grouped * torch.rsqrt(variance +
self.variance_epsilon)
x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
x = x_grouped.view(*prefix_dims, hidden_dim)
if redundant_tp:
@@ -130,18 +142,19 @@ class Mixer2RMSNormGated(CustomOp):
input_dtype = x.dtype
if not self.use_rms_norm:
# Keep gate in float32 for numerical stability during silu
return x * nn.functional.silu(gate.to(
torch.float32)).to(input_dtype)
return x * nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
if (((self.n_groups % self.tp_size) != 0) or self.n_groups != 1):
if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
return self.forward_native(x, gate)
return rms_norm_gated(x,
self.weight.data,
bias=None,
z=gate,
eps=self.variance_epsilon,
norm_before_gate=False)
return rms_norm_gated(
x,
self.weight.data,
bias=None,
z=gate,
eps=self.variance_epsilon,
norm_before_gate=False,
)
def mamba_v2_sharded_weight_loader(
@@ -156,7 +169,6 @@ def mamba_v2_sharded_weight_loader(
"""
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
# - track boundary of (sharded) param, and loaded_weight, respectively
boundary, loaded_boundary = 0, 0
@@ -191,11 +203,12 @@ def mamba_v2_sharded_weight_loader(
# seem to handle slices well.
# https://github.com/python/mypy/issues/2410
param.data[
boundary:(boundary + take),
... # type: ignore[misc]
] = loaded_weight[loaded_start_idx:(loaded_start_idx +
take) # type: ignore[misc]
] # type: ignore[misc]
boundary : (boundary + take), ... # type: ignore[misc]
] = loaded_weight[
loaded_start_idx : (
loaded_start_idx + take
) # type: ignore[misc]
] # type: ignore[misc]
# move indexing boundaries
boundary += shard_size
@@ -217,23 +230,25 @@ class MambaMixer2(MambaBase, CustomOp):
**selective** state spaces)
"""
def __init__(self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(
self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
# For TP, the sharding plan is as follows:
@@ -253,15 +268,18 @@ class MambaMixer2(MambaBase, CustomOp):
self.tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
assert (num_heads % self.tp_size == 0
), "Tensor parallel world size must divide num heads."
assert num_heads % self.tp_size == 0, (
"Tensor parallel world size must divide num heads."
)
assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
"If tensor parallel world size does not divide num_groups, "
"then num_groups must equal 1.")
"then num_groups must equal 1."
)
assert (n_groups % self.tp_size == 0) or self.tp_size == 1 or \
quant_config is None, (
assert (
(n_groups % self.tp_size == 0) or self.tp_size == 1 or quant_config is None
), (
"Tensor parallel currently supported for quantized models only "
"if tensor parallel world size divides num groups."
)
@@ -280,7 +298,8 @@ class MambaMixer2(MambaBase, CustomOp):
# - but if n_groups cannot divide tp_size, we need to
# extend some extra groups
groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
n_groups, self.tp_size)
n_groups, self.tp_size
)
self.n_groups = n_groups + groups
self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
@@ -340,8 +359,7 @@ class MambaMixer2(MambaBase, CustomOp):
# to the head shards
group_shard_settings = (
self.groups_ssm_state_size, # expected model size
(self.n_groups - n_groups) *
self.ssm_state_size, # extra dims assigned
(self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned
n_groups == 1, # if there was only one group
)
intermediate_settings = (intermediate_size, 0, False)
@@ -355,8 +373,7 @@ class MambaMixer2(MambaBase, CustomOp):
set_weight_attrs(
self.conv1d.bias,
{
"weight_loader":
mamba_v2_sharded_weight_loader(
"weight_loader": mamba_v2_sharded_weight_loader(
[
intermediate_settings,
group_shard_settings,
@@ -372,8 +389,7 @@ class MambaMixer2(MambaBase, CustomOp):
set_weight_attrs(
self.conv1d.weight,
{
"weight_loader":
mamba_v2_sharded_weight_loader(
"weight_loader": mamba_v2_sharded_weight_loader(
[
intermediate_settings,
group_shard_settings,
@@ -391,8 +407,7 @@ class MambaMixer2(MambaBase, CustomOp):
set_weight_attrs(
self.in_proj.weight,
{
"weight_loader":
mamba_v2_sharded_weight_loader(
"weight_loader": mamba_v2_sharded_weight_loader(
[
intermediate_settings, # for gate
intermediate_settings,
@@ -418,17 +433,18 @@ class MambaMixer2(MambaBase, CustomOp):
torch.empty(
divide(num_heads, self.tp_size),
dtype=torch.float32,
))
)
)
self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
self.use_rms_norm = use_rms_norm
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
a_weight_loader = composed_weight_loader(
sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
sharded_weight_loader(0), lambda x: -torch.exp(x.float())
)
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
set_weight_attrs(self.dt_bias,
{"weight_loader": sharded_weight_loader(0)})
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
self.out_proj = RowParallelLinear(
intermediate_size,
@@ -439,10 +455,9 @@ class MambaMixer2(MambaBase, CustomOp):
prefix=f"{prefix}.out_proj",
)
self.norm = Mixer2RMSNormGated(intermediate_size,
n_groups,
self.use_rms_norm,
eps=rms_norm_eps)
self.norm = Mixer2RMSNormGated(
intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
@@ -525,8 +540,9 @@ class MambaMixer2(MambaBase, CustomOp):
dim=-1,
)
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
# - get hidden_states, B and C after depthwise convolution.
split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
@@ -541,10 +557,10 @@ class MambaMixer2(MambaBase, CustomOp):
if attn_metadata is None:
# profile run
hidden_states_B_C = (hidden_states_B_C.transpose(
0, 1).clone().transpose(0, 1)).contiguous()
hidden_states, _B, _C = split_hidden_states_B_C_fn(
hidden_states_B_C)
hidden_states_B_C = (
hidden_states_B_C.transpose(0, 1).clone().transpose(0, 1)
).contiguous()
hidden_states, _B, _C = split_hidden_states_B_C_fn(hidden_states_B_C)
hidden_states = self.norm(hidden_states, gate)
out, _ = self.out_proj(hidden_states)
return out
@@ -580,11 +596,11 @@ class MambaMixer2(MambaBase, CustomOp):
# If prefix caching is enabled, retrieve the relevant variables
# for prefill and decode
last_state_idx_d, last_state_idx_p = torch.split(
attn_metadata.last_state_idx, [num_decodes, num_prefills],
dim=0)
attn_metadata.last_state_idx, [num_decodes, num_prefills], dim=0
)
current_last_idx_d, current_last_idx_p = torch.split(
attn_metadata.current_last_idx, [num_decodes, num_prefills],
dim=0)
attn_metadata.current_last_idx, [num_decodes, num_prefills], dim=0
)
# Prefill-only variables:
current_first_idx_p = attn_metadata.current_first_idx_p
context_lens_p = attn_metadata.context_lens_p
@@ -600,7 +616,7 @@ class MambaMixer2(MambaBase, CustomOp):
preallocated_ssm_out = torch.empty(
[
num_prefill_tokens + num_decodes,
(self.num_heads // self.tp_size) * self.head_dim
(self.num_heads // self.tp_size) * self.head_dim,
],
dtype=hidden_states.dtype,
device=hidden_states.device,
@@ -626,7 +642,8 @@ class MambaMixer2(MambaBase, CustomOp):
# "state_indices_tensor_p"), it will write additional cache
# states aligned at "block_size_to_align".
x = hidden_states_B_C_p.transpose(
0, 1) # this is the form that causal-conv see
0, 1
) # this is the form that causal-conv see
hidden_states_B_C_p = causal_conv1d_fn(
x,
conv_weights,
@@ -641,34 +658,34 @@ class MambaMixer2(MambaBase, CustomOp):
context_lens=context_lens_p,
block_size_to_align=mamba_block_size,
metadata=attn_metadata,
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]
query_start_loc=query_start_loc_p,
).transpose(0, 1)[:num_prefill_tokens]
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(
hidden_states_B_C_p)
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(hidden_states_B_C_p)
# 3. State Space Model sequence transformation
initial_states = None
if (has_initial_states_p is not None and prep_initial_states):
if has_initial_states_p is not None and prep_initial_states:
kernel_ssm_indices = state_indices_tensor_p
if prefix_caching_enabled:
kernel_ssm_indices = state_indices_tensor_p.gather(
1, last_state_idx_p.unsqueeze(1)).squeeze(1)
1, last_state_idx_p.unsqueeze(1)
).squeeze(1)
initial_states = torch.where(
has_initial_states_p[:, None, None, None],
ssm_state[kernel_ssm_indices], 0)
ssm_state[kernel_ssm_indices],
0,
)
# NOTE: final output is an in-place update of out tensor
varlen_states = mamba_chunk_scan_combined_varlen(
hidden_states_p.view(num_prefill_tokens,
self.num_heads // self.tp_size,
self.head_dim),
hidden_states_p.view(
num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim
),
dt_p,
self.A,
B_p.view(num_prefill_tokens, self.n_groups // self.tp_size,
-1),
C_p.view(num_prefill_tokens, self.n_groups // self.tp_size,
-1),
B_p.view(num_prefill_tokens, self.n_groups // self.tp_size, -1),
C_p.view(num_prefill_tokens, self.n_groups // self.tp_size, -1),
chunk_size=chunk_size,
D=self.D,
z=None,
@@ -681,18 +698,19 @@ class MambaMixer2(MambaBase, CustomOp):
return_intermediate_states=prefix_caching_enabled,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1,
self.head_dim),
state_dtype=ssm_state.dtype)
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim),
state_dtype=ssm_state.dtype,
)
if prefix_caching_enabled:
# Save states for sequences with more than just the final state:
n_blocks_to_fill = current_last_idx_p - current_first_idx_p
for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1):
cache_blocks_to_fill = state_indices_tensor_p[
seq_idx, current_first_idx_p[seq_idx]:
current_first_idx_p[seq_idx] +
n_blocks_to_fill[seq_idx]]
seq_idx,
current_first_idx_p[seq_idx] : current_first_idx_p[seq_idx]
+ n_blocks_to_fill[seq_idx],
]
# chunks = [0 1 2 3 4 5 6 ...]
# First aligned chunk would typically be:
# mamba_block_size = 1024, chunk_size = 256
@@ -704,22 +722,33 @@ class MambaMixer2(MambaBase, CustomOp):
# e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2)
# e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3)
chunk_stride = mamba_block_size // chunk_size
first_aligned_chunk = \
torch.concat([torch.zeros(1, \
dtype=last_chunk_indices_p.dtype, \
device=last_chunk_indices_p.device), \
last_chunk_indices_p + 1])[seq_idx] \
+ chunk_stride - 1 \
- last_computed_offset_p[seq_idx] // chunk_size
first_aligned_chunk = (
torch.concat(
[
torch.zeros(
1,
dtype=last_chunk_indices_p.dtype,
device=last_chunk_indices_p.device,
),
last_chunk_indices_p + 1,
]
)[seq_idx]
+ chunk_stride
- 1
- last_computed_offset_p[seq_idx] // chunk_size
)
from_where = varlen_states[
first_aligned_chunk:first_aligned_chunk +
n_blocks_to_fill[seq_idx] * chunk_stride:chunk_stride]
first_aligned_chunk : first_aligned_chunk
+ n_blocks_to_fill[seq_idx] * chunk_stride : chunk_stride
]
ssm_state[cache_blocks_to_fill] = from_where
#For all seqs, store the last state (Note: might be partial):
ssm_state[state_indices_tensor_p.gather(1,
current_last_idx_p.unsqueeze(1)).squeeze(1)] = \
varlen_states[last_chunk_indices_p]
# For all seqs, store the last state (Note: might be partial):
ssm_state[
state_indices_tensor_p.gather(
1, current_last_idx_p.unsqueeze(1)
).squeeze(1)
] = varlen_states[last_chunk_indices_p]
else:
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate)
@@ -729,13 +758,13 @@ class MambaMixer2(MambaBase, CustomOp):
# Process decode requests
if has_decode:
if prefix_caching_enabled:
state_indices_tensor_d_input = \
state_indices_tensor_d.gather(1,
last_state_idx_d.unsqueeze(1)).squeeze(1)
state_indices_tensor_d_output = \
state_indices_tensor_d.gather(1,
current_last_idx_d.unsqueeze(1)).squeeze(1)
#Note:
state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, last_state_idx_d.unsqueeze(1)
).squeeze(1)
state_indices_tensor_d_output = state_indices_tensor_d.gather(
1, current_last_idx_d.unsqueeze(1)
).squeeze(1)
# Note:
# for decode always: current_first_idx_d == current_last_idx_d
# at block boundaries: current_first_idx_d > last_state_idx_d
else:
@@ -755,20 +784,23 @@ class MambaMixer2(MambaBase, CustomOp):
initial_state_idx=last_state_idx_d,
)
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(
hidden_states_B_C_d)
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)
# 3. State Space Model sequence transformation
n_groups = self.n_groups // self.tp_size
A_d = self.A[:, None, ...][:, :, None].expand(
-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
A_d = (
self.A[:, None, ...][:, :, None]
.expand(-1, self.head_dim, self.ssm_state_size)
.to(dtype=torch.float32)
)
dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim)
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
D_d = self.D[:, None, ...].expand(-1, self.head_dim)
B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
hidden_states_d = hidden_states_d.view(
-1, self.num_heads // self.tp_size, self.head_dim)
-1, self.num_heads // self.tp_size, self.head_dim
)
# - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected
@@ -787,16 +819,14 @@ class MambaMixer2(MambaBase, CustomOp):
dt_softplus=True,
state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output,
out=preallocated_ssm_out_d.view(num_decodes, -1,
self.head_dim),
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
)
# 4. gated MLP
# GatedRMSNorm internally applying SiLU to the gate
# SiLU is applied internally before normalization, unlike standard
# norm usage
hidden_states = self.norm(preallocated_ssm_out,
gate[:num_actual_tokens])
hidden_states = self.norm(preallocated_ssm_out, gate[:num_actual_tokens])
# 5. Final linear projection
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
@@ -826,8 +856,8 @@ class MambaMixer2(MambaBase, CustomOp):
return "mamba2"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.mamba2_attn import (
Mamba2AttentionBackend)
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
return Mamba2AttentionBackend
@@ -839,9 +869,7 @@ def mamba_mixer2(
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states,
output=output,
mup_vector=mup_vector)
self.forward_cuda(hidden_states=hidden_states, output=output, mup_vector=mup_vector)
def mamba_mixer2_fake(

View File

@@ -10,7 +10,6 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype
class MambaStateDtypeCalculator:
@classmethod
def linear_attention_state_dtype(
cls,
@@ -21,7 +20,7 @@ class MambaStateDtypeCalculator:
if mamba_cache_dtype == "float32":
raise ValueError("fp32 state for minimax is not yet supported")
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return (state_dtype, )
return (state_dtype,)
@classmethod
def mamba1_state_dtype(
@@ -30,8 +29,9 @@ class MambaStateDtypeCalculator:
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
return cls._mamba_state_dtype(model_dtype, mamba_cache_dtype,
mamba_ssm_cache_dtype)
return cls._mamba_state_dtype(
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype
)
@classmethod
def mamba2_state_dtype(
@@ -40,8 +40,9 @@ class MambaStateDtypeCalculator:
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
return cls._mamba_state_dtype(model_dtype, mamba_cache_dtype,
mamba_ssm_cache_dtype)
return cls._mamba_state_dtype(
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype
)
@classmethod
def _mamba_state_dtype(
@@ -50,13 +51,11 @@ class MambaStateDtypeCalculator:
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
model_dtype)
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
if mamba_ssm_cache_dtype == "auto":
temporal_state_dtype = conv_state_dtype
else:
temporal_state_dtype = (
STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype])
temporal_state_dtype = STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype]
return (conv_state_dtype, temporal_state_dtype)
@@ -66,9 +65,8 @@ class MambaStateDtypeCalculator:
model_dtype: Union[ModelDType, torch.dtype],
mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
model_dtype)
return (conv_state_dtype, )
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return (conv_state_dtype,)
@classmethod
def gated_delta_net_state_dtype(
@@ -81,7 +79,6 @@ class MambaStateDtypeCalculator:
class MambaStateShapeCalculator:
@classmethod
def linear_attention_state_shape(
cls,
@@ -89,9 +86,8 @@ class MambaStateShapeCalculator:
tp_size: int,
head_dim: int,
) -> tuple[tuple[int, int, int], ...]:
state_shape = (num_heads // tp_size, head_dim, head_dim)
return (state_shape, )
return (state_shape,)
@classmethod
def mamba1_state_shape(
@@ -101,11 +97,9 @@ class MambaStateShapeCalculator:
state_size: int,
conv_kernel: int,
) -> tuple[tuple[int, int], tuple[int, int]]:
conv_state_shape = (divide(intermediate_size,
tp_world_size), conv_kernel - 1)
conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1)
temporal_state_shape = (divide(intermediate_size,
tp_world_size), state_size)
temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size)
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
@@ -124,8 +118,7 @@ class MambaStateShapeCalculator:
) -> tuple[tuple[int, int], tuple[int, int, int]]:
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = n_groups + cls.extra_groups_for_head_shards(
n_groups, tp_world_size)
n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size)
# heads and n_groups are TP-ed
conv_dim = intermediate_size + 2 * n_groups * state_size
@@ -135,8 +128,7 @@ class MambaStateShapeCalculator:
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
temporal_state_shape = (divide(num_heads,
tp_world_size), head_dim, state_size)
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
return conv_state_shape, temporal_state_shape
@classmethod
@@ -148,7 +140,7 @@ class MambaStateShapeCalculator:
) -> tuple[tuple[int, int]]:
conv_dim = divide(intermediate_size, tp_world_size)
conv_state_shape = (conv_kernel - 1, conv_dim)
return (conv_state_shape, )
return (conv_state_shape,)
@classmethod
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
@@ -173,7 +165,7 @@ class MambaStateShapeCalculator:
conv_kernel_size: int,
num_spec: int = 0,
):
conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads)
conv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads
conv_state_shape = (
divide(conv_dim, tp_world_size),
conv_kernel_size - 1 + num_spec,
@@ -181,6 +173,9 @@ class MambaStateShapeCalculator:
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
temporal_state_shape = (divide(num_v_heads,
tp_world_size), head_k_dim, head_v_dim)
temporal_state_shape = (
divide(num_v_heads, tp_world_size),
head_k_dim,
head_v_dim,
)
return conv_state_shape, temporal_state_shape

View File

@@ -38,8 +38,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines
# Strides
stride_x_dim: tl.constexpr, # stride to get to next feature-value,
stride_x_token: tl.
constexpr, # stride to get to next token (same feature-index, same sequence-index)
stride_x_token: tl.constexpr, # stride to get to next token (same feature-index, same sequence-index)
stride_w_dim: tl.constexpr, # stride to get to next dim-axis value
stride_w_width: tl.constexpr, # stride to get to next width-axis value
stride_istate_seq: tl.constexpr,
@@ -66,7 +65,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching
stride_conv_state_seq = stride_istate_seq
stride_conv_state_dim = stride_istate_dim
stride_conv_state_tok = stride_istate_token
state_len = KERNEL_WIDTH - 1 # can be passed via argument if it's not the same as this value
state_len = (
KERNEL_WIDTH - 1
) # can be passed via argument if it's not the same as this value
# one program handles one chunk in a single sequence
# rather than mixing sequences - to make updating initial_states across sequences efficiently
@@ -86,7 +87,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
# find the actual sequence length
seqlen = sequence_end_index - sequence_start_index
B_size: tl.constexpr = (stride_block_m * BLOCK_M)
B_size: tl.constexpr = stride_block_m * BLOCK_M
if IS_APC_ENABLED:
# Handle the case if prefix caching is enabled.
@@ -124,20 +125,24 @@ def _causal_conv1d_fwd_kernel( # continuous batching
segment_len = min(BLOCK_M, seqlen - token_offset)
# base of the sequence
x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,]
x_base = (
x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim
) # [BLOCK_N,]
# cache_idx
conv_states_input_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_cache_indices +
conv_state_init_index).to(tl.int64)
conv_states_input_coord = tl.load(
conv_state_indices_ptr + idx_seq * stride_cache_indices + conv_state_init_index
).to(tl.int64)
if USE_PAD_SLOT: # noqa
if conv_states_input_coord == pad_slot_id:
# not processing as this is not the actual sequence
return
conv_states_base = (conv_states_ptr +
(conv_states_input_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim)) # [BLOCK_N,]
conv_states_base = (
conv_states_ptr
+ (conv_states_input_coord * stride_conv_state_seq)
+ (idx_feats * stride_conv_state_dim)
) # [BLOCK_N,]
w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,]
@@ -149,8 +154,7 @@ def _causal_conv1d_fwd_kernel( # continuous batching
load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1)
if load_init_state:
# load from conv_states
prior_tokens = conv_states_base + (state_len -
1) * stride_conv_state_tok
prior_tokens = conv_states_base + (state_len - 1) * stride_conv_state_tok
mask_w = idx_feats < dim
if KERNEL_WIDTH == 2:
conv_states_ptrs = prior_tokens # [BLOCK_N]
@@ -180,46 +184,54 @@ def _causal_conv1d_fwd_kernel( # continuous batching
# prior-tokens are zeros
if KERNEL_WIDTH >= 2: # STRATEGY1
# first chunk and does not have prior-token, so just set to 0
col0 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty)
col0 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
if KERNEL_WIDTH >= 3: # STRATEGY1
col1 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty)
col1 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
if KERNEL_WIDTH >= 4: # STRATEGY1
col2 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty)
col2 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
if KERNEL_WIDTH >= 5: # STRATEGY1
col3 = tl.zeros((BLOCK_N, ), dtype=x_ptr.dtype.element_ty)
col3 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty)
# STEP 2:
# here prepare data for updating conv_state
if state_len <= seqlen: # SMALL_CACHE=True (only move part of 'x' into conv_state cache)
if (
state_len <= seqlen
): # SMALL_CACHE=True (only move part of 'x' into conv_state cache)
# just read from 'x'
# copy 'x' data to conv_state
# load only 'x' data (and set 0 before 'x' if seqlen < state_len)
idx_tokens_last = (seqlen - state_len) + tl.arange(
0, NP2_STATELEN) # [BLOCK_M]
x_ptrs = x_ptr + (
(sequence_start_index + idx_tokens_last) *
stride_x_token)[:, None] + (
idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,]
mask_x = ((idx_tokens_last >= 0)[:, None] &
(idx_tokens_last < seqlen)[:, None] &
(idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
0, NP2_STATELEN
) # [BLOCK_M]
x_ptrs = (
x_ptr
+ ((sequence_start_index + idx_tokens_last) * stride_x_token)[:, None]
+ (idx_feats * stride_x_dim)[None, :]
) # [BLOCK_M,BLOCK_N,]
mask_x = (
(idx_tokens_last >= 0)[:, None]
& (idx_tokens_last < seqlen)[:, None]
& (idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
# Compute the offset where the last block should be written in the conv_states
conv_states_output_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_cache_indices +
current_last_index).to(tl.int64)
conv_states_output_coord = tl.load(
conv_state_indices_ptr
+ idx_seq * stride_cache_indices
+ current_last_index
).to(tl.int64)
conv_states_ptrs_target = (
conv_states_ptr + (conv_states_output_coord *
stride_conv_state_seq) + # Offset from seq
(idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,]
idx_tokens_conv * stride_conv_state_tok)[:, None]
conv_states_ptr
+ (conv_states_output_coord * stride_conv_state_seq) # Offset from seq
+ (idx_feats * stride_conv_state_dim)
)[None, :] + ( # [BLOCK_N,]
idx_tokens_conv * stride_conv_state_tok
)[:, None]
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats
< dim)[None, :]
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :]
tl.debug_barrier() # NOTE: use this due to bug in Triton compiler
tl.store(conv_states_ptrs_target, loaded_x, mask)
@@ -229,39 +241,43 @@ def _causal_conv1d_fwd_kernel( # continuous batching
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
conv_states_ptrs_source = (
conv_states_ptr +
(conv_states_input_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim)[None, :] +
((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:,
None]
conv_states_ptr
+ (conv_states_input_coord * stride_conv_state_seq)
+ (idx_feats * stride_conv_state_dim)[None, :]
+ ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None]
) # [BLOCK_M, BLOCK_N]
mask = ((conv_states_input_coord < num_cache_lines)
& ((idx_tokens_conv + seqlen) < state_len)[:, None]
& (idx_feats < dim)[None, :])
mask = (
(conv_states_input_coord < num_cache_lines)
& ((idx_tokens_conv + seqlen) < state_len)[:, None]
& (idx_feats < dim)[None, :]
)
conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0)
VAL = state_len - seqlen
x_ptrs = x_base[None, :] + (
(idx_tokens_conv - VAL) *
stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
x_ptrs = (
x_base[None, :]
+ ((idx_tokens_conv - VAL) * stride_x_token)[:, None]
) # [BLOCK_M, BLOCK_N]
mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] &
(idx_tokens_conv - VAL < seqlen)[:, None] &
(idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
mask_x = (
(idx_tokens_conv - VAL >= 0)[:, None]
& (idx_tokens_conv - VAL < seqlen)[:, None]
& (idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
tl.debug_barrier(
) # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load
tl.debug_barrier() # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load
new_conv_state = tl.where(
mask, conv_state, loaded_x
) # BUG in 'tl.where' which requires a barrier before this
conv_states_ptrs_target = conv_states_base + (
idx_tokens_conv *
stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
mask = (idx_tokens_conv
< state_len)[:, None] & (idx_feats < dim)[None, :]
conv_states_ptrs_target = (
conv_states_base
+ (idx_tokens_conv * stride_conv_state_tok)[:, None]
) # [BLOCK_M, BLOCK_N]
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[
None, :
]
tl.store(conv_states_ptrs_target, new_conv_state, mask)
else: # load_init_state == False
# update conv_state by shifting left, BUT
@@ -270,21 +286,25 @@ def _causal_conv1d_fwd_kernel( # continuous batching
VAL = state_len - seqlen
x_ptrs = x_base[None, :] + (
(idx_tokens_conv - VAL) *
stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
x_ptrs = (
x_base[None, :]
+ ((idx_tokens_conv - VAL) * stride_x_token)[:, None]
) # [BLOCK_M, BLOCK_N]
mask_x = ((idx_tokens_conv - VAL >= 0)[:, None] &
(idx_tokens_conv - VAL < seqlen)[:, None] &
(idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
mask_x = (
(idx_tokens_conv - VAL >= 0)[:, None]
& (idx_tokens_conv - VAL < seqlen)[:, None]
& (idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
new_conv_state = tl.load(x_ptrs, mask_x, 0.0)
conv_states_ptrs_target = conv_states_base + (
idx_tokens_conv *
stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N]
mask = (idx_tokens_conv
< state_len)[:, None] & (idx_feats < dim)[None, :]
conv_states_ptrs_target = (
conv_states_base
+ (idx_tokens_conv * stride_conv_state_tok)[:, None]
) # [BLOCK_M, BLOCK_N]
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[
None, :
]
tl.store(conv_states_ptrs_target, new_conv_state, mask)
else: # chunk_offset > 0
@@ -294,29 +314,29 @@ def _causal_conv1d_fwd_kernel( # continuous batching
mask_w = idx_feats < dim
if KERNEL_WIDTH == 2:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
if KERNEL_WIDTH == 3:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
if KERNEL_WIDTH == 4:
conv_states_ptrs = prior_tokens # [BLOCK_N]
col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
if KERNEL_WIDTH == 5:
# ruff: noqa: F841
conv_states_ptrs = prior_tokens # [BLOCK_N]
col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N]
col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N]
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N]
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier='.ca')
col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca")
# Store intermediate states aligned with stride_block_m
# The additional states are cached starting from the last stride_block_m.
@@ -327,43 +347,51 @@ def _causal_conv1d_fwd_kernel( # continuous batching
# For example chunk_offset = n_block_to_fill stores the state at last_full_block
if (chunk_offset - 1) < n_block_to_fill:
# Store the states at the chunk boundaries from the start of the sequence
idx_tokens_last = (last_full_block_token_index -
(n_block_to_fill - chunk_offset) * B_size -
state_len) + tl.arange(
0, NP2_STATELEN) # [BLOCK_M]
x_ptrs = x_ptr + (idx_tokens_last * stride_x_token)[:, None] + (
idx_feats * stride_x_dim)[None, :] # [BLOCK_M,BLOCK_N,]
idx_tokens_last = (
last_full_block_token_index
- (n_block_to_fill - chunk_offset) * B_size
- state_len
) + tl.arange(0, NP2_STATELEN) # [BLOCK_M]
x_ptrs = (
x_ptr
+ (idx_tokens_last * stride_x_token)[:, None]
+ (idx_feats * stride_x_dim)[None, :]
) # [BLOCK_M,BLOCK_N,]
mask_x = (
(idx_tokens_last >= 0)[:, None] & (idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
mask_x = (idx_tokens_last >= 0)[:, None] & (idx_feats < dim)[
None, :
] # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
# cache_idx
conv_states_output_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_cache_indices +
current_first_index +
(chunk_offset - 1)).to(tl.int64)
conv_states_output_coord = tl.load(
conv_state_indices_ptr
+ idx_seq * stride_cache_indices
+ current_first_index
+ (chunk_offset - 1)
).to(tl.int64)
conv_states_ptrs_target = (
conv_states_ptr + (conv_states_output_coord *
stride_conv_state_seq) + # Offset from seq
(idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,]
idx_tokens_conv * stride_conv_state_tok)[:, None]
conv_states_ptr
+ (conv_states_output_coord * stride_conv_state_seq) # Offset from seq
+ (idx_feats * stride_conv_state_dim)
)[None, :] + ( # [BLOCK_N,]
idx_tokens_conv * stride_conv_state_tok
)[:, None]
mask = (idx_tokens_conv < state_len)[:, None] & \
(idx_feats < dim)[None, :]
mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :]
tl.debug_barrier() # NOTE: use this due to bug in Triton compiler
tl.store(conv_states_ptrs_target, loaded_x, mask)
if HAS_BIAS:
bias = bias_ptr + idx_feats
mask_bias = idx_feats < dim
acc_preload = tl.load(bias, mask=mask_bias,
other=0.0).to(tl.float32) # [BLOCK_N]
acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(
tl.float32
) # [BLOCK_N]
else:
acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32)
acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32)
x_base_1d = x_base + token_offset * stride_x_token # starting of chunk
@@ -387,7 +415,6 @@ def _causal_conv1d_fwd_kernel( # continuous batching
matrix_w = w_col0
matrix_x = col0
for j in tl.static_range(KERNEL_WIDTH):
if KERNEL_WIDTH == 2:
if j == 1: # KERNEL_WIDTH-1:
matrix_w = w_col1
@@ -428,9 +455,13 @@ def _causal_conv1d_fwd_kernel( # continuous batching
if SILU_ACTIVATION:
acc = acc / (1 + tl.exp(-acc))
mask_1d = (idx_token < segment_len) & (
idx_feats < dim) # token-index # feature-index
o_ptrs = o_ptr + (sequence_start_index + token_offset + idx_token
) * stride_o_token + (idx_feats * stride_o_dim)
idx_feats < dim
) # token-index # feature-index
o_ptrs = (
o_ptr
+ (sequence_start_index + token_offset + idx_token) * stride_o_token
+ (idx_feats * stride_o_dim)
)
tl.store(o_ptrs, acc, mask=mask_1d)
@@ -518,21 +549,15 @@ def causal_conv1d_fn(
batch_ptr = metadata.batch_ptr
token_chunk_offset_ptr = metadata.token_chunk_offset_ptr
else:
seqlens = query_start_loc.diff().to('cpu')
seqlens = query_start_loc.diff().to("cpu")
args = seqlens
MAX_NUM_PROGRAMS = 1024
batch_ptr = torch.full(
(MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device=x.device
(MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device
) # tracking which seq-idx the Triton program is handling
token_chunk_offset_ptr = torch.full(
(MAX_NUM_PROGRAMS, ),
PAD_SLOT_ID,
dtype=torch.int32,
device=x.device
(MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device
) # tracking BLOCK_M-based index in the sequence the Triton program is handling
is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1)
@@ -558,9 +583,11 @@ def causal_conv1d_fn(
# 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx]
# 4. computation can be skipped if cache_indices[idx] == pad_slot_id
num_cache_lines = conv_states.size(0)
assert (num_cache_lines == conv_states.shape[0]
and dim == conv_states.shape[1]
and width - 1 <= conv_states.shape[2])
assert (
num_cache_lines == conv_states.shape[0]
and dim == conv_states.shape[1]
and width - 1 <= conv_states.shape[2]
)
stride_istate_seq = conv_states.stride(0)
stride_istate_dim = conv_states.stride(1)
stride_istate_token = conv_states.stride(2)
@@ -571,8 +598,7 @@ def causal_conv1d_fn(
else:
stride_o_dim = out.stride(1)
stride_o_token = out.stride(2)
stride_cache_indices = cache_indices.stride(
0) if cache_indices is not None else 0
stride_cache_indices = cache_indices.stride(0) if cache_indices is not None else 0
if validate_data:
assert x.dim() == 2
@@ -586,15 +612,17 @@ def causal_conv1d_fn(
assert cache_indices.dim() == 1
assert padded_batch == cache_indices.size(0)
if has_initial_state is not None:
assert has_initial_state.size() == (padded_batch, )
assert conv_states is not None, "ERROR: `has_initial_state` is used, which needs also `conv_states`"
assert has_initial_state.size() == (padded_batch,)
assert conv_states is not None, (
"ERROR: `has_initial_state` is used, which needs also `conv_states`"
)
assert weight.stride(1) == 1
assert (dim, width) == weight.shape
assert is_channel_last, "Need to run in channel-last layout"
if block_size_to_align is not None and block_size_to_align > 0:
assert (
block_size_to_align % BLOCK_M
) == 0, "The mamba block size needs to be divisible by the BLOCK_M"
assert (block_size_to_align % BLOCK_M) == 0, (
"The mamba block size needs to be divisible by the BLOCK_M"
)
else:
block_size_to_align = BLOCK_M
@@ -618,44 +646,45 @@ def causal_conv1d_fn(
if META["batch_ptr"].nelement() < len(mlist):
newlen = len(mlist) + 1
META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
META["token_chunk_offset_ptr"].resize_(newlen).fill_(
PAD_SLOT_ID)
META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
if META["batch_ptr"].nelement() >= len(mlist):
META["batch_ptr"][0:len(mlist)].copy_(
torch.from_numpy(np.array(mlist)))
META["token_chunk_offset_ptr"][0:len(mlist)].copy_(
torch.from_numpy(np.array(offsetlist)))
META["batch_ptr"][0 : len(mlist)].copy_(
torch.from_numpy(np.array(mlist))
)
META["token_chunk_offset_ptr"][0 : len(mlist)].copy_(
torch.from_numpy(np.array(offsetlist))
)
META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device)
META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to(
META["x_ptr"].device)
META["x_ptr"].device
)
return tot
else:
def num_program(META, nums_dict):
tot = nums_dict[META["BLOCK_M"]]['tot']
tot = nums_dict[META["BLOCK_M"]]["tot"]
mlist = nums_dict[META["BLOCK_M"]]['mlist']
mlist_len = nums_dict[META["BLOCK_M"]]['mlist_len']
mlist = nums_dict[META["BLOCK_M"]]["mlist"]
mlist_len = nums_dict[META["BLOCK_M"]]["mlist_len"]
offsetlist = nums_dict[META["BLOCK_M"]]['offsetlist']
offsetlist = nums_dict[META["BLOCK_M"]]["offsetlist"]
if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None:
META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"]
META["token_chunk_offset_ptr"] = nums_dict[
META["BLOCK_M"]]["token_chunk_offset_ptr"]
META["token_chunk_offset_ptr"] = nums_dict[META["BLOCK_M"]][
"token_chunk_offset_ptr"
]
else:
if META["batch_ptr"].nelement() < mlist_len:
newlen = mlist_len + 1
META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
META["token_chunk_offset_ptr"].resize_(newlen).fill_(
PAD_SLOT_ID)
META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID)
if META["batch_ptr"].nelement() >= mlist_len:
META["batch_ptr"][0:mlist_len].copy_(mlist)
META["token_chunk_offset_ptr"][0:mlist_len].copy_(
offsetlist)
META["token_chunk_offset_ptr"][0:mlist_len].copy_(offsetlist)
return tot
def grid(META):
@@ -709,7 +738,7 @@ def causal_conv1d_fn(
IS_APC_ENABLED=current_last_idx is not None,
USE_PAD_SLOT=pad_slot_id is not None,
NP2_STATELEN=np2_statelen,
#launch_cooperative_grid=True
# launch_cooperative_grid=True
BLOCK_M=BLOCK_M,
BLOCK_N=256,
num_stages=2,
@@ -728,7 +757,7 @@ def _causal_conv1d_update_kernel(
num_accepted_tokens_ptr,
query_start_loc_ptr, # (batch + 1)
current_last_idx, # (batch,)
initial_state_idx, #(batch,)
initial_state_idx, # (batch,)
o_ptr, # (batch, dim, seqlen)
# Matrix dimensions
batch: int,
@@ -779,9 +808,9 @@ def _causal_conv1d_update_kernel(
current_last_index = 0
# cache_idx
conv_states_input_coord = tl.load(conv_state_indices_ptr +
idx_seq * stride_state_indices +
conv_state_init).to(tl.int64)
conv_states_input_coord = tl.load(
conv_state_indices_ptr + idx_seq * stride_state_indices + conv_state_init
).to(tl.int64)
if USE_PAD_SLOT: # noqa
if conv_states_input_coord == pad_slot_id:
@@ -790,11 +819,9 @@ def _causal_conv1d_update_kernel(
if IS_VARLEN:
query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64)
query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to(
tl.int64)
query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to(tl.int64)
# revise state_len and seqlen
state_len = state_len - (seqlen -
(query_end_index - query_start_index))
state_len = state_len - (seqlen - (query_end_index - query_start_index))
seqlen = query_end_index - query_start_index
x_offset = query_start_index * stride_x_token
o_offset = query_start_index * stride_o_token
@@ -822,14 +849,17 @@ def _causal_conv1d_update_kernel(
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
# - and so on.
conv_state_token_offset = (
tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1)
tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1
)
else:
conv_state_token_offset = 0
# STEP 1: READ init_state data
conv_states_base = (conv_state_ptr +
(conv_states_input_coord * stride_conv_state_seq) +
(idx_feats * stride_conv_state_dim))
conv_states_base = (
conv_state_ptr
+ (conv_states_input_coord * stride_conv_state_seq)
+ (idx_feats * stride_conv_state_dim)
)
mask_w = idx_feats < dim
prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok
@@ -856,25 +886,33 @@ def _causal_conv1d_update_kernel(
# window manner, at each forward pass, the tokens are shift by 1, so we
# load since idx_tokens + 1.
conv_state_ptrs_source = (
conv_state_ptr + (conv_states_input_coord * stride_conv_state_seq) +
conv_state_token_offset * stride_conv_state_tok +
(idx_feats * stride_conv_state_dim)[None, :] +
((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) *
stride_conv_state_tok)[:, None]) # [BLOCK_M, BLOCK_N]
mask = ((conv_states_input_coord < num_cache_lines)
& ((idx_tokens + seqlen) < state_len)[:, None]
& (idx_feats < dim)[None, :])
conv_state_ptr
+ (conv_states_input_coord * stride_conv_state_seq)
+ conv_state_token_offset * stride_conv_state_tok
+ (idx_feats * stride_conv_state_dim)[None, :]
+ ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[
:, None
]
) # [BLOCK_M, BLOCK_N]
mask = (
(conv_states_input_coord < num_cache_lines)
& ((idx_tokens + seqlen) < state_len)[:, None]
& (idx_feats < dim)[None, :]
)
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)
VAL = state_len - seqlen
x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N]
x_ptrs = x_base[None, :] + (
(idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
x_ptrs = (
x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None]
) # [BLOCK_M, BLOCK_N]
mask_x = ((idx_tokens - VAL >= 0)[:, None] &
(idx_tokens - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
mask_x = (
(idx_tokens - VAL >= 0)[:, None]
& (idx_tokens - VAL < seqlen)[:, None]
& (idx_feats < dim)[None, :]
) # token-index # token-index # feature-index
loaded_x = tl.load(x_ptrs, mask_x, 0.0)
tl.debug_barrier()
@@ -882,14 +920,16 @@ def _causal_conv1d_update_kernel(
# Get the state from the initial_state_idx
# cache_idx
conv_states_offset = tl.load(conv_state_indices_ptr +
idx_seq * stride_state_indices +
current_last_index).to(tl.int64)
conv_states_offset = tl.load(
conv_state_indices_ptr + idx_seq * stride_state_indices + current_last_index
).to(tl.int64)
conv_state_ptrs_target = (
conv_state_ptr +
(conv_states_offset * stride_conv_state_seq) + # Offset from seq
(idx_feats * stride_conv_state_dim))[None, :] + ( # [BLOCK_N,]
idx_tokens * stride_conv_state_tok)[:, None]
conv_state_ptr
+ (conv_states_offset * stride_conv_state_seq) # Offset from seq
+ (idx_feats * stride_conv_state_dim)
)[None, :] + ( # [BLOCK_N,]
idx_tokens * stride_conv_state_tok
)[:, None]
mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :]
tl.store(conv_state_ptrs_target, new_conv_state, mask)
@@ -897,10 +937,11 @@ def _causal_conv1d_update_kernel(
if HAS_BIAS:
bias = bias_ptr + idx_feats
mask_bias = idx_feats < dim
acc_preload = tl.load(bias, mask=mask_bias,
other=0.0).to(tl.float32) # [BLOCK_N]
acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(
tl.float32
) # [BLOCK_N]
else:
acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32)
acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32)
# STEP 4:
# PRE-LOAD WEIGHTS
@@ -1016,10 +1057,12 @@ def _causal_conv1d_update_kernel(
if SILU_ACTIVATION:
acc = acc / (1 + tl.exp(-acc))
mask_1d = (idx_token < seqlen) & (idx_feats < dim
) # token-index # feature-index
o_ptrs = o_ptr + o_offset + idx_token * stride_o_token + (idx_feats *
stride_o_dim)
mask_1d = (idx_token < seqlen) & (
idx_feats < dim
) # token-index # feature-index
o_ptrs = (
o_ptr + o_offset + idx_token * stride_o_token + (idx_feats * stride_o_dim)
)
tl.store(o_ptrs, acc, mask=mask_1d)
@@ -1104,16 +1147,16 @@ def causal_conv1d_update(
if validate_data:
assert dim == weight.size(0)
assert conv_state.stride(
-2
) == 1, f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})"
assert conv_state.stride(-2) == 1, (
f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})"
)
assert state_len >= width - 1
# when above happens, we don't shift-left to keep any records in conv_state
assert dim == conv_state.size(1)
if conv_state_indices is None:
assert conv_state.size(0) >= batch
else:
assert (batch, ) == conv_state_indices.shape
assert (batch,) == conv_state_indices.shape
assert num_cache_lines >= batch
assert weight.stride(1) == 1 # Need this
@@ -1133,10 +1176,10 @@ def causal_conv1d_update(
stride_o_token, stride_o_dim = out.stride()
stride_o_seq = 0
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride()
stride_state_indices = (
conv_state_indices.stride(0) if conv_state_indices is not None else 0
)
stride_state_indices = conv_state_indices.stride(
0) if conv_state_indices is not None else 0
if num_accepted_tokens is not None:
state_len = width - 1 + (seqlen - 1) # effective state_len needed
else:

View File

@@ -46,17 +46,17 @@ def _layer_norm_fwd_1pass_kernel(
B += group * N
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_Z and not NORM_BEFORE_GATE:
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
x *= z * tl.sigmoid(z)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.)
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
@@ -74,15 +74,17 @@ def _layer_norm_fwd_1pass_kernel(
tl.store(Y + cols, y, mask=mask)
def _layer_norm_fwd(x,
weight,
bias,
eps,
z=None,
out=None,
group_size=None,
norm_before_gate=True,
is_rms_norm=False):
def _layer_norm_fwd(
x,
weight,
bias,
eps,
z=None,
out=None,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
M, N = x.shape
if group_size is None:
group_size = N
@@ -92,57 +94,57 @@ def _layer_norm_fwd(x,
if z is not None:
assert z.stride(-1) == 1
assert z.shape == (M, N)
assert weight.shape == (N, )
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N, )
assert bias.shape == (N,)
# allocate output
if out is not None:
assert out.shape == x.shape
else:
out = torch.empty_like(x)
assert out.stride(-1) == 1
mean = torch.empty((ngroups * M, ), dtype=torch.float32,
device=x.device) if not is_rms_norm else None
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
mean = (
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
if not is_rms_norm
else None
)
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
if group_size > BLOCK_N:
raise RuntimeError(
"This layer norm doesn't support feature dim >= 64KB.")
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8)
grid = (M, ngroups)
with torch.cuda.device(x.device.index):
_layer_norm_fwd_1pass_kernel[grid](x,
out,
weight,
bias,
z,
mean,
rstd,
x.stride(0),
out.stride(0),
z.stride(0) if z is not None else 0,
M,
group_size,
eps,
BLOCK_N=BLOCK_N,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps)
_layer_norm_fwd_1pass_kernel[grid](
x,
out,
weight,
bias,
z,
mean,
rstd,
x.stride(0),
out.stride(0),
z.stride(0) if z is not None else 0,
M,
group_size,
eps,
BLOCK_N=BLOCK_N,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps,
)
return out, mean, rstd
def rms_norm_gated(x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True):
def rms_norm_gated(
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
):
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
@@ -156,13 +158,15 @@ def rms_norm_gated(x,
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, _, _ = _layer_norm_fwd(x,
weight,
bias,
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=True)
y, _, _ = _layer_norm_fwd(
x,
weight,
bias,
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=True,
)
return y.reshape(x_shape_og)

View File

@@ -11,8 +11,7 @@ from vllm import _custom_ops as ops
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.triton_utils import HAS_TRITON, tl, triton
TRITON3 = HAS_TRITON and (version.parse(triton.__version__)
>= version.parse("3.0.0"))
TRITON3 = HAS_TRITON and (version.parse(triton.__version__) >= version.parse("3.0.0"))
if TRITON3:
@@ -28,16 +27,18 @@ else:
return dt
@triton.heuristics(
{"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
@triton.heuristics({
"HAS_STATE_BATCH_INDICES":
lambda args: args["state_batch_indices_ptr"] is not None
})
@triton.heuristics(
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
{
"HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
is not None
}
)
@triton.heuristics(
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
)
@triton.jit
def _selective_scan_update_kernel(
# Pointers to matrices
@@ -110,15 +111,16 @@ def _selective_scan_update_kernel(
if HAS_STATE_BATCH_INDICES:
dst_state_batch_indices_ptr += pid_b
dst_state_batch_idx = tl.load(dst_state_batch_indices_ptr).to(tl.int64)
dst_state_ptr = state_ptr + (dst_state_batch_idx * stride_state_batch +
pid_h * stride_state_head)
dst_state_ptr = state_ptr + (
dst_state_batch_idx * stride_state_batch + pid_h * stride_state_head
)
state_batch_indices_ptr += pid_b
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
state_ptr += (state_batch_idx * stride_state_batch +
pid_h * stride_state_head)
state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
else:
dst_state_ptr = state_ptr + pid_b * stride_state_batch + \
pid_h * stride_state_head
dst_state_ptr = (
state_ptr + pid_b * stride_state_batch + pid_h * stride_state_head
)
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
@@ -126,28 +128,29 @@ def _selective_scan_update_kernel(
if HAS_DT_BIAS:
dt_bias_ptr += pid_h * stride_dt_bias_head
A_ptr += pid_h * stride_A_head
B_ptr += pid_b * stride_B_batch + (pid_h //
nheads_ngroups_ratio) * stride_B_group
C_ptr += pid_b * stride_C_batch + (pid_h //
nheads_ngroups_ratio) * stride_C_group
B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
if HAS_Z:
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim +
offs_n[None, :] * stride_state_dstate)
dst_state_ptrs = dst_state_ptr + (offs_m[:, None] * stride_state_dim +
offs_n[None, :] * stride_state_dstate)
state_ptrs = state_ptr + (
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
)
dst_state_ptrs = dst_state_ptr + (
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
)
x_ptrs = x_ptr + offs_m * stride_x_dim
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
if HAS_DT_BIAS:
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
if HAS_D:
D_ptr += pid_h * stride_D_head
A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim +
offs_n[None, :] * stride_A_dstate)
A_ptrs = A_ptr + (
offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
)
B_ptrs = B_ptr + offs_n * stride_B_dstate
C_ptrs = C_ptr + offs_n * stride_C_dstate
if HAS_D:
@@ -157,20 +160,19 @@ def _selective_scan_update_kernel(
out_ptrs = out_ptr + offs_m * stride_out_dim
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= (state_batch_idx != pad_slot_id)
mask &= state_batch_idx != pad_slot_id
state = tl.load(state_ptrs, mask=mask, other=0.0)
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if not TIE_HDIM:
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim,
other=0.0).to(tl.float32)
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(A_ptrs,
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
other=0.0).to(tl.float32)
A = tl.load(
A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
).to(tl.float32)
dA = tl.exp(A * dt[:, None])
else:
dt = tl.load(dt_ptr).to(tl.float32)
@@ -193,7 +195,7 @@ def _selective_scan_update_kernel(
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= (state_batch_idx != pad_slot_id)
mask &= state_batch_idx != pad_slot_id
tl.store(dst_state_ptrs, state, mask=mask)
out = tl.sum(state * C[None, :], axis=1)
if HAS_D:
@@ -203,20 +205,22 @@ def _selective_scan_update_kernel(
tl.store(out_ptrs, out, mask=offs_m < dim)
def selective_state_update(state,
x,
dt,
A,
B,
C,
D=None,
z=None,
dt_bias=None,
dt_softplus=False,
state_batch_indices=None,
dst_state_batch_indices=None,
pad_slot_id=PAD_SLOT_ID,
out=None):
def selective_state_update(
state,
x,
dt,
A,
B,
C,
D=None,
z=None,
dt_bias=None,
dt_softplus=False,
state_batch_indices=None,
dst_state_batch_indices=None,
pad_slot_id=PAD_SLOT_ID,
out=None,
):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
@@ -229,12 +233,12 @@ def selective_state_update(state,
z: (batch, dim) or (batch, nheads, dim)
dt_bias: (dim,) or (nheads, dim)
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: Preallocated ssm output tensor. Assume same shape as x.
out: Preallocated ssm output tensor. Assume same shape as x.
In-place updated.
"""
if state.dim() == 3:
@@ -275,25 +279,33 @@ def selective_state_update(state,
if dt_bias is not None:
assert dt_bias.shape == (nheads, dim)
if state_batch_indices is not None:
assert state_batch_indices.shape == (batch, )
assert state_batch_indices.shape == (batch,)
if dst_state_batch_indices is not None:
assert dst_state_batch_indices.shape == (batch, )
assert dst_state_batch_indices.shape == (batch,)
else:
# revert to the default behavior of in-place state updates
dst_state_batch_indices = state_batch_indices
assert out.shape == x.shape
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
(0, 0, 0))
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
# We don't want autotune since it will overwrite the state
# We instead tune by hand.
BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else
((16, 4) if dstate <= 32 else
((8, 4) if dstate <= 64 else
((4, 4) if dstate <= 128 else ((4, 8))))))
tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(
-1) == 0 and dt_bias.stride(-1) == 0
BLOCK_SIZE_M, num_warps = (
(32, 4)
if dstate <= 16
else (
(16, 4)
if dstate <= 32
else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
)
)
tie_hdim = (
A.stride(-1) == 0
and A.stride(-2) == 0
and dt.stride(-1) == 0
and dt_bias.stride(-1) == 0
)
with torch.cuda.device(x.device.index):
_selective_scan_update_kernel[grid](
state,
@@ -324,8 +336,7 @@ def selective_state_update(state,
dt.stride(0),
dt.stride(1),
dt.stride(2),
*(dt_bias.stride(0),
dt_bias.stride(1)) if dt_bias is not None else 0,
*(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
A.stride(0),
A.stride(1),
A.stride(2),
@@ -349,54 +360,56 @@ def selective_state_update(state,
)
def selective_scan_fn(u,
ssm_states,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
delta_softplus=False,
query_start_loc=None,
cache_indices=None,
has_initial_state=None,
pad_slot_id=PAD_SLOT_ID) -> torch.Tensor:
def selective_scan_fn(
u,
ssm_states,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
delta_softplus=False,
query_start_loc=None,
cache_indices=None,
has_initial_state=None,
pad_slot_id=PAD_SLOT_ID,
) -> torch.Tensor:
"""
u: (dim, total_length) for varlen or (batch, dim, seqlen)
u: (dim, total_length) for varlen or (batch, dim, seqlen)
applies changes in place.
ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate)
applies changes in place.
delta: (dim, total_length) for varlen or (batch, dim, seqlen)
A: (dim, dstate)
B: (ngroups, dstate, total_length) for varlen or
A: (dim, dstate)
B: (ngroups, dstate, total_length) for varlen or
(batch,ngroups,dstate,seqlen)
C: (ngroups, dstate, total_length) for varlen or
C: (ngroups, dstate, total_length) for varlen or
(batch,ngroups,dstate,seqlen)
D: (dim,)
z: (dim, total_length) for varlen or (batch, dim, seqlen)
D: (dim,)
z: (dim, total_length) for varlen or (batch, dim, seqlen)
dt_bias: (dim,) or (dim)
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended with 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
A tensor with each cell is a correspondent
A tensor with each cell is a correspondent
input and output ssm_state index
has_initial_state: (batch) bool
A tensor populated with ones and zeros,
indicate if the ssm_state at the corresponding index should be
used as initial state. Not providing argument assumes
A tensor populated with ones and zeros,
indicate if the ssm_state at the corresponding index should be
used as initial state. Not providing argument assumes
there's no initial state
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padding entries
that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
if cache_indices is passed, lets the kernel identify padding entries
that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at indices 0 and 3
returns
output: (dim, total_length) for varlen or (batch, dim, seqlen)
output: (dim, total_length) for varlen or (batch, dim, seqlen)
supports inplace replacement
"""
if u.stride(-1) != 1:
@@ -420,9 +433,22 @@ def selective_scan_fn(u,
if C.dim() == 2 and query_start_loc is not None:
C = C.unsqueeze(0)
ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus,
query_start_loc, cache_indices, has_initial_state,
ssm_states, pad_slot_id)
ops.selective_scan_fwd(
u,
delta,
A,
B,
C,
D,
z,
delta_bias,
delta_softplus,
query_start_loc,
cache_indices,
has_initial_state,
ssm_states,
pad_slot_id,
)
if z is None:
return delta # output written inplace to delta

View File

@@ -14,79 +14,52 @@ from vllm.triton_utils import tl, triton
@triton.autotune(
configs=[
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 64
},
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8),
num_warps=8,
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4),
num_warps=4,
),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4),
num_warps=4,
),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4),
num_warps=4,
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4),
num_warps=4,
),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4),
num_warps=4,
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2),
num_warps=2,
),
triton.Config(
{
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2),
num_warps=2,
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=2),
num_warps=2,
),
],
key=['chunk_size', 'K', 'IS_CAUSAL'],
key=["chunk_size", "K", "IS_CAUSAL"],
)
@triton.jit
def _bmm_chunk_fwd_kernel(
@@ -136,24 +109,26 @@ def _bmm_chunk_fwd_kernel(
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen +
offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk +
offs_n[None, :] * stride_b_seqlen)
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# compute a * b.T
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs,
mask=(offs_m[:, None] < chunk_size_limit) &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0).to(dot_dtype)
b = tl.load(b_ptrs,
mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) &
(offs_n[None, :] < chunk_size_limit),
other=0.0).to(dot_dtype)
a = tl.load(
a_ptrs,
mask=(offs_m[:, None] < chunk_size_limit)
& (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
).to(dot_dtype)
b = tl.load(
b_ptrs,
mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K)
& (offs_n[None, :] < chunk_size_limit),
other=0.0,
).to(dot_dtype)
acc += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
@@ -163,20 +138,15 @@ def _bmm_chunk_fwd_kernel(
out = acc.to(out_ptr.dtype.element_ty)
out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] +
offs_n[None, :] * stride_outn)
tl.store(out_ptrs,
out,
mask=(offs_m[:, None] < chunk_size) &
(offs_n[None, :] < chunk_size))
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
tl.store(
out_ptrs,
out,
mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size),
)
def _bmm_chunk_fwd(a,
b,
chunk_size,
cu_chunk_seqlens,
causal=False,
output_dtype=None):
def _bmm_chunk_fwd(a, b, chunk_size, cu_chunk_seqlens, causal=False, output_dtype=None):
"""
Argument:
a: (seqlen, ngroups, k)
@@ -198,16 +168,23 @@ def _bmm_chunk_fwd(a,
nchunks = len(cu_chunk_seqlens) - 1
# Allocates output.
out_dtype = a.dtype if output_dtype is None else output_dtype
out = torch.empty((nchunks, ngroups, chunk_size, chunk_size),
device=a.device,
dtype=out_dtype)
dot_dtype = (tl.bfloat16
if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
(tl.float16 if a.dtype == torch.float16
or b.dtype == torch.float16 else tl.float32))
grid = lambda META: (triton.cdiv(
chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
chunk_size, META['BLOCK_SIZE_N']), nchunks * ngroups)
out = torch.empty(
(nchunks, ngroups, chunk_size, chunk_size), device=a.device, dtype=out_dtype
)
dot_dtype = (
tl.bfloat16
if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16
else (
tl.float16
if a.dtype == torch.float16 or b.dtype == torch.float16
else tl.float32
)
)
grid = lambda META: (
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
* triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]),
nchunks * ngroups,
)
with torch.cuda.device(a.device.index):
_bmm_chunk_fwd_kernel[grid](
a_ptr=a,

View File

@@ -10,101 +10,68 @@ from packaging import version
from vllm.triton_utils import tl, triton
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
@triton.autotune(
configs=[
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 64
},
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8),
num_warps=8,
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4),
num_warps=4,
),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4),
num_warps=4,
),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4),
num_warps=4,
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4),
num_warps=4,
),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 64
},
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=4),
num_warps=4,
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 64
},
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=4),
num_warps=4,
),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4),
num_warps=4,
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2),
num_warps=2,
),
triton.Config(
{
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2),
num_warps=2,
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=2),
num_warps=2,
),
],
key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'],
key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"],
)
@triton.jit
def _chunk_scan_fwd_kernel(
@@ -177,15 +144,16 @@ def _chunk_scan_fwd_kernel(
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
cb_ptr += pid_c * stride_cb_chunk + (pid_h //
nheads_ngroups_ratio) * stride_cb_head
cb_ptr += pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
C_ptr += chunk_seqlen_start * stride_C_seqlen + (
pid_h // nheads_ngroups_ratio) * stride_C_head
C_ptr += (
chunk_seqlen_start * stride_C_seqlen
+ (pid_h // nheads_ngroups_ratio) * stride_C_head
)
# M-block offsets and prev states
# - logic in next block may override these if there is an active offset
@@ -193,26 +161,31 @@ def _chunk_scan_fwd_kernel(
seq_idx_ptr += pid_c * stride_seq_idx_chunk
seq_idx = tl.load(seq_idx_ptr)
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_chunk,
mask=pid_c >= 1,
other=-1)
seq_idx_prev = tl.load(
seq_idx_ptr - stride_seq_idx_chunk, mask=pid_c >= 1, other=-1
)
if HAS_INITSTATES and (seq_idx != seq_idx_prev):
prev_states_ptr = initstates_ptr + seq_idx * stride_init_states_batch + pid_h * stride_init_states_head
prev_states_ptr = (
initstates_ptr
+ seq_idx * stride_init_states_batch
+ pid_h * stride_init_states_head
)
prev_states_hdim = stride_init_states_hdim
prev_states_dstate = stride_init_states_dstate
else:
prev_states_ptr = states_ptr + (
pid_c - 1) * stride_states_chunk + pid_h * stride_states_head
prev_states_ptr = (
states_ptr + (pid_c - 1) * stride_states_chunk + pid_h * stride_states_head
)
prev_states_hdim = stride_states_hdim
prev_states_dstate = stride_states_dstate
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
mask=offs_m < chunk_size,
other=0.0).to(tl.float32)
dA_cs_m = tl.load(
dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
).to(tl.float32)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
@@ -221,52 +194,66 @@ def _chunk_scan_fwd_kernel(
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
offs_k_dstate = tl.arange(
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen +
offs_k_dstate[None, :] * stride_C_dstate)
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
)
C_ptrs = C_ptr + (
offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate
)
scale_m = tl.exp(dA_cs_m)
if BLOCK_SIZE_DSTATE <= 128:
C = tl.load(C_ptrs,
mask=(offs_m[:, None] < chunk_size_limit) &
(offs_k_dstate[None, :] < dstate),
other=0.0)
C = tl.load(
C_ptrs,
mask=(offs_m[:, None] < chunk_size_limit)
& (offs_k_dstate[None, :] < dstate),
other=0.0,
)
if not HAS_INITSTATES and (seq_idx != seq_idx_prev):
# if no init states AND starting a new sequence, we need zeros
prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N),
dtype=C_ptr.dtype.element_ty)
prev_states = tl.zeros(
(BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty
)
else:
# otherwise read the previous state
prev_states_ptrs = prev_states_ptr \
+ offs_n[None, :] * prev_states_hdim \
+ offs_k_dstate[:, None] * prev_states_dstate
prev_states = tl.load(prev_states_ptrs,
mask=(offs_k_dstate[:, None] < dstate) &
(offs_n[None, :] < hdim),
other=0.0)
prev_states_ptrs = (
prev_states_ptr
+ offs_n[None, :] * prev_states_hdim
+ offs_k_dstate[:, None] * prev_states_dstate
)
prev_states = tl.load(
prev_states_ptrs,
mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim),
other=0.0,
)
prev_states = prev_states.to(C_ptr.dtype.element_ty)
acc = tl.dot(C, prev_states) * scale_m[:, None]
else:
prev_states_ptrs = prev_states_ptr \
+ offs_n[None, :] * prev_states_hdim \
+ offs_k_dstate[:, None] * prev_states_dstate
prev_states_ptrs = (
prev_states_ptr
+ offs_n[None, :] * prev_states_hdim
+ offs_k_dstate[:, None] * prev_states_dstate
)
for k in range(0, dstate, BLOCK_SIZE_K):
C = tl.load(C_ptrs,
mask=(offs_m[:, None] < chunk_size_limit) &
(offs_k_dstate[None, :] < dstate - k),
other=0.0)
C = tl.load(
C_ptrs,
mask=(offs_m[:, None] < chunk_size_limit)
& (offs_k_dstate[None, :] < dstate - k),
other=0.0,
)
if not HAS_INITSTATES and (seq_idx != seq_idx_prev):
prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_K),
dtype=C_ptr.dtype.element_ty)
prev_states = tl.zeros(
(BLOCK_SIZE_DSTATE, BLOCK_SIZE_K), dtype=C_ptr.dtype.element_ty
)
else:
prev_states = tl.load(
prev_states_ptrs,
mask=(offs_k_dstate[:, None] < dstate - k) &
(offs_n[None, :] < hdim),
other=0.0)
mask=(offs_k_dstate[:, None] < dstate - k)
& (offs_n[None, :] < hdim),
other=0.0,
)
prev_states = prev_states.to(C_ptr.dtype.element_ty)
acc += tl.dot(C, prev_states)
C_ptrs += BLOCK_SIZE_K
@@ -274,36 +261,42 @@ def _chunk_scan_fwd_kernel(
acc *= scale_m[:, None]
offs_k = tl.arange(0, BLOCK_SIZE_K)
cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m +
offs_k[None, :] * stride_cb_csize_k)
x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen +
offs_n[None, :] * stride_x_hdim)
cb_ptrs = cb_ptr + (
offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k
)
x_ptrs = x_ptr + (
offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
)
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
K_MAX = chunk_size_limit if not IS_CAUSAL else min(
(pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
K_MAX = (
chunk_size_limit
if not IS_CAUSAL
else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
)
for k in range(0, K_MAX, BLOCK_SIZE_K):
cb = tl.load(cb_ptrs,
mask=(offs_m[:, None] < chunk_size) &
(offs_k[None, :] < chunk_size - k),
other=0.0).to(tl.float32)
dA_cs_k = tl.load(dA_cumsum_ptrs,
mask=offs_k < chunk_size - k,
other=0.0).to(tl.float32)
cb = tl.load(
cb_ptrs,
mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k),
other=0.0,
).to(tl.float32)
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(
tl.float32
)
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
# So we don't need masking wrt seq_idx here.
cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :])
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k,
other=0.0).to(tl.float32)
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
cb *= dt_k
if IS_CAUSAL:
mask = offs_m[:, None] >= k + offs_k[None, :]
cb = tl.where(mask, cb, 0.0)
cb = cb.to(x_ptr.dtype.element_ty)
x = tl.load(x_ptrs,
mask=(offs_k[:, None] < chunk_size_limit - k) &
(offs_n[None, :] < hdim),
other=0.0)
x = tl.load(
x_ptrs,
mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim),
other=0.0,
)
acc += tl.dot(cb, x)
cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
@@ -315,35 +308,41 @@ def _chunk_scan_fwd_kernel(
if HAS_D:
if D_HAS_HDIM:
D = tl.load(D_ptr + pid_h * stride_D_head + offs_n,
mask=offs_n < hdim,
other=0.0).to(tl.float32)
D = tl.load(
D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0
).to(tl.float32)
else:
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen +
offs_n[None, :] * stride_x_hdim),
mask=(offs_m[:, None] < chunk_size_limit) &
(offs_n[None, :] < hdim),
other=0.0).to(tl.float32)
x_residual = tl.load(
x_ptr
+ (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim),
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
other=0.0,
).to(tl.float32)
acc += x_residual * D
if HAS_Z:
z_ptr += chunk_seqlen_start * stride_z_seqlen + pid_h * stride_z_head
z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] +
stride_z_hdim * offs_out_n[None, :])
z = tl.load(z_ptrs,
mask=(offs_out_m[:, None] < chunk_size_limit) &
(offs_out_n[None, :] < hdim),
other=0.0).to(tl.float32)
z_ptrs = z_ptr + (
stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]
)
z = tl.load(
z_ptrs,
mask=(offs_out_m[:, None] < chunk_size_limit)
& (offs_out_n[None, :] < hdim),
other=0.0,
).to(tl.float32)
acc *= z * tl.sigmoid(z)
out_ptr += chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head
out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] +
offs_out_n[None, :] * stride_out_hdim)
tl.store(out_ptrs,
acc,
mask=(offs_out_m[:, None] < chunk_size_limit) &
(offs_out_n[None, :] < hdim))
out_ptrs = out_ptr + (
stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim
)
tl.store(
out_ptrs,
acc,
mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim),
)
def _chunk_scan_fwd(
@@ -369,24 +368,32 @@ def _chunk_scan_fwd(
assert C.shape == (seqlen, ngroups, dstate)
assert cb.shape == (nchunks, ngroups, chunk_size, chunk_size)
if D is not None:
assert D.shape == (nheads, headdim) or D.shape == (nheads, )
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
if z is not None:
assert z.shape == x.shape
assert dt.shape == (nheads, nchunks, chunk_size)
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
assert states.shape == (nchunks, nheads, headdim, dstate)
assert seq_idx.shape == (nchunks, )
assert seq_idx.shape == (nchunks,)
grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton
.cdiv(headdim, META['BLOCK_SIZE_N']), nchunks, nheads)
grid = lambda META: (
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
nchunks,
nheads,
)
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
(0, 0, 0))
initial_states_strides = ((initial_states.stride(0),
initial_states.stride(1),
initial_states.stride(2),
initial_states.stride(3))
if initial_states is not None else (0, 0, 0, 0))
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
initial_states_strides = (
(
initial_states.stride(0),
initial_states.stride(1),
initial_states.stride(2),
initial_states.stride(3),
)
if initial_states is not None
else (0, 0, 0, 0)
)
_chunk_scan_fwd_kernel[grid](
cb_ptr=cb,

View File

@@ -15,14 +15,14 @@ from .mamba_ssm import softplus
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_H': 2}),
triton.Config({'BLOCK_SIZE_H': 4}),
triton.Config({'BLOCK_SIZE_H': 8}),
triton.Config({'BLOCK_SIZE_H': 16}),
triton.Config({'BLOCK_SIZE_H': 32}),
triton.Config({'BLOCK_SIZE_H': 64}),
triton.Config({"BLOCK_SIZE_H": 2}),
triton.Config({"BLOCK_SIZE_H": 4}),
triton.Config({"BLOCK_SIZE_H": 8}),
triton.Config({"BLOCK_SIZE_H": 16}),
triton.Config({"BLOCK_SIZE_H": 32}),
triton.Config({"BLOCK_SIZE_H": 64}),
],
key=['chunk_size', 'nheads'],
key=["chunk_size", "nheads"],
)
@triton.jit
def _chunk_cumsum_fwd_kernel(
@@ -70,118 +70,99 @@ def _chunk_cumsum_fwd_kernel(
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head +
offs_c[None, :] * stride_dt_seqlen)
dt_ptrs = dt_ptr + (
offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
)
A_ptrs = A_ptr + offs_h * stride_A_head
dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head +
offs_c[None, :] * stride_dt_out_csize)
dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head +
offs_c[None, :] * stride_dA_cs_csize)
dt_out_ptrs = dt_out_ptr + (
offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize
)
dA_cs_ptrs = dA_cumsum_ptr + (
offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize
)
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
dt = tl.load(dt_ptrs,
mask=(offs_h[:, None] < nheads) &
(offs_c[None, :] < chunk_size_limit),
other=0.0).to(tl.float32)
dt = tl.load(
dt_ptrs,
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
other=0.0,
).to(tl.float32)
if HAS_DT_BIAS:
dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head,
mask=offs_h < nheads,
other=0.0).to(tl.float32)
dt_bias = tl.load(
dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
).to(tl.float32)
dt += dt_bias[:, None]
if DT_SOFTPLUS:
dt = tl.where(dt <= 20.0, softplus(dt), dt)
dt = tl.clamp(dt, dt_min, dt_max)
dt = tl.where(
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt,
0.0)
tl.store(dt_out_ptrs,
dt,
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
)
tl.store(
dt_out_ptrs,
dt,
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
)
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
dA = dt * A[:, None]
dA_cs = tl.cumsum(dA, axis=1)
tl.store(dA_cs_ptrs,
dA_cs,
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
tl.store(
dA_cs_ptrs,
dA_cs,
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
)
@triton.autotune(
configs=[
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 64
},
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8),
num_warps=8,
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4),
num_warps=4,
),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4),
num_warps=4,
),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4),
num_warps=4,
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4),
num_warps=4,
),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4),
num_warps=4,
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2),
num_warps=2,
),
triton.Config(
{
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2),
num_warps=2,
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=2),
num_warps=2,
),
],
key=['hdim', 'dstate', 'chunk_size'],
key=["hdim", "dstate", "chunk_size"],
)
@triton.jit
def _chunk_state_fwd_kernel(
@@ -227,8 +208,10 @@ def _chunk_state_fwd_kernel(
pid_n = tl.program_id(axis=0) % num_pid_n
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
b_ptr += chunk_seqlen_start * stride_b_seqlen + (
pid_h // nheads_ngroups_ratio) * stride_b_head
b_ptr += (
chunk_seqlen_start * stride_b_seqlen
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
)
x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
@@ -236,32 +219,38 @@ def _chunk_state_fwd_kernel(
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim +
offs_k[None, :] * stride_x_seqlen)
b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate +
offs_k[:, None] * stride_b_seqlen)
x_ptrs = x_ptr + (
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
)
b_ptrs = b_ptr + (
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
)
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
dA_cs_last = tl.load(dA_cumsum_ptr +
(chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
tl.float32
)
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
x = tl.load(x_ptrs,
mask=(offs_m[:, None] < hdim) &
(offs_k[None, :] < chunk_size_limit - k),
other=0.0)
b = tl.load(b_ptrs,
mask=(offs_k[:, None] < chunk_size_limit - k) &
(offs_n[None, :] < dstate),
other=0.0).to(tl.float32)
dA_cs_k = tl.load(dA_cumsum_ptrs,
mask=offs_k < chunk_size_limit - k,
other=0.0).to(tl.float32)
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k,
other=0.0).to(tl.float32)
x = tl.load(
x_ptrs,
mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k),
other=0.0,
)
b = tl.load(
b_ptrs,
mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate),
other=0.0,
).to(tl.float32)
dA_cs_k = tl.load(
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
).to(tl.float32)
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
tl.float32
)
scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k
b *= scale[:, None]
b = b.to(x_ptr.dtype.element_ty)
@@ -277,8 +266,9 @@ def _chunk_state_fwd_kernel(
states_ptr += pid_c * stride_states_chunk + pid_h * stride_states_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim +
offs_n[None, :] * stride_states_dstate)
states_ptrs = states_ptr + (
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
)
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
tl.store(states_ptrs, states, mask=c_mask)
@@ -286,79 +276,52 @@ def _chunk_state_fwd_kernel(
@triton.autotune(
configs=[
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 64
},
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8),
num_warps=8,
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4),
num_warps=4,
),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4),
num_warps=4,
),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4),
num_warps=4,
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4),
num_warps=4,
),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4),
num_warps=4,
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2),
num_warps=2,
),
triton.Config(
{
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2),
num_warps=2,
),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=2),
num_warps=2,
),
],
key=['hdim', 'dstate', 'chunk_size'],
key=["hdim", "dstate", "chunk_size"],
)
@triton.jit
def _chunk_state_varlen_kernel(
@@ -414,12 +377,16 @@ def _chunk_state_varlen_kernel(
pid_n = tl.program_id(axis=0) % num_pid_n
end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
pid_c = (end_idx - 1) // chunk_size
b_ptr += pid_c * chunk_size * stride_b_seqlen + (
pid_h // nheads_ngroups_ratio) * stride_b_head
b_ptr += (
pid_c * chunk_size * stride_b_seqlen
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
)
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
chunk_states_ptr += (
pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
)
if HAS_INITSTATES:
# if there are init states provided, we differentiate between states (which
@@ -430,13 +397,16 @@ def _chunk_state_varlen_kernel(
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim +
offs_k[None, :] * stride_x_seqlen)
b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate +
offs_k[:, None] * stride_b_seqlen)
x_ptrs = x_ptr + (
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
)
b_ptrs = b_ptr + (
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
)
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) *
stride_dA_cs_csize).to(tl.float32)
dA_cs_last = tl.load(
dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
).to(tl.float32)
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
chunk_size_limit = end_idx - pid_c * chunk_size
@@ -445,24 +415,31 @@ def _chunk_state_varlen_kernel(
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
x = tl.load(x_ptrs,
mask=(offs_m[:, None] < hdim) &
(offs_k[None, :] < chunk_size_limit - k) &
(offs_k[None, :] >= start_idx_cur - k),
other=0.0)
b = tl.load(b_ptrs,
mask=(offs_k[:, None] < chunk_size_limit - k) &
(offs_n[None, :] < dstate) &
(offs_k[:, None] >= start_idx_cur - k),
other=0.0).to(tl.float32)
dA_cs_k = tl.load(dA_cumsum_ptrs,
mask=offs_k < chunk_size_limit - k,
other=0.0).to(tl.float32)
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k,
other=0.0).to(tl.float32)
x = tl.load(
x_ptrs,
mask=(offs_m[:, None] < hdim)
& (offs_k[None, :] < chunk_size_limit - k)
& (offs_k[None, :] >= start_idx_cur - k),
other=0.0,
)
b = tl.load(
b_ptrs,
mask=(offs_k[:, None] < chunk_size_limit - k)
& (offs_n[None, :] < dstate)
& (offs_k[:, None] >= start_idx_cur - k),
other=0.0,
).to(tl.float32)
dA_cs_k = tl.load(
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
).to(tl.float32)
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
tl.float32
)
scale = tl.where(
(offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0)
tl.exp(dA_cs_last - dA_cs_k) * dt_k,
0.0,
)
b *= scale[:, None]
b = b.to(x_ptr.dtype.element_ty)
acc += tl.dot(x, b)
@@ -475,39 +452,43 @@ def _chunk_state_varlen_kernel(
# If HAS_INITSTATES==True need to consider two possibilities
# - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs
# - if state_idx >= pid * chunk_size, then we need to insert initstates
if ((start_idx < pid_c * chunk_size) # first chunk
or (HAS_INITSTATES)):
if (
(start_idx < pid_c * chunk_size) # first chunk
or (HAS_INITSTATES)
):
dA_cs_boundary = 0.0 # default
if not HAS_INITSTATES:
past_states_ptrs = chunk_states_ptr + (
offs_m[:, None] * stride_chunk_states_hdim +
offs_n[None, :] * stride_chunk_states_dstate)
offs_m[:, None] * stride_chunk_states_hdim
+ offs_n[None, :] * stride_chunk_states_dstate
)
else:
# - this seems repetitive, buts its to help the compiler
if start_idx < pid_c * chunk_size:
past_states_ptrs = chunk_states_ptr + (
offs_m[:, None] * stride_chunk_states_hdim +
offs_n[None, :] * stride_chunk_states_dstate)
offs_m[:, None] * stride_chunk_states_hdim
+ offs_n[None, :] * stride_chunk_states_dstate
)
else:
past_states_ptrs = initstates_ptr + (
pid_b * stride_init_states_batch +
offs_m[:, None] * stride_init_states_hdim +
offs_n[None, :] * stride_init_states_dstate)
pid_b * stride_init_states_batch
+ offs_m[:, None] * stride_init_states_hdim
+ offs_n[None, :] * stride_init_states_dstate
)
# need to adjust the boundary
if start_idx > pid_c * chunk_size:
dA_cs_boundary = tl.load(dA_cumsum_ptr +
(start_idx - pid_c * chunk_size -
1) * stride_dA_cs_csize).to(
tl.float32)
dA_cs_boundary = tl.load(
dA_cumsum_ptr
+ (start_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
).to(tl.float32)
past_states = tl.load(past_states_ptrs,
mask=(offs_m[:, None] < hdim) &
(offs_n[None, :] < dstate),
other=0.0).to(tl.float32)
past_states = tl.load(
past_states_ptrs,
mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate),
other=0.0,
).to(tl.float32)
scale = tl.exp(dA_cs_last - dA_cs_boundary)
acc += past_states * scale
@@ -517,36 +498,34 @@ def _chunk_state_varlen_kernel(
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim +
offs_n[None, :] * stride_states_dstate)
states_ptrs = states_ptr + (
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
)
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
tl.store(states_ptrs, states, mask=c_mask)
def _chunk_cumsum_fwd(dt,
A,
chunk_size,
cu_chunk_seqlens,
dt_bias=None,
dt_softplus=False,
dt_limit=(0.0, float("inf"))):
def _chunk_cumsum_fwd(
dt,
A,
chunk_size,
cu_chunk_seqlens,
dt_bias=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
):
seqlen, nheads = dt.shape
assert A.shape == (nheads, )
assert A.shape == (nheads,)
if dt_bias is not None:
assert dt_bias.shape == (nheads, )
assert dt_bias.shape == (nheads,)
nchunks = cu_chunk_seqlens.shape[0] - 1
dt_out = torch.empty(nheads,
nchunks,
chunk_size,
device=dt.device,
dtype=torch.float32)
dA_cumsum = torch.empty(nheads,
nchunks,
chunk_size,
device=dt.device,
dtype=torch.float32)
grid_chunk_cs = lambda META: (nchunks,
triton.cdiv(nheads, META['BLOCK_SIZE_H']))
dt_out = torch.empty(
nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
)
dA_cumsum = torch.empty(
nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
)
grid_chunk_cs = lambda META: (nchunks, triton.cdiv(nheads, META["BLOCK_SIZE_H"]))
with torch.cuda.device(dt.device.index):
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
dt_ptr=dt,
@@ -563,8 +542,7 @@ def _chunk_cumsum_fwd(dt,
stride_dt_seqlen=dt.stride(0),
stride_dt_head=dt.stride(1),
stride_A_head=A.stride(0),
stride_dt_bias_head=dt_bias.stride(0)
if dt_bias is not None else 0,
stride_dt_bias_head=dt_bias.stride(0) if dt_bias is not None else 0,
stride_dt_out_head=dt_out.stride(0),
stride_dt_out_chunk=dt_out.stride(1),
stride_dt_out_csize=dt_out.stride(2),
@@ -578,13 +556,9 @@ def _chunk_cumsum_fwd(dt,
return dA_cumsum, dt_out
def _chunk_state_fwd(B,
x,
dt,
dA_cumsum,
cu_chunk_seqlens,
states=None,
states_in_fp32=True):
def _chunk_state_fwd(
B, x, dt, dA_cumsum, cu_chunk_seqlens, states=None, states_in_fp32=True
):
seqlen, nheads, headdim = x.shape
_, nchunks, chunk_size = dt.shape
_, ngroups, dstate = B.shape
@@ -597,12 +571,16 @@ def _chunk_state_fwd(B,
assert states.shape == (nchunks, nheads, headdim, dstate)
else:
states_dtype = torch.float32 if states_in_fp32 else B.dtype
states = torch.empty((nchunks, nheads, headdim, dstate),
device=x.device,
dtype=states_dtype)
states = torch.empty(
(nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype
)
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.
cdiv(dstate, META['BLOCK_SIZE_N']), nchunks, nheads)
grid = lambda META: (
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
nchunks,
nheads,
)
with torch.cuda.device(x.device.index):
_chunk_state_fwd_kernel[grid](
x_ptr=x,
@@ -636,13 +614,9 @@ def _chunk_state_fwd(B,
return states
def chunk_state_varlen(B,
x,
dt,
dA_cumsum,
cu_seqlens,
chunk_states,
initial_states=None):
def chunk_state_varlen(
B, x, dt, dA_cumsum, cu_seqlens, chunk_states, initial_states=None
):
total_seqlen, nheads, headdim = x.shape
_, nchunks, chunk_size = dt.shape
_, ngroups, dstate = B.shape
@@ -657,21 +631,32 @@ def chunk_state_varlen(B,
if initial_states is not None:
assert initial_states.shape == (batch, nheads, headdim, dstate)
states = torch.empty(batch,
nheads,
headdim,
dstate,
dtype=chunk_states.dtype,
device=chunk_states.device)
states = torch.empty(
batch,
nheads,
headdim,
dstate,
dtype=chunk_states.dtype,
device=chunk_states.device,
)
initial_states_strides = ((initial_states.stride(0),
initial_states.stride(1),
initial_states.stride(2),
initial_states.stride(3))
if initial_states is not None else (0, 0, 0, 0))
initial_states_strides = (
(
initial_states.stride(0),
initial_states.stride(1),
initial_states.stride(2),
initial_states.stride(3),
)
if initial_states is not None
else (0, 0, 0, 0)
)
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.
cdiv(dstate, META['BLOCK_SIZE_N']), batch, nheads)
grid = lambda META: (
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
batch,
nheads,
)
with torch.cuda.device(x.device.index):
_chunk_state_varlen_kernel[grid](
x_ptr=x,
@@ -710,5 +695,6 @@ def chunk_state_varlen(B,
stride_init_states_head=initial_states_strides[1],
stride_init_states_hdim=initial_states_strides[2],
stride_init_states_dstate=initial_states_strides[3],
HAS_INITSTATES=initial_states is not None)
HAS_INITSTATES=initial_states is not None,
)
return states

View File

@@ -17,63 +17,66 @@ from .ssd_chunk_scan import _chunk_scan_fwd
from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd
from .ssd_state_passing import _state_passing_fwd
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
def is_int_pow_2(n):
return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
def _mamba_chunk_scan_combined_fwd(x,
dt,
A,
B,
C,
chunk_size,
out,
D=None,
z=None,
dt_bias=None,
initial_states=None,
return_intermediate_states=False,
seq_idx=None,
cu_seqlens=None,
cu_chunk_seqlens=None,
last_chunk_indices=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
state_dtype=None):
def _mamba_chunk_scan_combined_fwd(
x,
dt,
A,
B,
C,
chunk_size,
out,
D=None,
z=None,
dt_bias=None,
initial_states=None,
return_intermediate_states=False,
seq_idx=None,
cu_seqlens=None,
cu_chunk_seqlens=None,
last_chunk_indices=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
state_dtype=None,
):
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
seqlen, nheads, headdim = x.shape
_, ngroups, dstate = B.shape
assert nheads % ngroups == 0
assert B.shape == (seqlen, ngroups, dstate)
assert dt.shape == (seqlen, nheads)
assert A.shape == (nheads, )
assert A.shape == (nheads,)
assert C.shape == B.shape
if z is not None:
assert z.shape == x.shape
if D is not None:
assert D.shape == (nheads, headdim) or D.shape == (nheads, )
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
if seq_idx is not None:
assert seq_idx.shape == (cu_chunk_seqlens.shape[0] - 1, )
assert seq_idx.shape == (cu_chunk_seqlens.shape[0] - 1,)
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if x.stride(-1) != 1 and x.stride(
0) != 1: # Either M or K dimension should be contiguous
if (
x.stride(-1) != 1 and x.stride(0) != 1
): # Either M or K dimension should be contiguous
x = x.contiguous()
if z is not None and z.stride(-1) != 1 and z.stride(
0) != 1: # Either M or K dimension should be contiguous
if (
z is not None and z.stride(-1) != 1 and z.stride(0) != 1
): # Either M or K dimension should be contiguous
z = z.contiguous()
if D is not None and D.stride(-1) != 1:
D = D.contiguous()
assert cu_seqlens is not None, "Assuming varlen input - must supply cu_seqlens"
if initial_states is not None:
assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim,
dstate)
assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim, dstate)
# This function executes 5 sub-functions for computing mamba
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
@@ -86,22 +89,21 @@ def _mamba_chunk_scan_combined_fwd(x,
# 1. Compute chunked cumsum of A * dt
# - here dt may go through a softplus activation
dA_cumsum, dt = _chunk_cumsum_fwd(dt,
A,
chunk_size,
cu_chunk_seqlens,
dt_bias=dt_bias,
dt_softplus=dt_softplus,
dt_limit=dt_limit)
dA_cumsum, dt = _chunk_cumsum_fwd(
dt,
A,
chunk_size,
cu_chunk_seqlens,
dt_bias=dt_bias,
dt_softplus=dt_softplus,
dt_limit=dt_limit,
)
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
states = _chunk_state_fwd(B,
x,
dt,
dA_cumsum,
cu_chunk_seqlens,
states_in_fp32=True)
states = _chunk_state_fwd(
B, x, dt, dA_cumsum, cu_chunk_seqlens, states_in_fp32=True
)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
@@ -114,18 +116,15 @@ def _mamba_chunk_scan_combined_fwd(x,
dA_cumsum, # (nheads, nchunks, chunk_size)
cu_chunk_seqlens,
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
if initial_states is not None else
None, # (batch, nheads, headdim*dstate)
if initial_states is not None
else None, # (batch, nheads, headdim*dstate)
seq_idx=seq_idx,
out_dtype=state_dtype if state_dtype is not None else C.dtype)
out_dtype=state_dtype if state_dtype is not None else C.dtype,
)
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
# 4. Compute batched matrix multiply for C_j^T B_i terms
CB = _bmm_chunk_fwd(C,
B,
chunk_size,
cu_chunk_seqlens,
output_dtype=torch.float32)
CB = _bmm_chunk_fwd(C, B, chunk_size, cu_chunk_seqlens, output_dtype=torch.float32)
# 5. Scan and compute the diagonal blocks, taking into
# account past causal states.
@@ -225,6 +224,7 @@ def mamba_chunk_scan_combined_varlen(
last_chunk_indices=last_chunk_indices,
dt_softplus=dt_softplus,
dt_limit=dt_limit,
state_dtype=state_dtype)
state_dtype=state_dtype,
)
return varlen_states

View File

@@ -13,14 +13,14 @@ from vllm.triton_utils import tl, triton
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 64}),
triton.Config({'BLOCK_SIZE': 128}),
triton.Config({'BLOCK_SIZE': 256}),
triton.Config({'BLOCK_SIZE': 512}),
triton.Config({'BLOCK_SIZE': 1024}),
triton.Config({'BLOCK_SIZE': 2048}),
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
],
key=['dim'],
key=["dim"],
)
@triton.jit
def _state_passing_fwd_kernel(
@@ -58,8 +58,7 @@ def _state_passing_fwd_kernel(
pid_m = tl.program_id(axis=0)
states_ptr += pid_h * stride_states_head
dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size -
1) * stride_dA_cs_csize
dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size - 1) * stride_dA_cs_csize
out_ptr += pid_h * stride_out_head
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
@@ -67,31 +66,35 @@ def _state_passing_fwd_kernel(
out_ptrs = out_ptr + offs_m * stride_out_dim
if HAS_INITSTATES:
initstates_ptrs = initstates_ptr \
+ pid_h * stride_initstates_head \
initstates_ptrs = (
initstates_ptr
+ pid_h * stride_initstates_head
+ offs_m * stride_initstates_dim
)
states = tl.load(initstates_ptrs, mask=offs_m < dim,
other=0.0).to(tl.float32)
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
else:
states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
prev_seq_idx = 0
for c in range(nchunks):
new_states = tl.load(states_ptrs, mask=offs_m < dim,
other=0.0).to(tl.float32)
new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
seq_idx = tl.load(seq_idx_ptr + c * stride_seq_idx_chunk)
# we have started a new sequence
if prev_seq_idx != seq_idx:
if HAS_INITSTATES:
initstates_ptrs = initstates_ptr + seq_idx * stride_initstates_batch \
+ pid_h * stride_initstates_head \
initstates_ptrs = (
initstates_ptr
+ seq_idx * stride_initstates_batch
+ pid_h * stride_initstates_head
+ offs_m * stride_initstates_dim
states = tl.load(initstates_ptrs, mask=offs_m < dim,
other=0.0).to(tl.float32)
)
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(
tl.float32
)
else:
states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
prev_seq_idx = seq_idx
states = tl.exp(dA_cs) * states + new_states
@@ -115,16 +118,15 @@ def _state_passing_fwd(
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
seqlen = seq_idx.shape[-1]
out_dtype = states.dtype if out_dtype is None else out_dtype
out = torch.empty((nchunks, nheads, dim),
device=states.device,
dtype=out_dtype)
out = torch.empty((nchunks, nheads, dim), device=states.device, dtype=out_dtype)
initial_states_strides = ((initial_states.stride(0),
initial_states.stride(1),
initial_states.stride(2))
if initial_states is not None else (0, 0, 0))
initial_states_strides = (
(initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))
if initial_states is not None
else (0, 0, 0)
)
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), nheads)
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), nheads)
with torch.cuda.device(states.device.index):
_state_passing_fwd_kernel[grid](
states_ptr=states,

View File

@@ -13,29 +13,35 @@ from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
causal_conv1d_fn,
causal_conv1d_update,
)
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionMetadata)
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionMetadata
@CustomOp.register("short_conv")
class ShortConv(MambaBase, CustomOp):
def __init__(self,
config,
dim: int,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
prefix: str = ""):
def __init__(
self,
config,
dim: int,
layer_idx: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.layer_idx = layer_idx
@@ -72,7 +78,7 @@ class ShortConv(MambaBase, CustomOp):
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.kv_cache = (torch.tensor([]), )
self.kv_cache = (torch.tensor([]),)
self.model_config = model_config
self.cache_config = cache_config
@@ -121,8 +127,9 @@ class ShortConv(MambaBase, CustomOp):
B, C, x = BCx.chunk(3, dim=-1)
conv_weights = self.conv.weight.view(self.conv.weight.size(0),
self.conv.weight.size(2))
conv_weights = self.conv.weight.view(
self.conv.weight.size(0), self.conv.weight.size(2)
)
if attn_metadata is None:
# V1 profile run
@@ -163,23 +170,26 @@ class ShortConv(MambaBase, CustomOp):
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None)
attn_metadata.query_start_loc[-num_prefills - 1 :] - num_decodes
if has_prefill
else None
)
conv_output_list = []
if has_prefill:
Bx_p = (B_p * x_p).transpose(0, 1)
Bx = causal_conv1d_fn(Bx_p,
conv_weights,
self.conv.bias,
activation=None,
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
metadata=attn_metadata,
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]
Bx = causal_conv1d_fn(
Bx_p,
conv_weights,
self.conv.bias,
activation=None,
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
metadata=attn_metadata,
query_start_loc=query_start_loc_p,
).transpose(0, 1)[:num_prefill_tokens]
y = C_p * Bx
conv_output_list.append(y)
@@ -192,7 +202,8 @@ class ShortConv(MambaBase, CustomOp):
conv_weights,
self.conv.bias,
activation=None,
conv_state_indices=state_indices_tensor_d)
conv_state_indices=state_indices_tensor_d,
)
y = C_d * Bx
conv_output_list.insert(0, y)
@@ -222,8 +233,8 @@ class ShortConv(MambaBase, CustomOp):
return "short_conv"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.short_conv_attn import (
ShortConvAttentionBackend)
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
return ShortConvAttentionBackend