Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -13,29 +13,35 @@ from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
causal_conv1d_fn,
|
||||
causal_conv1d_update,
|
||||
)
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.short_conv_attn import (
|
||||
ShortConvAttentionMetadata)
|
||||
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionMetadata
|
||||
|
||||
|
||||
@CustomOp.register("short_conv")
|
||||
class ShortConv(MambaBase, CustomOp):
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
dim: int,
|
||||
layer_idx: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
dim: int,
|
||||
layer_idx: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
@@ -72,7 +78,7 @@ class ShortConv(MambaBase, CustomOp):
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.kv_cache = (torch.tensor([]), )
|
||||
self.kv_cache = (torch.tensor([]),)
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
@@ -121,8 +127,9 @@ class ShortConv(MambaBase, CustomOp):
|
||||
|
||||
B, C, x = BCx.chunk(3, dim=-1)
|
||||
|
||||
conv_weights = self.conv.weight.view(self.conv.weight.size(0),
|
||||
self.conv.weight.size(2))
|
||||
conv_weights = self.conv.weight.view(
|
||||
self.conv.weight.size(0), self.conv.weight.size(2)
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# V1 profile run
|
||||
@@ -163,23 +170,26 @@ class ShortConv(MambaBase, CustomOp):
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (
|
||||
attn_metadata.query_start_loc[-num_prefills - 1:] -
|
||||
num_decodes if has_prefill else None)
|
||||
attn_metadata.query_start_loc[-num_prefills - 1 :] - num_decodes
|
||||
if has_prefill
|
||||
else None
|
||||
)
|
||||
|
||||
conv_output_list = []
|
||||
|
||||
if has_prefill:
|
||||
Bx_p = (B_p * x_p).transpose(0, 1)
|
||||
Bx = causal_conv1d_fn(Bx_p,
|
||||
conv_weights,
|
||||
self.conv.bias,
|
||||
activation=None,
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
metadata=attn_metadata,
|
||||
query_start_loc=query_start_loc_p).transpose(
|
||||
0, 1)[:num_prefill_tokens]
|
||||
Bx = causal_conv1d_fn(
|
||||
Bx_p,
|
||||
conv_weights,
|
||||
self.conv.bias,
|
||||
activation=None,
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
metadata=attn_metadata,
|
||||
query_start_loc=query_start_loc_p,
|
||||
).transpose(0, 1)[:num_prefill_tokens]
|
||||
|
||||
y = C_p * Bx
|
||||
conv_output_list.append(y)
|
||||
@@ -192,7 +202,8 @@ class ShortConv(MambaBase, CustomOp):
|
||||
conv_weights,
|
||||
self.conv.bias,
|
||||
activation=None,
|
||||
conv_state_indices=state_indices_tensor_d)
|
||||
conv_state_indices=state_indices_tensor_d,
|
||||
)
|
||||
y = C_d * Bx
|
||||
conv_output_list.insert(0, y)
|
||||
|
||||
@@ -222,8 +233,8 @@ class ShortConv(MambaBase, CustomOp):
|
||||
return "short_conv"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.short_conv_attn import (
|
||||
ShortConvAttentionBackend)
|
||||
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
|
||||
|
||||
return ShortConvAttentionBackend
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user