Implements dual-chunk-flash-attn backend for dual chunk attention with sparse attention support (#11844)
This commit is contained in:
@@ -150,6 +150,101 @@ def merge_attn_states(output: torch.Tensor,
|
||||
prefix_lse, suffix_output, suffix_lse)
|
||||
|
||||
|
||||
def convert_vertical_slash_indexes(
|
||||
q_seqlens: torch.Tensor, # [BATCH, ]
|
||||
kv_seqlens: torch.Tensor, # [BATCH, ]
|
||||
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
|
||||
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
|
||||
context_size: int,
|
||||
block_size_M: int,
|
||||
block_size_N: int,
|
||||
causal: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
batch_size = slash_indexes.size(0)
|
||||
num_heads = slash_indexes.size(1)
|
||||
nnz_slash = slash_indexes.size(2)
|
||||
nnz_vertical = vertical_indexes.size(2)
|
||||
num_rows = (context_size + block_size_M - 1) // block_size_M
|
||||
|
||||
block_count = torch.zeros(batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device)
|
||||
block_offset = torch.zeros(batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
nnz_slash,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device)
|
||||
column_count = torch.zeros(batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device)
|
||||
column_index = torch.zeros(batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
nnz_vertical,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device)
|
||||
|
||||
torch.ops._C.convert_vertical_slash_indexes(
|
||||
block_count, block_offset, column_count, column_index, q_seqlens,
|
||||
kv_seqlens, vertical_indexes, slash_indexes, context_size,
|
||||
block_size_M, block_size_N, causal)
|
||||
return block_count, block_offset, column_count, column_index
|
||||
|
||||
|
||||
def convert_vertical_slash_indexes_mergehead(
|
||||
q_seqlens: torch.Tensor, # [BATCH, ]
|
||||
kv_seqlens: torch.Tensor, # [BATCH, ]
|
||||
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
|
||||
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
|
||||
# [N_HEADS] : different head use different number of indices
|
||||
vertical_indices_count: torch.Tensor,
|
||||
slash_indices_count: torch.Tensor,
|
||||
context_size: int,
|
||||
block_size_M: int,
|
||||
block_size_N: int,
|
||||
causal: bool = True,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
batch_size = slash_indexes.size(0)
|
||||
num_heads = slash_indexes.size(1)
|
||||
nnz_slash = slash_indexes.size(2)
|
||||
nnz_vertical = vertical_indexes.size(2)
|
||||
num_rows = (context_size + block_size_M - 1) // block_size_M
|
||||
|
||||
block_count = torch.empty(batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device)
|
||||
block_offset = torch.empty(batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
nnz_slash,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device)
|
||||
column_count = torch.empty(batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device)
|
||||
column_index = torch.empty(batch_size,
|
||||
num_heads,
|
||||
num_rows,
|
||||
nnz_vertical,
|
||||
dtype=q_seqlens.dtype,
|
||||
device=q_seqlens.device)
|
||||
|
||||
torch.ops._C.convert_vertical_slash_indexes_mergehead(
|
||||
block_count, block_offset, column_count, column_index, q_seqlens,
|
||||
kv_seqlens, vertical_indexes, slash_indexes, vertical_indices_count,
|
||||
slash_indices_count, context_size, block_size_M, block_size_N, causal)
|
||||
return block_count, block_offset, column_count, column_index
|
||||
|
||||
|
||||
# pos encoding ops
|
||||
def rotary_embedding(
|
||||
positions: torch.Tensor,
|
||||
|
||||
1494
vllm/attention/backends/dual_chunk_flash_attn.py
Normal file
1494
vllm/attention/backends/dual_chunk_flash_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -929,6 +929,23 @@ class ModelConfig:
|
||||
"Number of experts in the model must be greater than 0 "
|
||||
"when expert parallelism is enabled.")
|
||||
|
||||
def verify_dual_chunk_attention_config(
|
||||
self,
|
||||
load_config: "LoadConfig",
|
||||
) -> None:
|
||||
if hasattr(self.hf_config, "dual_chunk_attention_config"):
|
||||
# Try loading the sparse attention config
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
get_sparse_attention_config)
|
||||
sparse_attn_config = get_sparse_attention_config(self, load_config)
|
||||
if sparse_attn_config:
|
||||
self.hf_config.dual_chunk_attention_config[
|
||||
"sparse_attention_config"] = sparse_attn_config
|
||||
if "sparse_attention_enabled" not in \
|
||||
self.hf_config.dual_chunk_attention_config:
|
||||
self.hf_config.dual_chunk_attention_config[
|
||||
"sparse_attention_enabled"] = True
|
||||
|
||||
def verify_async_output_proc(self, parallel_config, speculative_config,
|
||||
device_config) -> None:
|
||||
if not self.use_async_output_proc:
|
||||
@@ -4187,6 +4204,8 @@ class VllmConfig:
|
||||
self.speculative_config,
|
||||
self.device_config)
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
self.model_config.verify_dual_chunk_attention_config(
|
||||
self.load_config)
|
||||
|
||||
if self.cache_config is not None:
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
|
||||
@@ -37,8 +37,8 @@ from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (FlexibleArgumentParser, GiB_bytes, is_in_doc_build,
|
||||
is_in_ray_actor)
|
||||
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
|
||||
GiB_bytes, is_in_doc_build, is_in_ray_actor)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
@@ -983,6 +983,17 @@ class EngineArgs:
|
||||
|
||||
assert self.enable_chunked_prefill is not None
|
||||
|
||||
if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]:
|
||||
assert self.enforce_eager, (
|
||||
"Cuda graph is not supported with DualChunkFlashAttention. "
|
||||
"To run the model in eager mode, set 'enforce_eager=True' "
|
||||
"or use '--enforce-eager' in the CLI.")
|
||||
assert current_platform.is_cuda(), (
|
||||
"DualChunkFlashAttention is only supported on CUDA platform.")
|
||||
assert not use_v1, (
|
||||
"DualChunkFlashAttention is not supported on V1 engine. "
|
||||
"To run the model in V0 engine, try set 'VLLM_USE_V1=0'")
|
||||
|
||||
cache_config = CacheConfig(
|
||||
block_size=self.block_size,
|
||||
gpu_memory_utilization=self.gpu_memory_utilization,
|
||||
|
||||
@@ -1486,6 +1486,184 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
return updates
|
||||
|
||||
|
||||
@CustomOp.register("dual_chunk_rotary_embedding")
|
||||
class DualChunkRotaryEmbedding(CustomOp):
|
||||
"""Rotary positional embedding for Dual Chunk Attention."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
chunk_size: int,
|
||||
local_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.chunk_size = chunk_size
|
||||
self.local_size = local_size
|
||||
self.dtype = dtype
|
||||
self.device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
(q_cache, qc_cache, k_cache, qc_no_clamp_cache,
|
||||
q_inter_cache) = self._compute_cos_sin_cache()
|
||||
|
||||
self.register_buffer("cos_sin_q_cache", q_cache, persistent=False)
|
||||
self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False)
|
||||
self.register_buffer("cos_sin_k_cache", k_cache, persistent=False)
|
||||
self.register_buffer("cos_sin_qc_no_clamp_cache",
|
||||
qc_no_clamp_cache,
|
||||
persistent=False)
|
||||
self.register_buffer("cos_sin_q_inter_cache",
|
||||
q_inter_cache,
|
||||
persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
"""Compute the inverse frequency."""
|
||||
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
|
||||
# However, we use `torch.arange(..., dtype=torch.float)` instead to
|
||||
# avoid numerical issues with large base values (e.g., 10000000).
|
||||
# This may cause a slight numerical difference between the HF
|
||||
# implementation and ours.
|
||||
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||
# use CPU to compute the cache and then move it to GPU. However, we
|
||||
# create the cache on GPU for faster initialization. This may cause
|
||||
# a slight numerical difference between the HF implementation and ours.
|
||||
inv_freq = 1.0 / (base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
chunk_len = self.chunk_size - self.local_size
|
||||
q_t = torch.arange(chunk_len, dtype=torch.float)
|
||||
qc_t = (torch.arange(chunk_len, dtype=torch.float) +
|
||||
chunk_len).clamp(max=self.chunk_size)
|
||||
k_t = torch.arange(self.max_position_embeddings,
|
||||
dtype=torch.float) % chunk_len
|
||||
|
||||
# count from chunk_len, no clamp(self.chunk_size) restriction
|
||||
qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len
|
||||
# count from self.chunk_size for q_inter's rope
|
||||
q_inter_t = torch.arange(chunk_len,
|
||||
dtype=torch.float) + self.chunk_size
|
||||
|
||||
q_freqs = torch.outer(q_t, inv_freq)
|
||||
qc_freqs = torch.outer(qc_t, inv_freq)
|
||||
k_freqs = torch.outer(k_t, inv_freq)
|
||||
qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq)
|
||||
q_inter_freqs = torch.outer(q_inter_t, inv_freq)
|
||||
|
||||
q_cos = q_freqs.cos()
|
||||
q_sin = q_freqs.sin()
|
||||
qc_cos = qc_freqs.cos()
|
||||
qc_sin = qc_freqs.sin()
|
||||
k_cos = k_freqs.cos()
|
||||
k_sin = k_freqs.sin()
|
||||
|
||||
qc_no_clamp_cos = qc_no_clamp_freqs.cos()
|
||||
qc_no_clamp_sin = qc_no_clamp_freqs.sin()
|
||||
q_inter_cos = q_inter_freqs.cos()
|
||||
q_inter_sin = q_inter_freqs.sin()
|
||||
|
||||
q_cache = torch.cat((q_cos, q_sin), dim=-1).to(dtype=self.dtype,
|
||||
device=self.device)
|
||||
qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(dtype=self.dtype,
|
||||
device=self.device)
|
||||
k_cache = torch.cat((k_cos, k_sin), dim=-1).to(dtype=self.dtype,
|
||||
device=self.device)
|
||||
qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin),
|
||||
dim=-1).to(dtype=self.dtype,
|
||||
device=self.device)
|
||||
q_inter_cache = torch.cat((q_inter_cos, q_inter_sin),
|
||||
dim=-1).to(dtype=self.dtype,
|
||||
device=self.device)
|
||||
return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
query = query.view(*query.shape[:-1], -1, self.head_size)
|
||||
key = key.view(*key.shape[:-1], -1, self.head_size)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
if self.rotary_dim < self.head_size:
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
else:
|
||||
query_pass = None
|
||||
key_pass = None
|
||||
|
||||
positions_with_offsets = (torch.add(positions, offsets)
|
||||
if offsets is not None else positions)
|
||||
key = self._apply_rotary_embedding(
|
||||
self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass)
|
||||
chunk_len = self.chunk_size - self.local_size
|
||||
query = self._apply_rotary_embedding(
|
||||
self.cos_sin_q_cache[positions_with_offsets % chunk_len],
|
||||
query_rot, query_pass)
|
||||
query_succ = self._apply_rotary_embedding(
|
||||
self.cos_sin_qc_cache[positions_with_offsets % chunk_len],
|
||||
query_rot, query_pass)
|
||||
query_inter = self._apply_rotary_embedding(
|
||||
self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1),
|
||||
query_rot, query_pass)
|
||||
query_succ_critical = self._apply_rotary_embedding(
|
||||
self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len],
|
||||
query_rot, query_pass)
|
||||
query_inter_critical = self._apply_rotary_embedding(
|
||||
self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len],
|
||||
query_rot, query_pass)
|
||||
|
||||
# merge query into one tensor to simplify the interfaces
|
||||
query = torch.cat((
|
||||
query,
|
||||
query_succ,
|
||||
query_inter,
|
||||
query_succ_critical,
|
||||
query_inter_critical,
|
||||
),
|
||||
dim=-1)
|
||||
return query, key
|
||||
|
||||
def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass):
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
if self.is_neox_style:
|
||||
# NOTE(woosuk): Here we assume that the positions tensor has the
|
||||
# shape [batch_size, seq_len].
|
||||
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
|
||||
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
|
||||
else:
|
||||
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
|
||||
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
|
||||
hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin
|
||||
|
||||
if self.rotary_dim < self.head_size:
|
||||
hidden = torch.cat((hidden_rot, hidden_pass), dim=-1)
|
||||
else:
|
||||
hidden = hidden_rot
|
||||
return hidden.flatten(-2).squeeze(0)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
|
||||
s += f", max_position_embeddings={self.max_position_embeddings}"
|
||||
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
|
||||
s += f", chunk_size={self.chunk_size}, local_size={self.local_size}"
|
||||
return s
|
||||
|
||||
|
||||
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
||||
|
||||
|
||||
@@ -1498,6 +1676,7 @@ def get_rope(
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
|
||||
) -> RotaryEmbedding:
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
@@ -1510,14 +1689,35 @@ def get_rope(
|
||||
rope_scaling_args = tuple(rope_scaling_tuple.items())
|
||||
else:
|
||||
rope_scaling_args = None
|
||||
|
||||
if dual_chunk_attention_config is not None:
|
||||
dual_chunk_attention_tuple = {
|
||||
k: tuple(v) if isinstance(v, list) else v
|
||||
for k, v in dual_chunk_attention_config.items()
|
||||
if k != "sparse_attention_config"
|
||||
}
|
||||
dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
|
||||
else:
|
||||
dual_chunk_attention_args = None
|
||||
|
||||
if partial_rotary_factor < 1.0:
|
||||
rotary_dim = int(rotary_dim * partial_rotary_factor)
|
||||
key = (head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
rope_scaling_args, dtype)
|
||||
rope_scaling_args, dual_chunk_attention_args, dtype)
|
||||
if key in _ROPE_DICT:
|
||||
return _ROPE_DICT[key]
|
||||
|
||||
if not rope_scaling:
|
||||
if dual_chunk_attention_config is not None:
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in dual_chunk_attention_config.items()
|
||||
if k in ("chunk_size", "local_size")
|
||||
}
|
||||
rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
is_neox_style, dtype,
|
||||
**extra_kwargs)
|
||||
elif not rope_scaling:
|
||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style, dtype)
|
||||
else:
|
||||
|
||||
@@ -217,6 +217,39 @@ def get_quant_config(model_config: ModelConfig,
|
||||
return quant_cls.from_config(config)
|
||||
|
||||
|
||||
def get_sparse_attention_config(
|
||||
model_config: ModelConfig,
|
||||
load_config: LoadConfig,
|
||||
sparse_attention_config_filename: str = "sparse_attention_config.json",
|
||||
) -> Dict[str, Any]:
|
||||
model_name_or_path = model_config.model
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
if not is_local:
|
||||
# Download the config files.
|
||||
with get_lock(model_name_or_path, load_config.download_dir):
|
||||
hf_folder = snapshot_download(
|
||||
model_name_or_path,
|
||||
revision=model_config.revision,
|
||||
allow_patterns="*.json",
|
||||
cache_dir=load_config.download_dir,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
tqdm_class=DisabledTqdm,
|
||||
)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
config_file = os.path.join(hf_folder, sparse_attention_config_filename)
|
||||
if not os.path.exists(config_file):
|
||||
return {}
|
||||
|
||||
# Load the sparse attention config.
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
logger.info("Loaded sparse attention config from %s", config_file)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def download_weights_from_hf(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str],
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||
from typing import Iterable, Optional, Set, Tuple, Union
|
||||
from typing import Any, Iterable, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -53,7 +53,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
|
||||
is_pp_missing_parameter,
|
||||
extract_layer_index, is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
@@ -99,17 +99,20 @@ class Qwen2MLP(nn.Module):
|
||||
|
||||
class Qwen2Attention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
rope_scaling: Optional[Tuple] = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
rope_scaling: Optional[Tuple] = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
dual_chunk_attention_config: Optional[dict[str,
|
||||
Any]] = None) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
@@ -131,6 +134,7 @@ class Qwen2Attention(nn.Module):
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.dual_chunk_attention_config = dual_chunk_attention_config
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
@@ -155,15 +159,21 @@ class Qwen2Attention(nn.Module):
|
||||
max_position=max_position,
|
||||
base=self.rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
attn_type=attn_type)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
attn_type=attn_type,
|
||||
prefix=f"{prefix}.attn",
|
||||
**{
|
||||
"layer_idx": extract_layer_index(prefix),
|
||||
"dual_chunk_attention_config": dual_chunk_attention_config,
|
||||
} if dual_chunk_attention_config else {})
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -192,6 +202,9 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
dual_chunk_attention_config = getattr(config,
|
||||
"dual_chunk_attention_config",
|
||||
None)
|
||||
|
||||
# By default, Qwen2 uses causal attention as it is a decoder-only model.
|
||||
# You can override the HF config with `is_causal=False` to enable
|
||||
@@ -213,6 +226,7 @@ class Qwen2DecoderLayer(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_type=attn_type,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.mlp = Qwen2MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
|
||||
@@ -175,6 +175,7 @@ class Qwen2MoeAttention(nn.Module):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -198,6 +199,7 @@ class Qwen2MoeAttention(nn.Module):
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.dual_chunk_attention_config = dual_chunk_attention_config
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
@@ -221,14 +223,20 @@ class Qwen2MoeAttention(nn.Module):
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
**{
|
||||
"layer_idx": extract_layer_index(prefix),
|
||||
"dual_chunk_attention_config": dual_chunk_attention_config,
|
||||
} if dual_chunk_attention_config else {})
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -256,6 +264,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
dual_chunk_attention_config = getattr(config,
|
||||
"dual_chunk_attention_config",
|
||||
None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
self.self_attn = Qwen2MoeAttention(
|
||||
@@ -268,6 +279,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
|
||||
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
|
||||
|
||||
@@ -222,6 +222,10 @@ class CudaPlatformBase(Platform):
|
||||
elif selected_backend == _Backend.XFORMERS:
|
||||
logger.info("Using XFormers backend.")
|
||||
return "vllm.attention.backends.xformers.XFormersBackend"
|
||||
elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN:
|
||||
logger.info("Using DualChunkFlashAttention backend.")
|
||||
return ("vllm.attention.backends.dual_chunk_flash_attn."
|
||||
"DualChunkFlashAttentionBackend")
|
||||
elif selected_backend == _Backend.FLASH_ATTN:
|
||||
pass
|
||||
elif selected_backend:
|
||||
|
||||
@@ -51,6 +51,7 @@ class _Backend(enum.Enum):
|
||||
PALLAS_VLLM_V1 = enum.auto()
|
||||
IPEX = enum.auto()
|
||||
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
|
||||
DUAL_CHUNK_FLASH_ATTN = enum.auto()
|
||||
NO_ATTENTION = enum.auto()
|
||||
|
||||
|
||||
|
||||
@@ -153,6 +153,7 @@ STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
|
||||
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
|
||||
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
|
||||
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
||||
STR_DUAL_CHUNK_FLASH_ATTN_VAL: str = "DUAL_CHUNK_FLASH_ATTN"
|
||||
STR_INVALID_VAL: str = "INVALID"
|
||||
|
||||
GB_bytes = 1_000_000_000
|
||||
|
||||
@@ -204,6 +204,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
self.mrope_input_positions = None # type: ignore
|
||||
self.seq_lens[0] = 0 # type: ignore
|
||||
self.orig_seq_lens[0] = 0 # type: ignore
|
||||
self.prompt_lens[0] = 0 # type: ignore
|
||||
self.query_lens[0] = 0 # type: ignore
|
||||
self.context_lens[0] = 0 # type: ignore
|
||||
self.curr_sliding_window_blocks[0] = 0 # type: ignore
|
||||
@@ -236,6 +237,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
# The original sequence length (before applying sliding window).
|
||||
# This is used to compute slot mapping.
|
||||
orig_seq_lens: Optional[List[int]] = None,
|
||||
# This is used in the dual-chunk flash attention backend.
|
||||
prompt_lens: Optional[List[int]] = None,
|
||||
# The query length.
|
||||
query_lens: Optional[List[int]] = None,
|
||||
# The number of tokens that are already computed.
|
||||
@@ -316,6 +319,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.orig_seq_lens[seq_id] = 0
|
||||
|
||||
if prompt_lens:
|
||||
self.prompt_lens = prompt_lens
|
||||
else:
|
||||
for seq_id in range(len(self.seq_ids)):
|
||||
self.prompt_lens[seq_id] = 0
|
||||
|
||||
if query_lens:
|
||||
self.query_lens = query_lens
|
||||
else:
|
||||
@@ -370,6 +379,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
self.mrope_input_positions = mrope_input_positions or None
|
||||
self.seq_lens = seq_lens or []
|
||||
self.orig_seq_lens = orig_seq_lens or []
|
||||
self.prompt_lens = prompt_lens or []
|
||||
self.query_lens = query_lens or []
|
||||
self.context_lens = context_lens or []
|
||||
self.curr_sliding_window_blocks = \
|
||||
@@ -403,6 +413,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
self.mrope_input_positions = None
|
||||
self.seq_lens = [0] * self.n_seqs
|
||||
self.orig_seq_lens = [0] * self.n_seqs
|
||||
self.prompt_lens = [0] * self.n_seqs
|
||||
self.query_lens = [0] * self.n_seqs
|
||||
self.context_lens = [0] * self.n_seqs
|
||||
self.curr_sliding_window_blocks = [0] * self.n_seqs
|
||||
@@ -552,6 +563,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
|
||||
inter_data.seq_lens[seq_idx] = seq_len
|
||||
inter_data.orig_seq_lens[seq_idx] = seq_len
|
||||
inter_data.prompt_lens[seq_idx] = seq_data.get_prompt_len()
|
||||
inter_data.context_lens[seq_idx] = context_len
|
||||
inter_data.input_tokens[seq_idx].extend(tokens)
|
||||
inter_data.inputs_embeds = prompt_embeds
|
||||
|
||||
Reference in New Issue
Block a user