Implements dual-chunk-flash-attn backend for dual chunk attention with sparse attention support (#11844)

This commit is contained in:
Tao He
2025-05-13 10:52:47 +08:00
committed by GitHub
parent f6518b2b48
commit 60f7624334
17 changed files with 2444 additions and 32 deletions

View File

@@ -150,6 +150,101 @@ def merge_attn_states(output: torch.Tensor,
prefix_lse, suffix_output, suffix_lse)
def convert_vertical_slash_indexes(
q_seqlens: torch.Tensor, # [BATCH, ]
kv_seqlens: torch.Tensor, # [BATCH, ]
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
context_size: int,
block_size_M: int,
block_size_N: int,
causal: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = slash_indexes.size(0)
num_heads = slash_indexes.size(1)
nnz_slash = slash_indexes.size(2)
nnz_vertical = vertical_indexes.size(2)
num_rows = (context_size + block_size_M - 1) // block_size_M
block_count = torch.zeros(batch_size,
num_heads,
num_rows,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
block_offset = torch.zeros(batch_size,
num_heads,
num_rows,
nnz_slash,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
column_count = torch.zeros(batch_size,
num_heads,
num_rows,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
column_index = torch.zeros(batch_size,
num_heads,
num_rows,
nnz_vertical,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
torch.ops._C.convert_vertical_slash_indexes(
block_count, block_offset, column_count, column_index, q_seqlens,
kv_seqlens, vertical_indexes, slash_indexes, context_size,
block_size_M, block_size_N, causal)
return block_count, block_offset, column_count, column_index
def convert_vertical_slash_indexes_mergehead(
q_seqlens: torch.Tensor, # [BATCH, ]
kv_seqlens: torch.Tensor, # [BATCH, ]
vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V]
slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S]
# [N_HEADS] : different head use different number of indices
vertical_indices_count: torch.Tensor,
slash_indices_count: torch.Tensor,
context_size: int,
block_size_M: int,
block_size_N: int,
causal: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = slash_indexes.size(0)
num_heads = slash_indexes.size(1)
nnz_slash = slash_indexes.size(2)
nnz_vertical = vertical_indexes.size(2)
num_rows = (context_size + block_size_M - 1) // block_size_M
block_count = torch.empty(batch_size,
num_heads,
num_rows,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
block_offset = torch.empty(batch_size,
num_heads,
num_rows,
nnz_slash,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
column_count = torch.empty(batch_size,
num_heads,
num_rows,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
column_index = torch.empty(batch_size,
num_heads,
num_rows,
nnz_vertical,
dtype=q_seqlens.dtype,
device=q_seqlens.device)
torch.ops._C.convert_vertical_slash_indexes_mergehead(
block_count, block_offset, column_count, column_index, q_seqlens,
kv_seqlens, vertical_indexes, slash_indexes, vertical_indices_count,
slash_indices_count, context_size, block_size_M, block_size_N, causal)
return block_count, block_offset, column_count, column_index
# pos encoding ops
def rotary_embedding(
positions: torch.Tensor,

File diff suppressed because it is too large Load Diff

View File

@@ -929,6 +929,23 @@ class ModelConfig:
"Number of experts in the model must be greater than 0 "
"when expert parallelism is enabled.")
def verify_dual_chunk_attention_config(
self,
load_config: "LoadConfig",
) -> None:
if hasattr(self.hf_config, "dual_chunk_attention_config"):
# Try loading the sparse attention config
from vllm.model_executor.model_loader.weight_utils import (
get_sparse_attention_config)
sparse_attn_config = get_sparse_attention_config(self, load_config)
if sparse_attn_config:
self.hf_config.dual_chunk_attention_config[
"sparse_attention_config"] = sparse_attn_config
if "sparse_attention_enabled" not in \
self.hf_config.dual_chunk_attention_config:
self.hf_config.dual_chunk_attention_config[
"sparse_attention_enabled"] = True
def verify_async_output_proc(self, parallel_config, speculative_config,
device_config) -> None:
if not self.use_async_output_proc:
@@ -4187,6 +4204,8 @@ class VllmConfig:
self.speculative_config,
self.device_config)
self.model_config.verify_with_parallel_config(self.parallel_config)
self.model_config.verify_dual_chunk_attention_config(
self.load_config)
if self.cache_config is not None:
self.cache_config.verify_with_parallel_config(self.parallel_config)

View File

@@ -37,8 +37,8 @@ from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext
from vllm.utils import (FlexibleArgumentParser, GiB_bytes, is_in_doc_build,
is_in_ray_actor)
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
GiB_bytes, is_in_doc_build, is_in_ray_actor)
# yapf: enable
@@ -983,6 +983,17 @@ class EngineArgs:
assert self.enable_chunked_prefill is not None
if envs.VLLM_ATTENTION_BACKEND in [STR_DUAL_CHUNK_FLASH_ATTN_VAL]:
assert self.enforce_eager, (
"Cuda graph is not supported with DualChunkFlashAttention. "
"To run the model in eager mode, set 'enforce_eager=True' "
"or use '--enforce-eager' in the CLI.")
assert current_platform.is_cuda(), (
"DualChunkFlashAttention is only supported on CUDA platform.")
assert not use_v1, (
"DualChunkFlashAttention is not supported on V1 engine. "
"To run the model in V0 engine, try set 'VLLM_USE_V1=0'")
cache_config = CacheConfig(
block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization,

View File

@@ -1486,6 +1486,184 @@ class MRotaryEmbedding(RotaryEmbedding):
return updates
@CustomOp.register("dual_chunk_rotary_embedding")
class DualChunkRotaryEmbedding(CustomOp):
"""Rotary positional embedding for Dual Chunk Attention."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
chunk_size: int,
local_size: int,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.chunk_size = chunk_size
self.local_size = local_size
self.dtype = dtype
self.device = torch.device(f"cuda:{torch.cuda.current_device()}")
(q_cache, qc_cache, k_cache, qc_no_clamp_cache,
q_inter_cache) = self._compute_cos_sin_cache()
self.register_buffer("cos_sin_q_cache", q_cache, persistent=False)
self.register_buffer("cos_sin_qc_cache", qc_cache, persistent=False)
self.register_buffer("cos_sin_k_cache", k_cache, persistent=False)
self.register_buffer("cos_sin_qc_no_clamp_cache",
qc_no_clamp_cache,
persistent=False)
self.register_buffer("cos_sin_q_inter_cache",
q_inter_cache,
persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
chunk_len = self.chunk_size - self.local_size
q_t = torch.arange(chunk_len, dtype=torch.float)
qc_t = (torch.arange(chunk_len, dtype=torch.float) +
chunk_len).clamp(max=self.chunk_size)
k_t = torch.arange(self.max_position_embeddings,
dtype=torch.float) % chunk_len
# count from chunk_len, no clamp(self.chunk_size) restriction
qc_no_clamp_t = torch.arange(chunk_len, dtype=torch.float) + chunk_len
# count from self.chunk_size for q_inter's rope
q_inter_t = torch.arange(chunk_len,
dtype=torch.float) + self.chunk_size
q_freqs = torch.outer(q_t, inv_freq)
qc_freqs = torch.outer(qc_t, inv_freq)
k_freqs = torch.outer(k_t, inv_freq)
qc_no_clamp_freqs = torch.outer(qc_no_clamp_t, inv_freq)
q_inter_freqs = torch.outer(q_inter_t, inv_freq)
q_cos = q_freqs.cos()
q_sin = q_freqs.sin()
qc_cos = qc_freqs.cos()
qc_sin = qc_freqs.sin()
k_cos = k_freqs.cos()
k_sin = k_freqs.sin()
qc_no_clamp_cos = qc_no_clamp_freqs.cos()
qc_no_clamp_sin = qc_no_clamp_freqs.sin()
q_inter_cos = q_inter_freqs.cos()
q_inter_sin = q_inter_freqs.sin()
q_cache = torch.cat((q_cos, q_sin), dim=-1).to(dtype=self.dtype,
device=self.device)
qc_cache = torch.cat((qc_cos, qc_sin), dim=-1).to(dtype=self.dtype,
device=self.device)
k_cache = torch.cat((k_cos, k_sin), dim=-1).to(dtype=self.dtype,
device=self.device)
qc_no_clamp_cache = torch.cat((qc_no_clamp_cos, qc_no_clamp_sin),
dim=-1).to(dtype=self.dtype,
device=self.device)
q_inter_cache = torch.cat((q_inter_cos, q_inter_sin),
dim=-1).to(dtype=self.dtype,
device=self.device)
return q_cache, qc_cache, k_cache, qc_no_clamp_cache, q_inter_cache
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
else:
query_pass = None
key_pass = None
positions_with_offsets = (torch.add(positions, offsets)
if offsets is not None else positions)
key = self._apply_rotary_embedding(
self.cos_sin_k_cache[positions_with_offsets], key_rot, key_pass)
chunk_len = self.chunk_size - self.local_size
query = self._apply_rotary_embedding(
self.cos_sin_q_cache[positions_with_offsets % chunk_len],
query_rot, query_pass)
query_succ = self._apply_rotary_embedding(
self.cos_sin_qc_cache[positions_with_offsets % chunk_len],
query_rot, query_pass)
query_inter = self._apply_rotary_embedding(
self.cos_sin_qc_cache[chunk_len - 1].repeat(positions.shape[0], 1),
query_rot, query_pass)
query_succ_critical = self._apply_rotary_embedding(
self.cos_sin_qc_no_clamp_cache[positions_with_offsets % chunk_len],
query_rot, query_pass)
query_inter_critical = self._apply_rotary_embedding(
self.cos_sin_q_inter_cache[positions_with_offsets % chunk_len],
query_rot, query_pass)
# merge query into one tensor to simplify the interfaces
query = torch.cat((
query,
query_succ,
query_inter,
query_succ_critical,
query_inter_critical,
),
dim=-1)
return query, key
def _apply_rotary_embedding(self, cos_sin, hidden_rot, hidden_pass):
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
hidden_rot = hidden_rot * cos + rotate_fn(hidden_rot) * sin
if self.rotary_dim < self.head_size:
hidden = torch.cat((hidden_rot, hidden_pass), dim=-1)
else:
hidden = hidden_rot
return hidden.flatten(-2).squeeze(0)
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}"
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
s += f", chunk_size={self.chunk_size}, local_size={self.local_size}"
return s
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
@@ -1498,6 +1676,7 @@ def get_rope(
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
@@ -1510,14 +1689,35 @@ def get_rope(
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
if dual_chunk_attention_config is not None:
dual_chunk_attention_tuple = {
k: tuple(v) if isinstance(v, list) else v
for k, v in dual_chunk_attention_config.items()
if k != "sparse_attention_config"
}
dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
else:
dual_chunk_attention_args = None
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
key = (head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling_args, dtype)
rope_scaling_args, dual_chunk_attention_args, dtype)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if not rope_scaling:
if dual_chunk_attention_config is not None:
extra_kwargs = {
k: v
for k, v in dual_chunk_attention_config.items()
if k in ("chunk_size", "local_size")
}
rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style, dtype,
**extra_kwargs)
elif not rope_scaling:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype)
else:

View File

@@ -217,6 +217,39 @@ def get_quant_config(model_config: ModelConfig,
return quant_cls.from_config(config)
def get_sparse_attention_config(
model_config: ModelConfig,
load_config: LoadConfig,
sparse_attention_config_filename: str = "sparse_attention_config.json",
) -> Dict[str, Any]:
model_name_or_path = model_config.model
is_local = os.path.isdir(model_name_or_path)
if not is_local:
# Download the config files.
with get_lock(model_name_or_path, load_config.download_dir):
hf_folder = snapshot_download(
model_name_or_path,
revision=model_config.revision,
allow_patterns="*.json",
cache_dir=load_config.download_dir,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
tqdm_class=DisabledTqdm,
)
else:
hf_folder = model_name_or_path
config_file = os.path.join(hf_folder, sparse_attention_config_filename)
if not os.path.exists(config_file):
return {}
# Load the sparse attention config.
with open(config_file) as f:
config = json.load(f)
logger.info("Loaded sparse attention config from %s", config_file)
return config
def download_weights_from_hf(
model_name_or_path: str,
cache_dir: Optional[str],

View File

@@ -23,7 +23,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Set, Tuple, Union
from typing import Any, Iterable, Optional, Set, Tuple, Union
import torch
from torch import nn
@@ -53,7 +53,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
is_pp_missing_parameter,
extract_layer_index, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
@@ -99,17 +99,20 @@ class Qwen2MLP(nn.Module):
class Qwen2Attention(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[Tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER) -> None:
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[Tuple] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
dual_chunk_attention_config: Optional[dict[str,
Any]] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
@@ -131,6 +134,7 @@ class Qwen2Attention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.dual_chunk_attention_config = dual_chunk_attention_config
self.qkv_proj = QKVParallelLinear(
hidden_size,
@@ -155,15 +159,21 @@ class Qwen2Attention(nn.Module):
max_position=max_position,
base=self.rope_theta,
rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=attn_type)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
attn_type=attn_type,
prefix=f"{prefix}.attn",
**{
"layer_idx": extract_layer_index(prefix),
"dual_chunk_attention_config": dual_chunk_attention_config,
} if dual_chunk_attention_config else {})
def forward(
self,
@@ -192,6 +202,9 @@ class Qwen2DecoderLayer(nn.Module):
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
dual_chunk_attention_config = getattr(config,
"dual_chunk_attention_config",
None)
# By default, Qwen2 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
@@ -213,6 +226,7 @@ class Qwen2DecoderLayer(nn.Module):
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,

View File

@@ -175,6 +175,7 @@ class Qwen2MoeAttention(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
dual_chunk_attention_config: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
@@ -198,6 +199,7 @@ class Qwen2MoeAttention(nn.Module):
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.dual_chunk_attention_config = dual_chunk_attention_config
self.qkv_proj = QKVParallelLinear(
hidden_size,
@@ -221,14 +223,20 @@ class Qwen2MoeAttention(nn.Module):
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
**{
"layer_idx": extract_layer_index(prefix),
"dual_chunk_attention_config": dual_chunk_attention_config,
} if dual_chunk_attention_config else {})
def forward(
self,
@@ -256,6 +264,9 @@ class Qwen2MoeDecoderLayer(nn.Module):
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
dual_chunk_attention_config = getattr(config,
"dual_chunk_attention_config",
None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = Qwen2MoeAttention(
@@ -268,6 +279,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
dual_chunk_attention_config=dual_chunk_attention_config,
)
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have

View File

@@ -222,6 +222,10 @@ class CudaPlatformBase(Platform):
elif selected_backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
return "vllm.attention.backends.xformers.XFormersBackend"
elif selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN:
logger.info("Using DualChunkFlashAttention backend.")
return ("vllm.attention.backends.dual_chunk_flash_attn."
"DualChunkFlashAttentionBackend")
elif selected_backend == _Backend.FLASH_ATTN:
pass
elif selected_backend:

View File

@@ -51,6 +51,7 @@ class _Backend(enum.Enum):
PALLAS_VLLM_V1 = enum.auto()
IPEX = enum.auto()
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
DUAL_CHUNK_FLASH_ATTN = enum.auto()
NO_ATTENTION = enum.auto()

View File

@@ -153,6 +153,7 @@ STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA"
STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH"
STR_XFORMERS_ATTN_VAL: str = "XFORMERS"
STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
STR_DUAL_CHUNK_FLASH_ATTN_VAL: str = "DUAL_CHUNK_FLASH_ATTN"
STR_INVALID_VAL: str = "INVALID"
GB_bytes = 1_000_000_000

View File

@@ -204,6 +204,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.mrope_input_positions = None # type: ignore
self.seq_lens[0] = 0 # type: ignore
self.orig_seq_lens[0] = 0 # type: ignore
self.prompt_lens[0] = 0 # type: ignore
self.query_lens[0] = 0 # type: ignore
self.context_lens[0] = 0 # type: ignore
self.curr_sliding_window_blocks[0] = 0 # type: ignore
@@ -236,6 +237,8 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
# The original sequence length (before applying sliding window).
# This is used to compute slot mapping.
orig_seq_lens: Optional[List[int]] = None,
# This is used in the dual-chunk flash attention backend.
prompt_lens: Optional[List[int]] = None,
# The query length.
query_lens: Optional[List[int]] = None,
# The number of tokens that are already computed.
@@ -316,6 +319,12 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
for seq_id in range(len(self.seq_ids)):
self.orig_seq_lens[seq_id] = 0
if prompt_lens:
self.prompt_lens = prompt_lens
else:
for seq_id in range(len(self.seq_ids)):
self.prompt_lens[seq_id] = 0
if query_lens:
self.query_lens = query_lens
else:
@@ -370,6 +379,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.mrope_input_positions = mrope_input_positions or None
self.seq_lens = seq_lens or []
self.orig_seq_lens = orig_seq_lens or []
self.prompt_lens = prompt_lens or []
self.query_lens = query_lens or []
self.context_lens = context_lens or []
self.curr_sliding_window_blocks = \
@@ -403,6 +413,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
self.mrope_input_positions = None
self.seq_lens = [0] * self.n_seqs
self.orig_seq_lens = [0] * self.n_seqs
self.prompt_lens = [0] * self.n_seqs
self.query_lens = [0] * self.n_seqs
self.context_lens = [0] * self.n_seqs
self.curr_sliding_window_blocks = [0] * self.n_seqs
@@ -552,6 +563,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
inter_data.seq_lens[seq_idx] = seq_len
inter_data.orig_seq_lens[seq_idx] = seq_len
inter_data.prompt_lens[seq_idx] = seq_data.get_prompt_len()
inter_data.context_lens[seq_idx] = context_len
inter_data.input_tokens[seq_idx].extend(tokens)
inter_data.inputs_embeds = prompt_embeds