Compare commits

..

8 Commits

Author SHA1 Message Date
Harry Mellor
72506c9834 Check for truthy rope_parameters not the existence of it (#30983)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
(cherry picked from commit 19c583398a)
2025-12-18 14:07:04 -08:00
Isotr0py
b2eb84de77 [Bugfix] Remove tile_size=64 for mm_prefix triton attention (#30973)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
(cherry picked from commit d2dc5dfc6e)
2025-12-18 14:06:49 -08:00
sarathc-cerebras
ac43367ced adds jais 2 support (#30188)
Signed-off-by: sarathc-cerebras <sarath.chandran@cerebras.net>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
(cherry picked from commit 28d15ab56b)
2025-12-18 14:06:33 -08:00
Yifan Qiao
30fe765e9f [Fix][FlexAttention] return max logical block index to handle reused blocks (#30915)
Signed-off-by: Yifan Qiao <yifanqiao@berkeley.edu>
(cherry picked from commit 11a89cf95c)
2025-12-18 14:06:17 -08:00
Lucas Wilkinson
2c0ee0fde8 [BugFix] Partial revert of #29558 (DeepEP HT + PIECEWISE CG support) (#30910)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Lucas Wilkinson <LucasWilkinson@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
(cherry picked from commit 30bb19a760)
2025-12-17 23:56:41 -08:00
Isotr0py
55f1fc1b1b [v1] Add PrefixLM support to TritonAttention backend (#30386)
(cherry picked from commit 74a1ac38b0)
2025-12-17 19:57:52 -08:00
Varun Sundar Rabindranath
17f3988094 [BugFix] Workspace allocation during profile run : DeepEPHighThroughput + DeepGEMM (#30899)
(cherry picked from commit e3fc374a9a)
2025-12-17 19:57:33 -08:00
Nicolò Lucchesi
682c38583c [CI][Bugfix] Fix flaky tests/entrypoints/openai/test_audio.py::test_chat_streaming_audio (#30878)
Signed-off-by: NickLucche <nlucches@redhat.com>
(cherry picked from commit 9ca8cb38fd)
2025-12-17 19:57:15 -08:00
15 changed files with 875 additions and 208 deletions

View File

@@ -406,6 +406,7 @@ th {
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ |
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ |
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ |
| `Jais2ForCausalLM` | Jais2 | `inceptionai/Jais-2-8B-Chat`, `inceptionai/Jais-2-70B-Chat`, etc. | | ✅︎ |
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ |
| `KimiLinearForCausalLM` | Kimi-Linear-48B-A3B-Base, Kimi-Linear-48B-A3B-Instruct | `moonshotai/Kimi-Linear-48B-A3B-Base`, `moonshotai/Kimi-Linear-48B-A3B-Instruct` | | ✅︎ |
| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ |

View File

@@ -233,24 +233,6 @@ def test_splitting_ops_dynamic():
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
def test_moe_splitting_ops_deepep_ht_piecewise():
# Non-inductor, non-attn-fusion case: DeepEP HT with dp>1
# should add MoE ops to splitting_ops on top of attention ops.
config = VllmConfig(
parallel_config=ParallelConfig(
all2all_backend="deepep_high_throughput",
data_parallel_size=8,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
),
)
splitting_ops = config.compilation_config.splitting_ops
assert splitting_ops is not None
assert "vllm::moe_forward" in splitting_ops
assert "vllm::moe_forward_shared" in splitting_ops
def test_moe_splitting_ops_deepep_ht_inductor_partition():
# Inductor partition case: user-provided splitting_ops should be
# preserved and MoE ops should be appended for DeepEP HT with dp>1.
@@ -277,26 +259,6 @@ def test_moe_splitting_ops_deepep_ht_inductor_partition():
]
def test_moe_splitting_ops_deepep_ht_attn_fusion_no_inductor():
# Pure attn-fusion case without inductor partition: even with
# DeepEP HT and dp>1, we should not re-enable piecewise compilation
# or add MoE ops into splitting_ops.
config = VllmConfig(
parallel_config=ParallelConfig(
all2all_backend="deepep_high_throughput",
data_parallel_size=8,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
pass_config={"fuse_attn_quant": True, "eliminate_noops": True},
custom_ops=["+quant_fp8"],
cudagraph_mode=CUDAGraphMode.PIECEWISE,
),
)
assert config.compilation_config.splitting_ops == []
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
def test_should_split():
import torch

View File

@@ -254,7 +254,9 @@ async def test_single_chat_session_input_audio(
async def test_chat_streaming_audio(
client: openai.AsyncOpenAI, model_name: str, audio_url: str
):
messages = dummy_messages_from_audio_url(audio_url)
messages = dummy_messages_from_audio_url(
audio_url, "What's a short title for this audio?"
)
# test single completion
chat_completion = await client.chat.completions.create(

View File

@@ -15,7 +15,10 @@ from tests.v1.attention.utils import (
create_standard_kv_cache_spec,
create_vllm_config,
)
from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadataBuilder
from vllm.v1.attention.backends.flex_attention import (
FlexAttentionMetadataBuilder,
physical_to_logical_mapping,
)
from ..models.utils import check_embeddings_close, check_logprobs_close
@@ -212,5 +215,31 @@ def test_block_mask_direct_vs_slow_path():
)
def test_physical_to_logical_mapping_handles_reused_blocks():
"""Regression test: reused physical blocks map to the latest logical block.
For sliding-window / hybrid attention layers, physical KV-cache blocks can be
reused over time. The inverse mapping must therefore select the latest
logical block index for a physical block id.
"""
# Padding should not make physical block 0 look live.
block_table = torch.tensor([[6, 0, 0, 0]], dtype=torch.int32)
seq_lens = torch.tensor([1 * 16], dtype=torch.int32) # only 1 block valid
out = physical_to_logical_mapping(
block_table=block_table, seq_lens=seq_lens, block_size=16, total_blocks=10
)
assert out[0, 0].item() == -1
assert out[0, 6].item() == 0
# If a physical block id appears multiple times (block reuse), mapping should
# point to the latest logical block index.
block_table2 = torch.tensor([[2, 2, 5]], dtype=torch.int32)
seq_lens2 = torch.tensor([3 * 16], dtype=torch.int32)
out2 = physical_to_logical_mapping(
block_table=block_table2, seq_lens=seq_lens2, block_size=16, total_blocks=8
)
assert out2[0, 2].item() == 1
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -1,17 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Literal, NamedTuple
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
from typing import Any, NamedTuple
import pytest
from huggingface_hub import hf_hub_download
from pytest import MarkDecorator
from transformers import AutoModelForImageTextToText
from tests.quantization.utils import is_quant_method_supported
from vllm.assets.image import ImageAsset
from vllm.multimodal.image import rescale_image_size
from vllm.utils.torch_utils import set_default_torch_num_threads
from ....conftest import PromptImageInput, VllmRunner
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner
from ...utils import check_logprobs_close
@@ -21,9 +27,10 @@ class GGUFMMTestConfig(NamedTuple):
gguf_backbone: str
gguf_mmproj: str
prompt: list[str]
mm_data: dict[Literal["images"], PromptImageInput]
image_names: list[str] # Store names, load PIL images at runtime
max_model_len: int = 4096
marks: list[MarkDecorator] = []
mm_processor_kwargs: dict[str, Any] = {}
@property
def gguf_model(self):
@@ -31,27 +38,75 @@ class GGUFMMTestConfig(NamedTuple):
return hf_hub_download(self.gguf_repo, filename=self.gguf_backbone)
# Common prompts aligned with test_common.py "gemma3" entry format
_GEMMA3_PROMPTS = IMAGE_ASSETS.prompts(
{
"stop_sign": (
"<bos><start_of_turn>user\n"
"<start_of_image>What's the content in the center of the image?"
"<end_of_turn>\n<start_of_turn>model\n"
),
"cherry_blossom": (
"<bos><start_of_turn>user\n"
"<start_of_image>What is the season?"
"<end_of_turn>\n<start_of_turn>model\n"
),
}
)
# Image asset names - load at runtime to avoid pickle issues with subprocess
_GEMMA3_IMAGE_NAMES = ["stop_sign", "cherry_blossom"]
# Regular multimodal (no pan-and-scan) - uses QAT Q4_0 GGUF
GEMMA3_CONFIG = GGUFMMTestConfig(
original_model="google/gemma-3-4b-it",
gguf_repo="google/gemma-3-4b-it-qat-q4_0-gguf",
gguf_backbone="gemma-3-4b-it-q4_0.gguf",
gguf_mmproj="mmproj-model-f16-4B.gguf",
prompt=["<start_of_image>Describe this image in detail:"],
mm_data={"images": [ImageAsset("stop_sign").pil_image]},
prompt=_GEMMA3_PROMPTS,
image_names=_GEMMA3_IMAGE_NAMES,
max_model_len=4096,
marks=[pytest.mark.core_model],
mm_processor_kwargs={},
)
MODELS_TO_TEST = [GEMMA3_CONFIG]
# Pan-and-scan multimodal - uses unquantized BF16 GGUF
GEMMA3_CONFIG_PAN_AND_SCAN = GGUFMMTestConfig(
original_model="google/gemma-3-4b-it",
gguf_repo="unsloth/gemma-3-4b-it-GGUF",
gguf_backbone="gemma-3-4b-it-BF16.gguf",
gguf_mmproj="mmproj-BF16.gguf",
prompt=_GEMMA3_PROMPTS,
image_names=_GEMMA3_IMAGE_NAMES,
max_model_len=4096,
marks=[pytest.mark.core_model],
mm_processor_kwargs={"do_pan_and_scan": True},
)
MODELS_TO_TEST = [GEMMA3_CONFIG, GEMMA3_CONFIG_PAN_AND_SCAN]
def run_multimodal_gguf_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
model: GGUFMMTestConfig,
dtype: str,
max_tokens: int,
num_logprobs: int,
):
# Run gguf model.
# Load images at runtime (inside subprocess) to avoid pickle issues
images = [ImageAsset(name).pil_image for name in model.image_names]
size_factors = [0.25, 0.5, 1.0]
inputs_per_image = [
(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
)
for image, prompt in zip(images, model.prompt)
]
# NOTE: Run vLLM first to avoid CUDA init issues with multiprocessing fork.
# Run GGUF model via vLLM.
with (
set_default_torch_num_threads(1),
vllm_runner(
@@ -60,35 +115,42 @@ def run_multimodal_gguf_test(
tokenizer_name=model.original_model,
dtype=dtype,
max_model_len=model.max_model_len,
mm_processor_kwargs=model.mm_processor_kwargs,
) as gguf_model,
):
gguf_outputs = gguf_model.generate_greedy_logprobs(
prompts=model.prompt,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
**model.mm_data,
)
gguf_outputs_per_case = [
gguf_model.generate_greedy_logprobs(
prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images,
)
for prompts, images in inputs_per_image
]
# Run unquantized model.
with vllm_runner(
model_name=model.original_model,
enforce_eager=True, # faster tests
# Then run HfRunner for HuggingFace baseline comparison.
with hf_runner(
model.original_model,
dtype=dtype,
max_model_len=model.max_model_len,
) as original_model:
original_outputs = original_model.generate_greedy_logprobs(
prompts=model.prompt,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
**model.mm_data,
)
auto_cls=AutoModelForImageTextToText,
) as hf_model:
hf_outputs_per_case = [
hf_model.generate_greedy_logprobs_limit(
prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images,
)
for prompts, images in inputs_per_image
]
check_logprobs_close(
outputs_0_lst=original_outputs,
outputs_1_lst=gguf_outputs,
name_0="original",
name_1="gguf",
)
for hf_outputs, gguf_outputs in zip(hf_outputs_per_case, gguf_outputs_per_case):
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=gguf_outputs,
name_0="hf",
name_1="gguf",
)
@pytest.mark.skipif(
@@ -105,11 +167,14 @@ def run_multimodal_gguf_test(
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
def test_models(
def test_gemma3_mm_gguf(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
model: GGUFMMTestConfig,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
run_multimodal_gguf_test(vllm_runner, model, dtype, max_tokens, num_logprobs)
run_multimodal_gguf_test(
hf_runner, vllm_runner, model, dtype, max_tokens, num_logprobs
)

View File

@@ -295,6 +295,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"internlm/internlm3-8b-instruct", trust_remote_code=True
),
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
"Jais2ForCausalLM": _HfExamplesInfo(
"inceptionai/Jais-2-8B-Chat", min_transformers_version="4.58"
),
"JambaForCausalLM": _HfExamplesInfo(
"ai21labs/AI21-Jamba-1.5-Mini",
extras={

View File

@@ -86,6 +86,9 @@ def kernel_unified_attention_2d(
USE_SOFTCAP: tl.constexpr, # bool
USE_SINKS: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
USE_MM_PREFIX: tl.constexpr, # bool
MAX_MM_RANGES: tl.constexpr, # int
mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
@@ -270,7 +273,38 @@ def kernel_unified_attention_2d(
else:
V = V_load
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
# Compute attention mask: causal by default (key <= query)
query_abs_pos = context_len + query_pos[:, None]
seq_mask = seq_offset[None, :] <= query_abs_pos
# Apply sliding window to base mask BEFORE mm_prefix OR.
# Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix
if SLIDING_WINDOW > 0:
seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW)
# PrefixLM: extend mask with bidirectional ranges for multimodal tokens.
# Applied AFTER sliding window so mm_prefix ranges override SW restriction.
if USE_MM_PREFIX:
for i in range(MAX_MM_RANGES):
range_start = tl.load(
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2
)
range_end = tl.load(
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + 1
)
is_valid = range_start < range_end
q_in_range = (
(query_abs_pos >= range_start)
& (query_abs_pos <= range_end)
& is_valid
)
k_in_range = (
(seq_offset[None, :] >= range_start)
& (seq_offset[None, :] <= range_end)
& is_valid
)
seq_mask |= q_in_range & k_in_range
# S : (BLOCK_M, TILE_SIZE)
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
@@ -284,13 +318,6 @@ def kernel_unified_attention_2d(
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
)
if SLIDING_WINDOW > 0:
S = tl.where(
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
S,
float("-inf"),
)
if USE_ALIBI_SLOPES:
S += alibi_slope[:, None] * (seq_offset - context_len)
@@ -398,6 +425,9 @@ def kernel_unified_attention_3d(
num_seqs: tl.int32,
BLOCK_M: tl.constexpr, # int
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
USE_MM_PREFIX: tl.constexpr, # bool
MAX_MM_RANGES: tl.constexpr, # int
mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence
):
q_block_global_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
@@ -559,7 +589,38 @@ def kernel_unified_attention_3d(
else:
V = V_load
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
# Compute attention mask: causal by default (key <= query)
query_abs_pos = context_len + query_pos[:, None]
seq_mask = seq_offset[None, :] <= query_abs_pos
# Apply sliding window to base mask BEFORE mm_prefix OR.
# Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix
if SLIDING_WINDOW > 0:
seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW)
# PrefixLM: extend mask with bidirectional ranges for multimodal tokens.
# Applied AFTER sliding window so mm_prefix ranges override SW restriction.
if USE_MM_PREFIX:
for i in range(MAX_MM_RANGES):
range_start = tl.load(
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2
)
range_end = tl.load(
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + 1
)
is_valid = range_start < range_end
q_in_range = (
(query_abs_pos >= range_start)
& (query_abs_pos <= range_end)
& is_valid
)
k_in_range = (
(seq_offset[None, :] >= range_start)
& (seq_offset[None, :] <= range_end)
& is_valid
)
seq_mask |= q_in_range & k_in_range
# S : (BLOCK_M, TILE_SIZE)
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
@@ -572,13 +633,6 @@ def kernel_unified_attention_3d(
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
)
if SLIDING_WINDOW > 0:
S = tl.where(
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
S,
float("-inf"),
)
if USE_ALIBI_SLOPES:
S += alibi_slope[:, None] * (seq_offset - context_len)
@@ -732,6 +786,38 @@ def reduce_segments(
tl.store(output_ptr + output_offset, acc, mask=dim_mask)
def _is_gemma3_attention(head_size: int, sliding_window: int) -> bool:
"""Detect Gemma3 models via unique (head_size, sliding_window) signature.
Gemma3 models are the only ones using sliding_window=1024 with
head_size 128 (27B) or 256 (1B, 4B, 12B). Other SWA models use
different window sizes (Mistral=4096, Phi-3=2047).
"""
return sliding_window == 1024 and head_size in (128, 256)
def _get_tile_size(
head_size: int,
sliding_window: int,
element_size: int,
is_prefill: bool,
) -> int:
"""Select tile size with Gemma3-specific optimization.
For Gemma3, use 32 for both prefill and decode to better utilize
the larger head dimension (128/256). For other models, use
the default vLLM behavior.
"""
if _is_gemma3_attention(head_size, sliding_window):
# Gemma3: use 32 for decode (default is 16)
return 32
# Default behavior
if is_prefill:
return 32
return 16 if element_size >= 2 else 32
def unified_attention(
q,
k,
@@ -759,6 +845,8 @@ def unified_attention(
qq_bias=None,
# Optional tensor for sinks
sinks=None,
# Optional tensor for prefix lengths (PrefixLM support)
mm_prefix_range=None,
):
assert causal, "Only causal attention is supported"
assert q_descale is None, "Q scales not supported"
@@ -766,6 +854,17 @@ def unified_attention(
if sinks is not None:
assert sinks.shape[0] == q.shape[1], "Sinks must be num_query_heads size"
use_mm_prefix = False
max_mm_ranges = 0
if mm_prefix_range is not None:
if mm_prefix_range.ndim == 3:
use_mm_prefix = True
max_mm_ranges = mm_prefix_range.shape[1]
else:
raise ValueError(
f"Unsupported mm_prefix_range shape: {mm_prefix_range.shape}"
)
use_alibi_slopes = alibi_slopes is not None
use_qq_bias = qq_bias is not None
@@ -792,11 +891,21 @@ def unified_attention(
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
# Assigning default tile sizes for prefill and decode.
# Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1)
# and at least 16 for all other data types.
TILE_SIZE_PREFILL = 32
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
# Tile sizes for prefill and decode. Gemma3 models use optimized values.
# Note: tile size must be at least 32 for fp8 (element_size == 1).
sliding_window_val = 1 + window_size[0] if window_size[0] >= 0 else 0
TILE_SIZE_PREFILL = _get_tile_size(
head_size,
sliding_window_val,
q.element_size(),
is_prefill=True,
)
TILE_SIZE_DECODE = _get_tile_size(
head_size,
sliding_window_val,
q.element_size(),
is_prefill=False,
)
# Launch the 2D kernel if
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
@@ -847,6 +956,9 @@ def unified_attention(
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),
USE_MM_PREFIX=use_mm_prefix,
MAX_MM_RANGES=max_mm_ranges,
mm_prefix_range_ptr=mm_prefix_range,
SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0),
stride_k_cache_1=k.stride(1),
@@ -895,6 +1007,9 @@ def unified_attention(
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),
USE_MM_PREFIX=use_mm_prefix,
MAX_MM_RANGES=max_mm_ranges,
mm_prefix_range_ptr=mm_prefix_range,
SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0),
stride_k_cache_1=k.stride(1),

View File

@@ -915,8 +915,6 @@ class CompilationConfig:
"mode is CompilationMode.VLLM_COMPILE"
)
added_default_splitting_ops = False
if self.pass_config.fuse_attn_quant and not self.use_inductor_graph_partition:
self.set_splitting_ops_for_attn_fusion()
else:
@@ -930,7 +928,6 @@ class CompilationConfig:
# for details. Make a copy to avoid mutating the class-level
# list via reference.
self.splitting_ops = list(self._attention_ops)
added_default_splitting_ops = True
elif len(self.splitting_ops) == 0:
if (
self.cudagraph_mode == CUDAGraphMode.PIECEWISE
@@ -958,44 +955,25 @@ class CompilationConfig:
self.cudagraph_mode = CUDAGraphMode.FULL
self.splitting_ops = []
# split MoE ops for cudagraph
moe_ops = [
"vllm::moe_forward",
"vllm::moe_forward_shared",
]
# Disable CUDA graphs for DeepEP high-throughput since its not CG compatible
backend = all2all_backend or envs.VLLM_ALL2ALL_BACKEND
dp_size = data_parallel_size if data_parallel_size is not None else 1
need_moe_splitting = (
if (
backend == "deepep_high_throughput"
and dp_size > 1
# pure attn-fusion without inductor partition deliberately disables
# piecewise graphs and MoE splitting.
and not (
self.pass_config.fuse_attn_quant
and not self.use_inductor_graph_partition
and self.cudagraph_mode != CUDAGraphMode.NONE
):
# TODO: Piecewise Cuda graph might be enabled
# if torch compile cache key issue fixed
# See https://github.com/vllm-project/vllm/pull/25093
logger.info(
"DeepEP: Disabling CUDA Graphs since DeepEP high-throughput kernels "
"are optimized for prefill and are incompatible with CUDA Graphs. "
"In order to use CUDA Graphs for decode-optimized workloads, "
"use --all2all-backend with another option, such as "
"deepep_low_latency, pplx, or allgather_reducescatter."
)
)
if need_moe_splitting and self.cudagraph_mode != CUDAGraphMode.NONE:
# if we just initialized default splitting_ops for this config,
# automatically append the MoE ops
if added_default_splitting_ops:
for op in moe_ops:
if op not in self.splitting_ops:
self.splitting_ops.append(op)
# make sure MoE ops are split out
if not any(op in self.splitting_ops for op in moe_ops):
self.cudagraph_mode = CUDAGraphMode.NONE
logger.warning_once(
"DeepEP high throughput backend with data_parallel_size > 1 "
"requires splitting MoE ops from cudagraphs. Please ensure "
"'vllm::moe_forward' or 'vllm::moe_forward_shared' are "
"present in CompilationConfig.splitting_ops."
)
elif self.cudagraph_mode.has_full_cudagraphs():
# fall back to piecewise when MoE splitting is required.
self.cudagraph_mode = CUDAGraphMode.PIECEWISE
self.cudagraph_mode = CUDAGraphMode.NONE
def set_splitting_ops_for_attn_fusion(self):
assert self.pass_config.fuse_attn_quant

View File

@@ -795,7 +795,10 @@ class FusedMoEModularKernel(torch.nn.Module):
top_k,
global_num_experts,
local_num_experts,
expert_tokens_meta,
# expert_tokens_meta help in allocating optimal/minimal
# amount of workspace. Mark it None, so we allocate for
# the worst-case scenario.
expert_tokens_meta=None,
)
)

View File

@@ -19,7 +19,6 @@ from collections.abc import Iterable
from itertools import islice
import torch
import torch.nn.functional as F
from torch import nn
from transformers import Gemma3TextConfig
@@ -223,77 +222,9 @@ class Gemma3Attention(nn.Module):
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
if not kwargs.get("has_images", False):
# Fast path for text-only inputs. The performance for the text-only
# inputs are not affected by the naive attention below.
output, _ = self.o_proj(attn_output)
return output
# NOTE(woosuk): Gemma3 uses bidirectional attention between image tokens
# that correspond to the same image while using causal attention
# otherwise. Current attention backends cannot handle this pattern, so
# we temporarily use a naive attention implementation with mask tensors.
# We intentionally keep the attention backend as-is and only override
# `attn_output` with the naive implementation's output. This minimizes
# changes to existing model runners and attention backends. The call to
# `self.attn(q, k, v)` is only used to populate the KV cache - its
# output is discarded and overwritten below. While this duplicates
# computation, it maintains compatibility.
# TODO(woosuk): Optimize by implementing custom attention kernels.
attn_output = self.naive_attn_with_masks(q, k, v, out=attn_output, **kwargs)
output, _ = self.o_proj(attn_output)
return output
def naive_attn_with_masks(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
**kwargs,
) -> torch.Tensor:
# NOTE(woosuk): As described in the comment above, this code is not
# meant to be performant. It is only meant to be correct.
q = q.view(-1, self.num_heads, self.head_dim)
# Expand the key and value to handle GQA.
num_queries_per_kv = self.num_heads // self.num_kv_heads
k = k.view(-1, self.num_kv_heads, self.head_dim)
k = k.repeat_interleave(num_queries_per_kv, dim=-2)
v = v.view(-1, self.num_kv_heads, self.head_dim)
v = v.repeat_interleave(num_queries_per_kv, dim=-2)
if self.is_sliding:
attn_masks = kwargs["local_attn_masks"]
else:
attn_masks = kwargs["global_attn_masks"]
seq_lens = kwargs["seq_lens"]
start_idx = 0
for seq_len, attn_mask in zip(seq_lens, attn_masks):
end_idx = start_idx + seq_len
query = q[start_idx:end_idx].unsqueeze(0)
key = k[start_idx:end_idx].unsqueeze(0)
value = v[start_idx:end_idx].unsqueeze(0)
# Transpose.
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask,
self.scaling,
)
output = output.transpose(1, 2).flatten(-2, -1)
out[start_idx:end_idx] = output
start_idx = end_idx
return out
class Gemma3DecoderLayer(nn.Module):
def __init__(

View File

@@ -0,0 +1,529 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Jais2 model compatible with HuggingFace weights."""
from collections.abc import Iterable
import torch
from torch import nn
from transformers import Jais2Config
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.activation import ReLUSquaredActivation
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
class Jais2MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.up_proj = ColumnParallelLinear(
input_size=hidden_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.up_proj",
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.act_fn = ReLUSquaredActivation()
def forward(self, x):
x, _ = self.up_proj(x)
x = self.act_fn(x)
x, _ = self.down_proj(x)
return x
class Jais2Attention(nn.Module):
def __init__(
self,
config: Jais2Config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position_embeddings: int = 8192,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
cache_config: CacheConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
layer_idx = extract_layer_index(prefix)
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self.head_dim = getattr(
config, "head_dim", self.hidden_size // self.total_num_heads
)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
is_neox_style = True
if quant_config is not None and quant_config.get_name() == "gguf":
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
rope_parameters=getattr(config, "rope_parameters", None),
is_neox_style=is_neox_style,
)
if hasattr(config, "interleaved_sliding_window"):
interleaved_sliding_window = config.interleaved_sliding_window
if isinstance(interleaved_sliding_window, int):
sliding_window = interleaved_sliding_window
elif isinstance(interleaved_sliding_window, list):
sw_idx = layer_idx % len(interleaved_sliding_window)
sliding_window = interleaved_sliding_window[sw_idx]
else:
raise ValueError(
f"{type(interleaved_sliding_window)} is not supported."
)
else:
sliding_window = None
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn",
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class Jais2DecoderLayer(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
config: Jais2Config,
prefix: str = "",
) -> None:
super().__init__()
config = config or vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = self.get_quant_config(vllm_config)
self.hidden_size = config.hidden_size
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False
)
self.self_attn = Jais2Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(
config, "num_key_value_heads", config.num_attention_heads
),
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = Jais2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.input_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
self.post_attention_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = (
self.input_layernorm(hidden_states + residual),
hidden_states + residual,
)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
)
# Fully Connected
hidden_states, residual = (
self.post_attention_layernorm(hidden_states + residual),
hidden_states + residual,
)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
"""Get quantization config for this layer. Override in subclasses."""
return vllm_config.quant_config
@support_torch_compile
class Jais2Model(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
layer_type: type[nn.Module] = Jais2DecoderLayer,
):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.quant_config = quant_config
self.padding_idx = config.pad_token_id
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (
config.tie_word_embeddings and get_pp_group().is_last_rank
):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: layer_type(
config=config,
vllm_config=vllm_config,
prefix=prefix,
),
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states + residual), residual
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
# Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Jais2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
}
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
self.model = self._init_model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size
),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size, logit_scale
)
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
return Jais2Model(vllm_config=vllm_config, prefix=prefix)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
model_output = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return model_output
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
)
return loader.load_weights(weights)

View File

@@ -127,6 +127,7 @@ _TEXT_GENERATION_MODELS = {
"InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
"InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"), # noqa: E501
"Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),

View File

@@ -318,15 +318,15 @@ def patch_rope_parameters(config: PretrainedConfig) -> None:
# Transformers v4 installed, legacy config fields may be present
if (rope_scaling := getattr(config, "rope_scaling", None)) is not None:
config.rope_parameters = rope_scaling
if (
rope_theta is not None or partial_rotary_factor is not None
) and not getattr(config, "rope_parameters", None):
config.rope_parameters = {"rope_type": "default"}
if rope_theta is not None:
if not hasattr(config, "rope_parameters"):
config.rope_parameters = {"rope_type": "default"}
config.rope_parameters["rope_theta"] = rope_theta
if partial_rotary_factor is not None:
if not hasattr(config, "rope_parameters"):
config.rope_parameters = {"rope_type": "default"}
config.rope_parameters["partial_rotary_factor"] = partial_rotary_factor
elif rope_theta is not None or hasattr(config, "rope_parameters"):
elif rope_theta is not None or getattr(config, "rope_parameters", None):
# Transformers v5 installed
# Patch these fields in case they used non-standard names
if rope_theta is not None:

View File

@@ -160,7 +160,7 @@ def physical_to_logical_mapping(
└───────────────────────────────────────────┘
If multiple logical blocks map to the same physical block,
this function returns the first (minimum) logical block index.
this function returns the latest (maximum) logical block index.
If a physical block is not mapped to by any logical block,
its value in the result will be -1.
@@ -183,6 +183,15 @@ def physical_to_logical_mapping(
To prevent this, we use seq_lens and block_size to mask out unused
entries, ensuring only valid block references are processed.
IMPORTANT: Reused physical blocks (sliding-window / hybrid attention)
────────────────────────────────────────────────────────────────────
For some attention types, physical cache blocks can be reused over time.
This can cause the same physical block id to appear multiple times in a row
of `block_table` at different logical block indices. In that case, only the
latest logical block index corresponds to the current contents of that
physical block. Therefore, the inverse mapping must pick the maximum logical
block index for each physical block id.
Args:
block_table: Tensor of shape [max_reqs, max_num_blocks]
mapping logical blocks to physical locations. May contain
@@ -217,8 +226,8 @@ def physical_to_logical_mapping(
mask, torch.arange(max_num_blocks, device=device)[None, :], 0
)
physical_to_logical.scatter_(
-1, valid_block_table.to(torch.int64), valid_logical_indices
physical_to_logical.scatter_reduce_(
-1, valid_block_table.to(torch.int64), valid_logical_indices, reduce="amax"
)
# NB - Seems like block 0 is always empty so we reset it manually
physical_to_logical[:, 0] = -1

View File

@@ -76,6 +76,39 @@ class TritonAttentionMetadata:
# Optional aot scheduling
scheduler_metadata: torch.Tensor | None = None
prefix_scheduler_metadata: torch.Tensor | None = None
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
@property
def mm_prefix_range_tensor(self) -> torch.Tensor | None:
"""Convert mm_prefix_range dict to padded tensor for Triton kernel.
Returns shape: (num_seqs, max_ranges, 2) with 0-padding for empty ranges.
Empty ranges have start==end==0, which kernel skips via is_valid check.
"""
# TODO(Isotr0py): Move to model runner's attention metadata
# preparation to avoid duplicate computation.
if self.mm_prefix_range is None:
return None
num_seqs = self.seq_lens.shape[0]
device = self.seq_lens.device
# Collect ranges, using [(0,0)] for empty sequences to ensure uniform dims
range_lists = [
self.mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)] for i in range(num_seqs)
]
# Return None if all ranges are trivial (only (0,0) placeholders)
if all(r == [(0, 0)] for r in range_lists):
return None
# Create 2D tensors with shape (num_ranges, 2) for each sequence
range_tensors = [
torch.tensor(r, dtype=torch.int32, device=device).view(-1, 2)
for r in range_lists
]
return torch.nested.nested_tensor(range_tensors).to_padded_tensor(0)
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
@@ -268,6 +301,10 @@ class TritonAttentionBackend(AttentionBackend):
def supports_head_size(cls, head_size: int) -> bool:
return head_size >= 32
@classmethod
def supports_mm_prefix(cls) -> bool:
return True
@classmethod
def supports_sink(cls) -> bool:
return True
@@ -427,6 +464,7 @@ class TritonAttentionImpl(AttentionImpl):
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor
unified_attention(
q=query[:num_actual_tokens],
@@ -453,6 +491,7 @@ class TritonAttentionImpl(AttentionImpl):
softmax_segm_expsum=softmax_segm_expsum,
sinks=self.sinks,
output_scale=output_scale,
mm_prefix_range=mm_prefix_range_tensor,
)
return output