Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -11,28 +11,40 @@ from torch import nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
causal_conv1d_fn,
causal_conv1d_update,
)
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_state_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_state_update
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined_varlen)
mamba_chunk_scan_combined_varlen,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction, composed_weight_loader, sharded_weight_loader)
LoaderFunction,
composed_weight_loader,
sharded_weight_loader,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
@@ -43,12 +55,13 @@ from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated
@CustomOp.register("mixer2_gated_rms_norm")
class Mixer2RMSNormGated(CustomOp):
def __init__(self,
full_hidden_size: int,
full_n_groups: int,
use_rms_norm: bool = True,
eps: float = 1e-6):
def __init__(
self,
full_hidden_size: int,
full_n_groups: int,
use_rms_norm: bool = True,
eps: float = 1e-6,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
@@ -62,13 +75,13 @@ class Mixer2RMSNormGated(CustomOp):
if self.use_rms_norm:
# Register norm weight only if we're actually applying RMSNorm
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
set_weight_attrs(self.weight,
{"weight_loader": sharded_weight_loader(0)})
set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
else:
# Avoid checkpoint mismatch by skipping unused parameter
self.register_parameter("weight", None)
assert (self.full_hidden_size % self.tp_size == 0
), "Tensor parallel world size must divide hidden size."
assert self.full_hidden_size % self.tp_size == 0, (
"Tensor parallel world size must divide hidden size."
)
def forward_native(
self,
@@ -111,8 +124,7 @@ class Mixer2RMSNormGated(CustomOp):
group_count = hidden_dim // self.group_size
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
variance = x_grouped.pow(2).mean(-1, keepdim=True)
x_grouped = x_grouped * torch.rsqrt(variance +
self.variance_epsilon)
x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
x = x_grouped.view(*prefix_dims, hidden_dim)
if redundant_tp:
@@ -130,18 +142,19 @@ class Mixer2RMSNormGated(CustomOp):
input_dtype = x.dtype
if not self.use_rms_norm:
# Keep gate in float32 for numerical stability during silu
return x * nn.functional.silu(gate.to(
torch.float32)).to(input_dtype)
return x * nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
if (((self.n_groups % self.tp_size) != 0) or self.n_groups != 1):
if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
return self.forward_native(x, gate)
return rms_norm_gated(x,
self.weight.data,
bias=None,
z=gate,
eps=self.variance_epsilon,
norm_before_gate=False)
return rms_norm_gated(
x,
self.weight.data,
bias=None,
z=gate,
eps=self.variance_epsilon,
norm_before_gate=False,
)
def mamba_v2_sharded_weight_loader(
@@ -156,7 +169,6 @@ def mamba_v2_sharded_weight_loader(
"""
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
# - track boundary of (sharded) param, and loaded_weight, respectively
boundary, loaded_boundary = 0, 0
@@ -191,11 +203,12 @@ def mamba_v2_sharded_weight_loader(
# seem to handle slices well.
# https://github.com/python/mypy/issues/2410
param.data[
boundary:(boundary + take),
... # type: ignore[misc]
] = loaded_weight[loaded_start_idx:(loaded_start_idx +
take) # type: ignore[misc]
] # type: ignore[misc]
boundary : (boundary + take), ... # type: ignore[misc]
] = loaded_weight[
loaded_start_idx : (
loaded_start_idx + take
) # type: ignore[misc]
] # type: ignore[misc]
# move indexing boundaries
boundary += shard_size
@@ -217,23 +230,25 @@ class MambaMixer2(MambaBase, CustomOp):
**selective** state spaces)
"""
def __init__(self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(
self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
# For TP, the sharding plan is as follows:
@@ -253,15 +268,18 @@ class MambaMixer2(MambaBase, CustomOp):
self.tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
assert (num_heads % self.tp_size == 0
), "Tensor parallel world size must divide num heads."
assert num_heads % self.tp_size == 0, (
"Tensor parallel world size must divide num heads."
)
assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
"If tensor parallel world size does not divide num_groups, "
"then num_groups must equal 1.")
"then num_groups must equal 1."
)
assert (n_groups % self.tp_size == 0) or self.tp_size == 1 or \
quant_config is None, (
assert (
(n_groups % self.tp_size == 0) or self.tp_size == 1 or quant_config is None
), (
"Tensor parallel currently supported for quantized models only "
"if tensor parallel world size divides num groups."
)
@@ -280,7 +298,8 @@ class MambaMixer2(MambaBase, CustomOp):
# - but if n_groups cannot divide tp_size, we need to
# extend some extra groups
groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
n_groups, self.tp_size)
n_groups, self.tp_size
)
self.n_groups = n_groups + groups
self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
@@ -340,8 +359,7 @@ class MambaMixer2(MambaBase, CustomOp):
# to the head shards
group_shard_settings = (
self.groups_ssm_state_size, # expected model size
(self.n_groups - n_groups) *
self.ssm_state_size, # extra dims assigned
(self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned
n_groups == 1, # if there was only one group
)
intermediate_settings = (intermediate_size, 0, False)
@@ -355,8 +373,7 @@ class MambaMixer2(MambaBase, CustomOp):
set_weight_attrs(
self.conv1d.bias,
{
"weight_loader":
mamba_v2_sharded_weight_loader(
"weight_loader": mamba_v2_sharded_weight_loader(
[
intermediate_settings,
group_shard_settings,
@@ -372,8 +389,7 @@ class MambaMixer2(MambaBase, CustomOp):
set_weight_attrs(
self.conv1d.weight,
{
"weight_loader":
mamba_v2_sharded_weight_loader(
"weight_loader": mamba_v2_sharded_weight_loader(
[
intermediate_settings,
group_shard_settings,
@@ -391,8 +407,7 @@ class MambaMixer2(MambaBase, CustomOp):
set_weight_attrs(
self.in_proj.weight,
{
"weight_loader":
mamba_v2_sharded_weight_loader(
"weight_loader": mamba_v2_sharded_weight_loader(
[
intermediate_settings, # for gate
intermediate_settings,
@@ -418,17 +433,18 @@ class MambaMixer2(MambaBase, CustomOp):
torch.empty(
divide(num_heads, self.tp_size),
dtype=torch.float32,
))
)
)
self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
self.use_rms_norm = use_rms_norm
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
a_weight_loader = composed_weight_loader(
sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
sharded_weight_loader(0), lambda x: -torch.exp(x.float())
)
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
set_weight_attrs(self.dt_bias,
{"weight_loader": sharded_weight_loader(0)})
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
self.out_proj = RowParallelLinear(
intermediate_size,
@@ -439,10 +455,9 @@ class MambaMixer2(MambaBase, CustomOp):
prefix=f"{prefix}.out_proj",
)
self.norm = Mixer2RMSNormGated(intermediate_size,
n_groups,
self.use_rms_norm,
eps=rms_norm_eps)
self.norm = Mixer2RMSNormGated(
intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
@@ -525,8 +540,9 @@ class MambaMixer2(MambaBase, CustomOp):
dim=-1,
)
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
# - get hidden_states, B and C after depthwise convolution.
split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
@@ -541,10 +557,10 @@ class MambaMixer2(MambaBase, CustomOp):
if attn_metadata is None:
# profile run
hidden_states_B_C = (hidden_states_B_C.transpose(
0, 1).clone().transpose(0, 1)).contiguous()
hidden_states, _B, _C = split_hidden_states_B_C_fn(
hidden_states_B_C)
hidden_states_B_C = (
hidden_states_B_C.transpose(0, 1).clone().transpose(0, 1)
).contiguous()
hidden_states, _B, _C = split_hidden_states_B_C_fn(hidden_states_B_C)
hidden_states = self.norm(hidden_states, gate)
out, _ = self.out_proj(hidden_states)
return out
@@ -580,11 +596,11 @@ class MambaMixer2(MambaBase, CustomOp):
# If prefix caching is enabled, retrieve the relevant variables
# for prefill and decode
last_state_idx_d, last_state_idx_p = torch.split(
attn_metadata.last_state_idx, [num_decodes, num_prefills],
dim=0)
attn_metadata.last_state_idx, [num_decodes, num_prefills], dim=0
)
current_last_idx_d, current_last_idx_p = torch.split(
attn_metadata.current_last_idx, [num_decodes, num_prefills],
dim=0)
attn_metadata.current_last_idx, [num_decodes, num_prefills], dim=0
)
# Prefill-only variables:
current_first_idx_p = attn_metadata.current_first_idx_p
context_lens_p = attn_metadata.context_lens_p
@@ -600,7 +616,7 @@ class MambaMixer2(MambaBase, CustomOp):
preallocated_ssm_out = torch.empty(
[
num_prefill_tokens + num_decodes,
(self.num_heads // self.tp_size) * self.head_dim
(self.num_heads // self.tp_size) * self.head_dim,
],
dtype=hidden_states.dtype,
device=hidden_states.device,
@@ -626,7 +642,8 @@ class MambaMixer2(MambaBase, CustomOp):
# "state_indices_tensor_p"), it will write additional cache
# states aligned at "block_size_to_align".
x = hidden_states_B_C_p.transpose(
0, 1) # this is the form that causal-conv see
0, 1
) # this is the form that causal-conv see
hidden_states_B_C_p = causal_conv1d_fn(
x,
conv_weights,
@@ -641,34 +658,34 @@ class MambaMixer2(MambaBase, CustomOp):
context_lens=context_lens_p,
block_size_to_align=mamba_block_size,
metadata=attn_metadata,
query_start_loc=query_start_loc_p).transpose(
0, 1)[:num_prefill_tokens]
query_start_loc=query_start_loc_p,
).transpose(0, 1)[:num_prefill_tokens]
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(
hidden_states_B_C_p)
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(hidden_states_B_C_p)
# 3. State Space Model sequence transformation
initial_states = None
if (has_initial_states_p is not None and prep_initial_states):
if has_initial_states_p is not None and prep_initial_states:
kernel_ssm_indices = state_indices_tensor_p
if prefix_caching_enabled:
kernel_ssm_indices = state_indices_tensor_p.gather(
1, last_state_idx_p.unsqueeze(1)).squeeze(1)
1, last_state_idx_p.unsqueeze(1)
).squeeze(1)
initial_states = torch.where(
has_initial_states_p[:, None, None, None],
ssm_state[kernel_ssm_indices], 0)
ssm_state[kernel_ssm_indices],
0,
)
# NOTE: final output is an in-place update of out tensor
varlen_states = mamba_chunk_scan_combined_varlen(
hidden_states_p.view(num_prefill_tokens,
self.num_heads // self.tp_size,
self.head_dim),
hidden_states_p.view(
num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim
),
dt_p,
self.A,
B_p.view(num_prefill_tokens, self.n_groups // self.tp_size,
-1),
C_p.view(num_prefill_tokens, self.n_groups // self.tp_size,
-1),
B_p.view(num_prefill_tokens, self.n_groups // self.tp_size, -1),
C_p.view(num_prefill_tokens, self.n_groups // self.tp_size, -1),
chunk_size=chunk_size,
D=self.D,
z=None,
@@ -681,18 +698,19 @@ class MambaMixer2(MambaBase, CustomOp):
return_intermediate_states=prefix_caching_enabled,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1,
self.head_dim),
state_dtype=ssm_state.dtype)
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim),
state_dtype=ssm_state.dtype,
)
if prefix_caching_enabled:
# Save states for sequences with more than just the final state:
n_blocks_to_fill = current_last_idx_p - current_first_idx_p
for seq_idx in (n_blocks_to_fill > 0).nonzero().squeeze(1):
cache_blocks_to_fill = state_indices_tensor_p[
seq_idx, current_first_idx_p[seq_idx]:
current_first_idx_p[seq_idx] +
n_blocks_to_fill[seq_idx]]
seq_idx,
current_first_idx_p[seq_idx] : current_first_idx_p[seq_idx]
+ n_blocks_to_fill[seq_idx],
]
# chunks = [0 1 2 3 4 5 6 ...]
# First aligned chunk would typically be:
# mamba_block_size = 1024, chunk_size = 256
@@ -704,22 +722,33 @@ class MambaMixer2(MambaBase, CustomOp):
# e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2)
# e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3)
chunk_stride = mamba_block_size // chunk_size
first_aligned_chunk = \
torch.concat([torch.zeros(1, \
dtype=last_chunk_indices_p.dtype, \
device=last_chunk_indices_p.device), \
last_chunk_indices_p + 1])[seq_idx] \
+ chunk_stride - 1 \
- last_computed_offset_p[seq_idx] // chunk_size
first_aligned_chunk = (
torch.concat(
[
torch.zeros(
1,
dtype=last_chunk_indices_p.dtype,
device=last_chunk_indices_p.device,
),
last_chunk_indices_p + 1,
]
)[seq_idx]
+ chunk_stride
- 1
- last_computed_offset_p[seq_idx] // chunk_size
)
from_where = varlen_states[
first_aligned_chunk:first_aligned_chunk +
n_blocks_to_fill[seq_idx] * chunk_stride:chunk_stride]
first_aligned_chunk : first_aligned_chunk
+ n_blocks_to_fill[seq_idx] * chunk_stride : chunk_stride
]
ssm_state[cache_blocks_to_fill] = from_where
#For all seqs, store the last state (Note: might be partial):
ssm_state[state_indices_tensor_p.gather(1,
current_last_idx_p.unsqueeze(1)).squeeze(1)] = \
varlen_states[last_chunk_indices_p]
# For all seqs, store the last state (Note: might be partial):
ssm_state[
state_indices_tensor_p.gather(
1, current_last_idx_p.unsqueeze(1)
).squeeze(1)
] = varlen_states[last_chunk_indices_p]
else:
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate)
@@ -729,13 +758,13 @@ class MambaMixer2(MambaBase, CustomOp):
# Process decode requests
if has_decode:
if prefix_caching_enabled:
state_indices_tensor_d_input = \
state_indices_tensor_d.gather(1,
last_state_idx_d.unsqueeze(1)).squeeze(1)
state_indices_tensor_d_output = \
state_indices_tensor_d.gather(1,
current_last_idx_d.unsqueeze(1)).squeeze(1)
#Note:
state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, last_state_idx_d.unsqueeze(1)
).squeeze(1)
state_indices_tensor_d_output = state_indices_tensor_d.gather(
1, current_last_idx_d.unsqueeze(1)
).squeeze(1)
# Note:
# for decode always: current_first_idx_d == current_last_idx_d
# at block boundaries: current_first_idx_d > last_state_idx_d
else:
@@ -755,20 +784,23 @@ class MambaMixer2(MambaBase, CustomOp):
initial_state_idx=last_state_idx_d,
)
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(
hidden_states_B_C_d)
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)
# 3. State Space Model sequence transformation
n_groups = self.n_groups // self.tp_size
A_d = self.A[:, None, ...][:, :, None].expand(
-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
A_d = (
self.A[:, None, ...][:, :, None]
.expand(-1, self.head_dim, self.ssm_state_size)
.to(dtype=torch.float32)
)
dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim)
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
D_d = self.D[:, None, ...].expand(-1, self.head_dim)
B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
hidden_states_d = hidden_states_d.view(
-1, self.num_heads // self.tp_size, self.head_dim)
-1, self.num_heads // self.tp_size, self.head_dim
)
# - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected
@@ -787,16 +819,14 @@ class MambaMixer2(MambaBase, CustomOp):
dt_softplus=True,
state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output,
out=preallocated_ssm_out_d.view(num_decodes, -1,
self.head_dim),
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
)
# 4. gated MLP
# GatedRMSNorm internally applying SiLU to the gate
# SiLU is applied internally before normalization, unlike standard
# norm usage
hidden_states = self.norm(preallocated_ssm_out,
gate[:num_actual_tokens])
hidden_states = self.norm(preallocated_ssm_out, gate[:num_actual_tokens])
# 5. Final linear projection
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
@@ -826,8 +856,8 @@ class MambaMixer2(MambaBase, CustomOp):
return "mamba2"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.mamba2_attn import (
Mamba2AttentionBackend)
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
return Mamba2AttentionBackend
@@ -839,9 +869,7 @@ def mamba_mixer2(
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states,
output=output,
mup_vector=mup_vector)
self.forward_cuda(hidden_states=hidden_states, output=output, mup_vector=mup_vector)
def mamba_mixer2_fake(