Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -12,20 +12,30 @@ from torch.nn.parameter import Parameter
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
causal_conv1d_fn,
causal_conv1d_update,
)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
selective_scan_fn,
selective_state_update,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
@@ -44,22 +54,24 @@ class MambaMixer(MambaBase, CustomOp):
**selective** state spaces)
"""
def __init__(self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
time_step_rank: int,
use_conv_bias: bool,
use_bias: bool,
use_rms_norm: bool,
rms_norm_has_weight: bool = True,
rms_norm_eps: float = 1e-5,
activation="silu",
is_lora_enabled: bool = False,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
prefix: str = ""):
def __init__(
self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
time_step_rank: int,
use_conv_bias: bool,
use_bias: bool,
use_rms_norm: bool,
rms_norm_has_weight: bool = True,
rms_norm_eps: float = 1e-5,
activation="silu",
is_lora_enabled: bool = False,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
):
super().__init__()
self.time_step_rank = time_step_rank
self.ssm_state_size = ssm_state_size
@@ -80,9 +92,9 @@ class MambaMixer(MambaBase, CustomOp):
# doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
self.in_proj = MergedColumnParallelLinear(hidden_size,
[intermediate_size] * 2,
bias=use_bias)
self.in_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=use_bias
)
# selective projection used to make dt, B and C input dependent
self.x_proj = RowParallelLinear(
@@ -93,17 +105,18 @@ class MambaMixer(MambaBase, CustomOp):
# time step projection (discretization) -
# In the forward we need to apply dt_proj without the bias,
# as the bias is added in the selective scan kernel.
self.dt_proj = ColumnParallelLinear(time_step_rank,
intermediate_size,
bias=True,
skip_bias_add=True)
self.dt_proj = ColumnParallelLinear(
time_step_rank, intermediate_size, bias=True, skip_bias_add=True
)
def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
param.data.copy_(
loaded_weight.data.split(loaded_weight.shape[0] // tp_size,
dim=0)[tp_rank])
loaded_weight.data.split(loaded_weight.shape[0] // tp_size, dim=0)[
tp_rank
]
)
def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
weight_loader(param, -torch.exp(loaded_weight.float()))
@@ -114,7 +127,8 @@ class MambaMixer(MambaBase, CustomOp):
intermediate_size // tp_size,
ssm_state_size,
dtype=torch.float32,
))
)
)
self.D = nn.Parameter(torch.ones(intermediate_size // tp_size))
set_weight_attrs(self.D, {"weight_loader": weight_loader})
@@ -127,23 +141,35 @@ class MambaMixer(MambaBase, CustomOp):
input_is_parallel=True,
)
self.dt_layernorm = RMSNorm(
time_step_rank,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
) if use_rms_norm else None
self.dt_layernorm = (
RMSNorm(
time_step_rank,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
)
if use_rms_norm
else None
)
self.b_layernorm = RMSNorm(
ssm_state_size,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
) if use_rms_norm else None
self.b_layernorm = (
RMSNorm(
ssm_state_size,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
)
if use_rms_norm
else None
)
self.c_layernorm = RMSNorm(
ssm_state_size,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
) if use_rms_norm else None
self.c_layernorm = (
RMSNorm(
ssm_state_size,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
)
if use_rms_norm
else None
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
@@ -157,7 +183,7 @@ class MambaMixer(MambaBase, CustomOp):
self.prefix = prefix
def _ssm_transform(
self, x: torch.Tensor
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if self.is_lora_enabled:
# Lora kernel requires contiguous tensor.
@@ -167,7 +193,8 @@ class MambaMixer(MambaBase, CustomOp):
time_step, B, C = torch.split(
ssm_params,
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
dim=-1)
dim=-1,
)
if self.use_rms_norm:
assert self.dt_layernorm is not None
assert self.b_layernorm is not None
@@ -185,8 +212,7 @@ class MambaMixer(MambaBase, CustomOp):
self.prefix,
)
def forward_native(self, hidden_states: torch.Tensor,
output: torch.Tensor):
def forward_native(self, hidden_states: torch.Tensor, output: torch.Tensor):
pass
def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor):
@@ -232,8 +258,9 @@ class MambaMixer(MambaBase, CustomOp):
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
hidden_states_BC, gate = projected_states.chunk(2, dim=-2)
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
if attn_metadata is None:
# V1 profile run
@@ -281,10 +308,12 @@ class MambaMixer(MambaBase, CustomOp):
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
query_start_loc=query_start_loc_p)
query_start_loc=query_start_loc_p,
)
# 3. State Space Model sequence transformations.
discrete_time_step_p, B_p, C_p = self._ssm_transform(
conv_out_p.transpose(-2, -1))
conv_out_p.transpose(-2, -1)
)
time_proj_bias = self._time_proj_bias()
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
@@ -301,7 +330,8 @@ class MambaMixer(MambaBase, CustomOp):
delta_softplus=True,
cache_indices=state_indices_tensor_p,
has_initial_state=has_initial_states_p,
query_start_loc=query_start_loc_p)
query_start_loc=query_start_loc_p,
)
ssm_outputs.append(scan_out_p)
if has_decode:
@@ -312,39 +342,42 @@ class MambaMixer(MambaBase, CustomOp):
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor_d).transpose(0, 1)
conv_state_indices=state_indices_tensor_d,
).transpose(0, 1)
# 3. State Space Model sequence transformation.
discrete_time_step_d, B_d, C_d = self._ssm_transform(
conv_out_d.transpose(-2, -1))
conv_out_d.transpose(-2, -1)
)
time_proj_bias = self._time_proj_bias()
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
scan_outputs_d = torch.empty_like(
hidden_states_BC_d.transpose(0, 1))
selective_state_update(ssm_state,
conv_out_d.transpose(0, 1),
discrete_time_step_d.transpose(0, 1),
self.A,
B_d,
C_d,
self.D,
gate_d.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor_d,
out=scan_outputs_d)
scan_outputs_d = torch.empty_like(hidden_states_BC_d.transpose(0, 1))
selective_state_update(
ssm_state,
conv_out_d.transpose(0, 1),
discrete_time_step_d.transpose(0, 1),
self.A,
B_d,
C_d,
self.D,
gate_d.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor_d,
out=scan_outputs_d,
)
scan_outputs_d = scan_outputs_d.transpose(0, 1)
ssm_outputs.insert(0, scan_outputs_d)
scan_outputs_combined = ssm_outputs[0] if len(
ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
scan_outputs_combined = (
ssm_outputs[0] if len(ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
)
# 5. Final output projection
if self.is_lora_enabled: # Lora kernel requires contiguous tensor.
scan_outputs_combined = scan_outputs_combined.transpose(
-2, -1).contiguous()
scan_outputs_combined = scan_outputs_combined.transpose(-2, -1).contiguous()
out = self.out_proj(scan_outputs_combined)[0]
else:
out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0]
@@ -373,8 +406,8 @@ class MambaMixer(MambaBase, CustomOp):
return "mamba1"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.mamba1_attn import (
Mamba1AttentionBackend)
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
return Mamba1AttentionBackend
def _time_proj_bias(self) -> Optional[torch.Tensor]:
@@ -406,27 +439,34 @@ def split_batch_to_prefill_and_decode(
num_decodes: int,
num_padded_decodes: int,
) -> PrefillDecodeSplit:
num_actual_tokens = num_prefill_tokens + num_padded_decodes
# In v1, decode tokens come first, then prefill tokens.
hidden_states_BC_d, hidden_states_BC_p = torch.split(
hidden_states_BC[..., :num_actual_tokens],
[num_padded_decodes, num_prefill_tokens],
dim=-1)
gate_d, gate_p = torch.split(gate[..., :num_actual_tokens],
[num_padded_decodes, num_prefill_tokens],
dim=-1)
dim=-1,
)
gate_d, gate_p = torch.split(
gate[..., :num_actual_tokens], [num_padded_decodes, num_prefill_tokens], dim=-1
)
# num_padded_decodes accounts for CUDA graph padding when applicable
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor[:num_padded_decodes + num_prefills],
state_indices_tensor[: num_padded_decodes + num_prefills],
[num_padded_decodes, num_prefills],
dim=0)
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
num_padded_decodes if num_prefills > 0 else None)
has_initial_states_p = has_initial_states[-num_prefills:] if (
has_initial_states is not None and num_prefills > 0) else None
dim=0,
)
query_start_loc_p = (
query_start_loc[-num_prefills - 1 :] - num_padded_decodes
if num_prefills > 0
else None
)
has_initial_states_p = (
has_initial_states[-num_prefills:]
if (has_initial_states is not None and num_prefills > 0)
else None
)
return PrefillDecodeSplit(
hidden_states_BC_p=hidden_states_BC_p,