Compare commits

...

12 Commits

Author SHA1 Message Date
Isotr0py
55f1fc1b1b [v1] Add PrefixLM support to TritonAttention backend (#30386)
(cherry picked from commit 74a1ac38b0)
2025-12-17 19:57:52 -08:00
Varun Sundar Rabindranath
17f3988094 [BugFix] Workspace allocation during profile run : DeepEPHighThroughput + DeepGEMM (#30899)
(cherry picked from commit e3fc374a9a)
2025-12-17 19:57:33 -08:00
Nicolò Lucchesi
682c38583c [CI][Bugfix] Fix flaky tests/entrypoints/openai/test_audio.py::test_chat_streaming_audio (#30878)
Signed-off-by: NickLucche <nlucches@redhat.com>
(cherry picked from commit 9ca8cb38fd)
2025-12-17 19:57:15 -08:00
Yan Ma
f124b56786 [XPU] fix broken fp8 online quantization for XPU platform (#30831)
Signed-off-by: Yan Ma <yan.ma@intel.com>
(cherry picked from commit 4f735babb7)
2025-12-17 00:30:39 -08:00
Li, Jiang
d78e128b8b [Bugfix][CPU] Fix CPU backend ROPE dispatch for VL models (#30829)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
Signed-off-by: Li, Jiang <bigpyj64@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
(cherry picked from commit 0cd5353644)
2025-12-17 00:18:02 -08:00
Lucas Wilkinson
761b730dcb [BugFix] Fix memory spike in workspace allocation (#30744)
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
(cherry picked from commit 00a8d7628c)
2025-12-17 00:17:31 -08:00
TJian
f34eca5f01 [ROCm] [Bugfix] Fix torch sdpa hallucination (#30789)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
(cherry picked from commit 2410132bb1)
2025-12-16 17:16:25 -08:00
Wentao Ye
4cd332f3cf [CI] Skip ci failure test (#30804)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
(cherry picked from commit b6ec077e05)
2025-12-16 17:16:08 -08:00
Roger Wang
16484d394c [Core][MM] Optimize encoder cache manager by operating with embeddings only (#30475)
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Sun Kim <sunytokki@gmail.com>
(cherry picked from commit f5f51e5931)
2025-12-16 17:15:49 -08:00
Isotr0py
e397bd6592 [CI/Build] Skip broken ViT backend functionality test tempoarily (#30782)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
(cherry picked from commit 4de08ad698)
2025-12-16 17:15:26 -08:00
Isotr0py
6a88d590bb [Bugfix] Fix broken ViT attention selection for Blackwell device (#30731)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
(cherry picked from commit e94384bbad)
2025-12-16 17:13:54 -08:00
Shanshan Shen
ad8c073131 [CustomOp] Extract ApplyRotaryEmb as CustomOp and unify the dispatch logic (#29873)
Signed-off-by: shen-shanshan <467638484@qq.com>
Co-authored-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
(cherry picked from commit 3bd9c49158)
2025-12-16 17:13:23 -08:00
41 changed files with 1217 additions and 547 deletions

View File

@@ -1223,6 +1223,8 @@ steps:
# FIXIT: find out which code initialize cuda before running the test
# before the fix, we need to use spawn to test it
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
# Alot of these tests are on the edge of OOMing
- export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# There is some Tensor Parallelism related processing logic in LoRA that
# requires multi-GPU testing for validation.
- pytest -v -s -x lora/test_chatglm3_tp.py

View File

@@ -523,6 +523,8 @@ CUSTOM_OPS_QUANT_RMS_NORM = ["+quant_fp8,+rms_norm"]
list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)),
)
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
# TODO: remove skip after we fix the fusion thoroughly
@pytest.mark.skipif(is_blackwell(), reason="Temporarily disabled on Blackwell")
def test_rms_group_quant(
model_name: str,
model_kwargs: dict[str, Any],
@@ -562,7 +564,7 @@ def test_rms_group_quant(
splitting_ops=splitting_ops,
# Common
mode=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(eliminate_noops=True, enable_fusion=True),
pass_config=PassConfig(eliminate_noops=True, fuse_norm_quant=True),
# Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True},
)

View File

@@ -254,7 +254,9 @@ async def test_single_chat_session_input_audio(
async def test_chat_streaming_audio(
client: openai.AsyncOpenAI, model_name: str, audio_url: str
):
messages = dummy_messages_from_audio_url(audio_url)
messages = dummy_messages_from_audio_url(
audio_url, "What's a short title for this audio?"
)
# test single completion
chat_completion = await client.chat.completions.create(

View File

@@ -0,0 +1,203 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for ApplyRotaryEmb CustomOp dispatch behavior.
This test ensures that RotaryEmbedding classes correctly call the appropriate
ApplyRotaryEmb methods based on the calling context:
1. RotaryEmbedding.forward_native() -> ApplyRotaryEmb.forward_native()
2. RotaryEmbedding.forward_cuda() -> ApplyRotaryEmb.forward() (auto-dispatch)
3. RotaryEmbedding.forward_hip() -> ApplyRotaryEmb.forward() (auto-dispatch)
"""
from dataclasses import dataclass
import pytest
import torch
from vllm.config import (
CompilationConfig,
VllmConfig,
get_cached_compilation_config,
set_current_vllm_config,
)
from vllm.platforms import current_platform
CUDA_DEVICES = ["cuda:0"]
@dataclass
class RotaryEmbeddingTestCase:
"""Test case configuration for RotaryEmbedding dispatch tests."""
name: str
rope_class: type
rope_kwargs: dict
method_name: str # forward_native, forward_cuda, forward
positions_shape: tuple # (num_tokens,) or (3, num_tokens) or (4, num_tokens)
expect_forward_native: bool # Should call ApplyRotaryEmb.forward_native()
expect_forward: bool # Should call ApplyRotaryEmb.forward()
def get_test_cases() -> list[RotaryEmbeddingTestCase]:
"""Generate test cases for all RotaryEmbedding classes."""
from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import (
Ernie4_5_VLRotaryEmbedding,
)
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
from vllm.model_executor.layers.rotary_embedding.xdrope import XDRotaryEmbedding
common_kwargs = {
"head_size": 128,
"rotary_dim": 128,
"max_position_embeddings": 4096,
"base": 10000,
"is_neox_style": True,
"dtype": torch.bfloat16,
}
return [
# MRotaryEmbedding tests
RotaryEmbeddingTestCase(
name="MRotaryEmbedding.forward_native",
rope_class=MRotaryEmbedding,
rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]},
method_name="forward_native",
positions_shape=(3, 32), # 2D for multimodal
expect_forward_native=True,
expect_forward=False,
),
RotaryEmbeddingTestCase(
name="MRotaryEmbedding.forward_cuda_1d",
rope_class=MRotaryEmbedding,
rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]},
method_name="forward_cuda",
positions_shape=(32,), # 1D triggers apply_rotary_emb path
expect_forward_native=False,
expect_forward=True,
),
# XDRotaryEmbedding tests
RotaryEmbeddingTestCase(
name="XDRotaryEmbedding.forward",
rope_class=XDRotaryEmbedding,
rope_kwargs={
**common_kwargs,
"scaling_alpha": 1.0,
"xdrope_section": [16, 16, 16, 16],
},
method_name="forward",
positions_shape=(4, 32), # 4D for P/W/H/T
expect_forward_native=False,
expect_forward=True,
),
# Ernie4_5_VLRotaryEmbedding tests
RotaryEmbeddingTestCase(
name="Ernie4_5_VLRotaryEmbedding.forward_native",
rope_class=Ernie4_5_VLRotaryEmbedding,
rope_kwargs={**common_kwargs, "mrope_section": [22, 22, 20]},
method_name="forward_native",
positions_shape=(3, 32), # 2D for multimodal
expect_forward_native=True,
expect_forward=False,
),
]
def run_dispatch_test(
test_case: RotaryEmbeddingTestCase,
device: str,
):
"""Run a dispatch test for a RotaryEmbedding class."""
vllm_config = VllmConfig(
compilation_config=CompilationConfig(custom_ops=["all", "+apply_rotary_emb"])
)
get_cached_compilation_config.cache_clear()
with set_current_vllm_config(vllm_config):
rope = test_case.rope_class(**test_case.rope_kwargs).to(device=device)
apply_rotary_emb = rope.apply_rotary_emb
# Verify custom op is enabled
if test_case.expect_forward_native:
assert (
apply_rotary_emb._forward_method != apply_rotary_emb.forward_native
), "Test setup error: ApplyRotaryEmb custom op should be enabled"
# Setup call tracking
call_tracker = {"forward_native_called": False, "forward_called": False}
original_forward_native = apply_rotary_emb.forward_native
original_forward = apply_rotary_emb.forward
def tracked_forward_native(*args, **kwargs):
call_tracker["forward_native_called"] = True
return original_forward_native(*args, **kwargs)
def tracked_forward(*args, **kwargs):
call_tracker["forward_called"] = True
return original_forward(*args, **kwargs)
apply_rotary_emb.forward_native = tracked_forward_native
apply_rotary_emb.forward = tracked_forward
try:
num_tokens = test_case.positions_shape[-1]
num_q_heads = 8
num_kv_heads = 2
head_size = test_case.rope_kwargs["head_size"]
max_position = test_case.rope_kwargs["max_position_embeddings"]
positions = torch.randint(
0, max_position // 4, test_case.positions_shape, device=device
)
query = torch.randn(
num_tokens, num_q_heads * head_size, dtype=torch.bfloat16, device=device
)
key = torch.randn(
num_tokens,
num_kv_heads * head_size,
dtype=torch.bfloat16,
device=device,
)
# Call the method under test
method = getattr(rope, test_case.method_name)
method(positions, query.clone(), key.clone())
# Verify expectations
if test_case.expect_forward_native:
assert call_tracker["forward_native_called"], (
f"{test_case.name} should call ApplyRotaryEmb.forward_native()"
)
if not test_case.expect_forward:
assert not call_tracker["forward_called"], (
f"{test_case.name} should NOT call ApplyRotaryEmb.forward(). "
"Bug: when +apply_rotary_emb is enabled, forward_native() "
"incorrectly dispatches to CUDA/HIP kernels."
)
if test_case.expect_forward:
assert call_tracker["forward_called"], (
f"{test_case.name} should call ApplyRotaryEmb.forward()"
)
finally:
apply_rotary_emb.forward_native = original_forward_native
apply_rotary_emb.forward = original_forward
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests."
)
@pytest.mark.parametrize("test_case", get_test_cases(), ids=lambda tc: tc.name)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_rotary_embedding_dispatch(
test_case: RotaryEmbeddingTestCase,
device: str,
):
"""
Test that RotaryEmbedding classes dispatch to the correct ApplyRotaryEmb method.
- forward_native methods should call ApplyRotaryEmb.forward_native()
- forward_cuda/forward methods should call ApplyRotaryEmb.forward()
"""
run_dispatch_test(test_case, device)

View File

@@ -1,17 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Literal, NamedTuple
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
from typing import Any, NamedTuple
import pytest
from huggingface_hub import hf_hub_download
from pytest import MarkDecorator
from transformers import AutoModelForImageTextToText
from tests.quantization.utils import is_quant_method_supported
from vllm.assets.image import ImageAsset
from vllm.multimodal.image import rescale_image_size
from vllm.utils.torch_utils import set_default_torch_num_threads
from ....conftest import PromptImageInput, VllmRunner
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner
from ...utils import check_logprobs_close
@@ -21,9 +27,10 @@ class GGUFMMTestConfig(NamedTuple):
gguf_backbone: str
gguf_mmproj: str
prompt: list[str]
mm_data: dict[Literal["images"], PromptImageInput]
image_names: list[str] # Store names, load PIL images at runtime
max_model_len: int = 4096
marks: list[MarkDecorator] = []
mm_processor_kwargs: dict[str, Any] = {}
@property
def gguf_model(self):
@@ -31,27 +38,75 @@ class GGUFMMTestConfig(NamedTuple):
return hf_hub_download(self.gguf_repo, filename=self.gguf_backbone)
# Common prompts aligned with test_common.py "gemma3" entry format
_GEMMA3_PROMPTS = IMAGE_ASSETS.prompts(
{
"stop_sign": (
"<bos><start_of_turn>user\n"
"<start_of_image>What's the content in the center of the image?"
"<end_of_turn>\n<start_of_turn>model\n"
),
"cherry_blossom": (
"<bos><start_of_turn>user\n"
"<start_of_image>What is the season?"
"<end_of_turn>\n<start_of_turn>model\n"
),
}
)
# Image asset names - load at runtime to avoid pickle issues with subprocess
_GEMMA3_IMAGE_NAMES = ["stop_sign", "cherry_blossom"]
# Regular multimodal (no pan-and-scan) - uses QAT Q4_0 GGUF
GEMMA3_CONFIG = GGUFMMTestConfig(
original_model="google/gemma-3-4b-it",
gguf_repo="google/gemma-3-4b-it-qat-q4_0-gguf",
gguf_backbone="gemma-3-4b-it-q4_0.gguf",
gguf_mmproj="mmproj-model-f16-4B.gguf",
prompt=["<start_of_image>Describe this image in detail:"],
mm_data={"images": [ImageAsset("stop_sign").pil_image]},
prompt=_GEMMA3_PROMPTS,
image_names=_GEMMA3_IMAGE_NAMES,
max_model_len=4096,
marks=[pytest.mark.core_model],
mm_processor_kwargs={},
)
MODELS_TO_TEST = [GEMMA3_CONFIG]
# Pan-and-scan multimodal - uses unquantized BF16 GGUF
GEMMA3_CONFIG_PAN_AND_SCAN = GGUFMMTestConfig(
original_model="google/gemma-3-4b-it",
gguf_repo="unsloth/gemma-3-4b-it-GGUF",
gguf_backbone="gemma-3-4b-it-BF16.gguf",
gguf_mmproj="mmproj-BF16.gguf",
prompt=_GEMMA3_PROMPTS,
image_names=_GEMMA3_IMAGE_NAMES,
max_model_len=4096,
marks=[pytest.mark.core_model],
mm_processor_kwargs={"do_pan_and_scan": True},
)
MODELS_TO_TEST = [GEMMA3_CONFIG, GEMMA3_CONFIG_PAN_AND_SCAN]
def run_multimodal_gguf_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
model: GGUFMMTestConfig,
dtype: str,
max_tokens: int,
num_logprobs: int,
):
# Run gguf model.
# Load images at runtime (inside subprocess) to avoid pickle issues
images = [ImageAsset(name).pil_image for name in model.image_names]
size_factors = [0.25, 0.5, 1.0]
inputs_per_image = [
(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
)
for image, prompt in zip(images, model.prompt)
]
# NOTE: Run vLLM first to avoid CUDA init issues with multiprocessing fork.
# Run GGUF model via vLLM.
with (
set_default_torch_num_threads(1),
vllm_runner(
@@ -60,35 +115,42 @@ def run_multimodal_gguf_test(
tokenizer_name=model.original_model,
dtype=dtype,
max_model_len=model.max_model_len,
mm_processor_kwargs=model.mm_processor_kwargs,
) as gguf_model,
):
gguf_outputs = gguf_model.generate_greedy_logprobs(
prompts=model.prompt,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
**model.mm_data,
)
gguf_outputs_per_case = [
gguf_model.generate_greedy_logprobs(
prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images,
)
for prompts, images in inputs_per_image
]
# Run unquantized model.
with vllm_runner(
model_name=model.original_model,
enforce_eager=True, # faster tests
# Then run HfRunner for HuggingFace baseline comparison.
with hf_runner(
model.original_model,
dtype=dtype,
max_model_len=model.max_model_len,
) as original_model:
original_outputs = original_model.generate_greedy_logprobs(
prompts=model.prompt,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
**model.mm_data,
)
auto_cls=AutoModelForImageTextToText,
) as hf_model:
hf_outputs_per_case = [
hf_model.generate_greedy_logprobs_limit(
prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images,
)
for prompts, images in inputs_per_image
]
check_logprobs_close(
outputs_0_lst=original_outputs,
outputs_1_lst=gguf_outputs,
name_0="original",
name_1="gguf",
)
for hf_outputs, gguf_outputs in zip(hf_outputs_per_case, gguf_outputs_per_case):
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=gguf_outputs,
name_0="hf",
name_1="gguf",
)
@pytest.mark.skipif(
@@ -105,11 +167,14 @@ def run_multimodal_gguf_test(
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
def test_models(
def test_gemma3_mm_gguf(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
model: GGUFMMTestConfig,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
run_multimodal_gguf_test(vllm_runner, model, dtype, max_tokens, num_logprobs)
run_multimodal_gguf_test(
hf_runner, vllm_runner, model, dtype, max_tokens, num_logprobs
)

View File

@@ -388,6 +388,7 @@ def run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner):
"mm_encoder_attn_backend",
[None] + current_platform.get_supported_vit_attn_backends(),
)
@pytest.mark.skip(reason="Broken test due to memory segmentation fault")
@create_new_process_for_each_test()
def test_vit_backend_functionality(
model_key: str,

View File

@@ -60,12 +60,12 @@ def test_profiling(model_id: str, max_model_len: int):
total_num_patches.item() + num_tiles.item() + 3
) # image start, image, image end
profiled_tokens = profiler.get_mm_max_contiguous_tokens(
profiled_tokens = profiler.get_mm_max_tokens(
max_model_len,
mm_counts=mm_counts,
)
assert total_tokens == profiled_tokens["image"]
assert total_num_patches == profiled_tokens["image"]
assert total_tokens == sum(
placeholder.length
for placeholder in decoder_dummy_data.multi_modal_placeholders["image"]

View File

@@ -9,6 +9,7 @@ from tempfile import NamedTemporaryFile, TemporaryDirectory
import numpy as np
import pytest
import torch
from PIL import Image, ImageChops
from vllm.multimodal.image import convert_image_mode
@@ -410,6 +411,97 @@ def test_argsort_mm_positions(case):
assert modality_idxs == expected_modality_idxs
@pytest.mark.parametrize(
"is_embed,expected",
[
(None, 5),
(torch.tensor([True, True, True, True, True]), 5),
(torch.tensor([False, False, False, False, False]), 0),
(torch.tensor([True, False, True, False, True]), 3),
(torch.tensor([True]), 1),
],
)
def test_placeholder_range_get_num_embeds(is_embed, expected):
length = len(is_embed) if is_embed is not None else 5
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
assert pr.get_num_embeds == expected
@pytest.mark.parametrize(
"is_embed,expected",
[
(None, None),
(
torch.tensor([False, True, False, True, True]),
torch.tensor([0, 1, 1, 2, 3]),
),
(torch.tensor([True, True, True]), torch.tensor([1, 2, 3])),
],
)
def test_placeholder_range_embeds_cumsum(is_embed, expected):
length = len(is_embed) if is_embed is not None else 5
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
if expected is None:
assert pr.embeds_cumsum is None
return
assert torch.equal(pr.embeds_cumsum, expected)
# cached_property should return the same object on repeated access
assert pr.embeds_cumsum is pr.embeds_cumsum
@pytest.mark.parametrize(
"is_embed,start_idx,end_idx,expected",
[
(None, 2, 4, (2, 4)),
(
torch.tensor([False, True, False, True, True]),
3,
5,
(1, 3),
),
(
torch.tensor([False, True, False, True, True]),
0,
2,
(0, 1),
),
(
torch.tensor([True, False, True, False]),
2,
2,
(1, 1),
),
],
)
def test_placeholder_range_get_embeds_indices_in_range(
is_embed, start_idx, end_idx, expected
):
length = len(is_embed) if is_embed is not None else 5
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
assert pr.get_embeds_indices_in_range(start_idx, end_idx) == expected
@pytest.mark.parametrize(
"offset,is_embed,expected",
[
(0, None, [(0, 4)]),
(
2,
torch.tensor([False, True, False, True, True]),
[(3, 3), (5, 6)],
),
(0, torch.tensor([True, True, True, True]), [(0, 3)]),
(0, torch.tensor([False, False, False, False]), []),
],
)
def test_placeholder_range_extract_embeds_range(offset, is_embed, expected):
length = len(is_embed) if is_embed is not None else 5
pr = PlaceholderRange(offset=offset, length=length, is_embed=is_embed)
assert pr.extract_embeds_range() == expected
@pytest.mark.asyncio
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
@@ -23,7 +24,7 @@ class MockRequest:
)
self.mm_features.append(feature)
def get_num_encoder_tokens(self, input_id: int) -> int:
def get_num_encoder_embeds(self, input_id: int) -> int:
return self._token_counts[input_id]
@@ -162,8 +163,8 @@ def test_schedule_request_multi_images_respect_space_limit():
num_tokens_to_schedule = 0
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
compute_budget -= req.get_num_encoder_tokens(0)
num_tokens_to_schedule += req.get_num_encoder_embeds(0)
compute_budget -= req.get_num_encoder_embeds(0)
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
@@ -174,7 +175,75 @@ def test_schedule_request_multi_images_respect_compute_limit():
compute_budget = 10
num_tokens_to_schedule = 0
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
compute_budget -= req.get_num_encoder_tokens(0)
num_tokens_to_schedule += req.get_num_encoder_embeds(0)
compute_budget -= req.get_num_encoder_embeds(0)
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
def test_encoder_cache_with_is_embed_mask():
class MockRequestWithMask(MockRequest):
def get_num_encoder_embeds(self, input_id: int) -> int:
return self.mm_features[input_id].mm_position.get_num_embeds
is_embed = torch.zeros(100, dtype=torch.bool)
is_embed[torch.tensor([5, 15, 25, 35, 45, 55, 65, 75])] = True
request = MockRequestWithMask("r1", ["img1"], [100])
request.mm_features[0] = MultiModalFeatureSpec(
data=None,
modality="image",
identifier="img1",
mm_position=PlaceholderRange(offset=0, length=100, is_embed=is_embed),
)
manager = EncoderCacheManager(cache_size=100)
manager.allocate(request, 0)
assert manager.num_free_slots == 92
assert "img1" in manager.cached
old_size = 100
new_size = request.mm_features[0].mm_position.get_num_embeds
assert new_size == 8
savings_ratio = old_size / new_size
assert savings_ratio == 12.5
def test_encoder_cache_mask_based_retrieval():
class MockRequestWithMask(MockRequest):
def get_num_encoder_embeds(self, input_id: int) -> int:
return self.mm_features[input_id].mm_position.get_num_embeds
is_embed = torch.tensor(
[False, False, True, True, False, True, True, True, False, False]
)
request = MockRequestWithMask("r1", ["img1"], [10])
request.mm_features[0] = MultiModalFeatureSpec(
data=None,
modality="image",
identifier="img1",
mm_position=PlaceholderRange(offset=0, length=10, is_embed=is_embed),
)
manager = EncoderCacheManager(cache_size=50)
manager.allocate(request, 0)
assert request.mm_features[0].mm_position.get_num_embeds == 5
start_idx = 2
end_idx = 8
num_embeds_before = is_embed[:start_idx].sum().item()
num_embeds_in_range = is_embed[start_idx:end_idx].sum().item()
assert num_embeds_before == 0
assert num_embeds_in_range == 5
start_idx = 0
end_idx = 5
num_embeds_before = is_embed[:start_idx].sum().item() if start_idx > 0 else 0
num_embeds_in_range = is_embed[start_idx:end_idx].sum().item()
assert num_embeds_before == 0
assert num_embeds_in_range == 2

View File

@@ -38,7 +38,7 @@ class MockRequest:
)
self.mm_features.append(feature)
def get_num_encoder_tokens(self, input_id: int) -> int:
def get_num_encoder_embeds(self, input_id: int) -> int:
assert input_id < len(self._token_counts)
return self._token_counts[input_id]

View File

@@ -86,6 +86,9 @@ def kernel_unified_attention_2d(
USE_SOFTCAP: tl.constexpr, # bool
USE_SINKS: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int
USE_MM_PREFIX: tl.constexpr, # bool
MAX_MM_RANGES: tl.constexpr, # int
mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence
stride_k_cache_0: tl.int64, # int
stride_k_cache_1: tl.int64, # int
stride_k_cache_2: tl.int64, # int
@@ -270,7 +273,38 @@ def kernel_unified_attention_2d(
else:
V = V_load
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
# Compute attention mask: causal by default (key <= query)
query_abs_pos = context_len + query_pos[:, None]
seq_mask = seq_offset[None, :] <= query_abs_pos
# Apply sliding window to base mask BEFORE mm_prefix OR.
# Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix
if SLIDING_WINDOW > 0:
seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW)
# PrefixLM: extend mask with bidirectional ranges for multimodal tokens.
# Applied AFTER sliding window so mm_prefix ranges override SW restriction.
if USE_MM_PREFIX:
for i in range(MAX_MM_RANGES):
range_start = tl.load(
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2
)
range_end = tl.load(
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + 1
)
is_valid = range_start < range_end
q_in_range = (
(query_abs_pos >= range_start)
& (query_abs_pos <= range_end)
& is_valid
)
k_in_range = (
(seq_offset[None, :] >= range_start)
& (seq_offset[None, :] <= range_end)
& is_valid
)
seq_mask |= q_in_range & k_in_range
# S : (BLOCK_M, TILE_SIZE)
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
@@ -284,13 +318,6 @@ def kernel_unified_attention_2d(
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
)
if SLIDING_WINDOW > 0:
S = tl.where(
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
S,
float("-inf"),
)
if USE_ALIBI_SLOPES:
S += alibi_slope[:, None] * (seq_offset - context_len)
@@ -398,6 +425,9 @@ def kernel_unified_attention_3d(
num_seqs: tl.int32,
BLOCK_M: tl.constexpr, # int
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
USE_MM_PREFIX: tl.constexpr, # bool
MAX_MM_RANGES: tl.constexpr, # int
mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence
):
q_block_global_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
@@ -559,7 +589,38 @@ def kernel_unified_attention_3d(
else:
V = V_load
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
# Compute attention mask: causal by default (key <= query)
query_abs_pos = context_len + query_pos[:, None]
seq_mask = seq_offset[None, :] <= query_abs_pos
# Apply sliding window to base mask BEFORE mm_prefix OR.
# Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix
if SLIDING_WINDOW > 0:
seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW)
# PrefixLM: extend mask with bidirectional ranges for multimodal tokens.
# Applied AFTER sliding window so mm_prefix ranges override SW restriction.
if USE_MM_PREFIX:
for i in range(MAX_MM_RANGES):
range_start = tl.load(
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2
)
range_end = tl.load(
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + 1
)
is_valid = range_start < range_end
q_in_range = (
(query_abs_pos >= range_start)
& (query_abs_pos <= range_end)
& is_valid
)
k_in_range = (
(seq_offset[None, :] >= range_start)
& (seq_offset[None, :] <= range_end)
& is_valid
)
seq_mask |= q_in_range & k_in_range
# S : (BLOCK_M, TILE_SIZE)
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
@@ -572,13 +633,6 @@ def kernel_unified_attention_3d(
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
)
if SLIDING_WINDOW > 0:
S = tl.where(
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
S,
float("-inf"),
)
if USE_ALIBI_SLOPES:
S += alibi_slope[:, None] * (seq_offset - context_len)
@@ -732,6 +786,43 @@ def reduce_segments(
tl.store(output_ptr + output_offset, acc, mask=dim_mask)
def _is_gemma3_attention(head_size: int, sliding_window: int) -> bool:
"""Detect Gemma3 models via unique (head_size, sliding_window) signature.
Gemma3 models are the only ones using sliding_window=1024 with
head_size 128 (27B) or 256 (1B, 4B, 12B). Other SWA models use
different window sizes (Mistral=4096, Phi-3=2047).
"""
return sliding_window == 1024 and head_size in (128, 256)
def _get_tile_size(
head_size: int,
sliding_window: int,
element_size: int,
is_mm_prefix: bool,
is_prefill: bool,
) -> int:
"""Select tile size with Gemma3-specific optimization.
For Gemma3, use 32 for both prefill and decode to better utilize
the larger head dimension (128/256). For other models, use
the default vLLM behavior.
"""
if is_mm_prefix:
# Multimodal bidirectional attention needs a larger tile size
return 64
if _is_gemma3_attention(head_size, sliding_window):
# Gemma3: use 32 for decode (default is 16)
return 32
# Default behavior
if is_prefill:
return 32
return 16 if element_size >= 2 else 32
def unified_attention(
q,
k,
@@ -759,6 +850,8 @@ def unified_attention(
qq_bias=None,
# Optional tensor for sinks
sinks=None,
# Optional tensor for prefix lengths (PrefixLM support)
mm_prefix_range=None,
):
assert causal, "Only causal attention is supported"
assert q_descale is None, "Q scales not supported"
@@ -766,6 +859,17 @@ def unified_attention(
if sinks is not None:
assert sinks.shape[0] == q.shape[1], "Sinks must be num_query_heads size"
use_mm_prefix = False
max_mm_ranges = 0
if mm_prefix_range is not None:
if mm_prefix_range.ndim == 3:
use_mm_prefix = True
max_mm_ranges = mm_prefix_range.shape[1]
else:
raise ValueError(
f"Unsupported mm_prefix_range shape: {mm_prefix_range.shape}"
)
use_alibi_slopes = alibi_slopes is not None
use_qq_bias = qq_bias is not None
@@ -792,11 +896,23 @@ def unified_attention(
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
# Assigning default tile sizes for prefill and decode.
# Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1)
# and at least 16 for all other data types.
TILE_SIZE_PREFILL = 32
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
# Tile sizes for prefill and decode. Gemma3 models use optimized values.
# Note: tile size must be at least 32 for fp8 (element_size == 1).
sliding_window_val = 1 + window_size[0] if window_size[0] >= 0 else 0
TILE_SIZE_PREFILL = _get_tile_size(
head_size,
sliding_window_val,
q.element_size(),
is_mm_prefix=use_mm_prefix,
is_prefill=True,
)
TILE_SIZE_DECODE = _get_tile_size(
head_size,
sliding_window_val,
q.element_size(),
is_mm_prefix=use_mm_prefix,
is_prefill=False,
)
# Launch the 2D kernel if
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
@@ -847,6 +963,9 @@ def unified_attention(
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),
USE_MM_PREFIX=use_mm_prefix,
MAX_MM_RANGES=max_mm_ranges,
mm_prefix_range_ptr=mm_prefix_range,
SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0),
stride_k_cache_1=k.stride(1),
@@ -895,6 +1014,9 @@ def unified_attention(
USE_QQ_BIAS=use_qq_bias,
USE_SOFTCAP=(softcap > 0),
USE_SINKS=(sinks is not None),
USE_MM_PREFIX=use_mm_prefix,
MAX_MM_RANGES=max_mm_ranges,
mm_prefix_range_ptr=mm_prefix_range,
SLIDING_WINDOW=(1 + window_size[0]),
stride_k_cache_0=k.stride(0),
stride_k_cache_1=k.stride(1),

View File

@@ -16,6 +16,7 @@ import einops
import torch
import torch.nn.functional as F
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
@@ -89,6 +90,13 @@ def torch_sdpa_wrapper(
v: torch.Tensor,
cu_seqlens: torch.Tensor,
) -> torch.Tensor:
# Never remove the contiguous logic for ROCm
# Without it, hallucinations occur with the backend
if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
outputs = []
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()

View File

@@ -144,7 +144,7 @@ class ECExampleConnector(ECConnectorBase):
Update ECConnector state after encoder cache allocation.
"""
mm_hash = request.mm_features[index].identifier
num_encoder_token = request.get_num_encoder_tokens(index)
num_encoder_token = request.get_num_encoder_embeds(index)
# Insert mm_hash only if this block has not been recorded yet.
self._mm_datas_need_loads[mm_hash] = num_encoder_token

View File

@@ -795,7 +795,10 @@ class FusedMoEModularKernel(torch.nn.Module):
top_k,
global_num_experts,
local_num_experts,
expert_tokens_meta,
# expert_tokens_meta help in allocating optimal/minimal
# amount of workspace. Mark it None, so we allocate for
# the worst-case scenario.
expert_tokens_meta=None,
)
)

View File

@@ -27,6 +27,10 @@ from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8LinearMethod
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
maybe_create_device_identity,
)
from vllm.model_executor.parameter import ModelWeightParameter
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
@@ -305,6 +309,37 @@ class XPUFp8LinearMethod(Fp8LinearMethod):
def __init__(self, quant_config: Fp8Config):
super().__init__(quant_config)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
def process_weights_after_loading(self, layer: Module) -> None:
# If checkpoint not serialized fp8, quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:

View File

@@ -7,7 +7,7 @@ import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp
from .common import apply_rotary_emb_torch
from .common import ApplyRotaryEmb
@CustomOp.register("rotary_embedding")
@@ -49,6 +49,10 @@ class RotaryEmbeddingBase(CustomOp):
rocm_aiter_ops.is_triton_rotary_embed_enabled()
)
self.apply_rotary_emb = ApplyRotaryEmb(
is_neox_style=self.is_neox_style,
)
def _compute_inv_freq(self, base: float) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
@@ -123,7 +127,12 @@ class RotaryEmbedding(RotaryEmbeddingBase):
query = query.view(num_tokens, -1, head_size)
query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_emb_torch(query_rot, cos, sin, is_neox_style)
query_rot = ApplyRotaryEmb.forward_static(
query_rot,
cos,
sin,
is_neox_style,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
# key may be None in some cases, e.g. cross-layer KV sharing
@@ -132,7 +141,12 @@ class RotaryEmbedding(RotaryEmbeddingBase):
key = key.view(num_tokens, -1, head_size)
key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_emb_torch(key_rot, cos, sin, is_neox_style)
key_rot = ApplyRotaryEmb.forward_static(
key_rot,
cos,
sin,
is_neox_style,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key

View File

@@ -2,19 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Callable
from functools import cache
from importlib.util import find_spec
import torch
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.model_executor.custom_op import CustomOp
from vllm.utils.torch_utils import direct_register_custom_op
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
logger = init_logger(__name__)
@@ -32,71 +27,6 @@ def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
return x.flatten(-2)
def apply_rotary_emb_torch(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
def apply_rotary_emb_dispatch(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, is_neox_style: bool
) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
if current_platform.is_cuda():
return apply_rotary_emb(x.unsqueeze(0), cos, sin, not is_neox_style).squeeze(0)
else:
return apply_rotary_emb_torch(x, cos, sin, is_neox_style)
@cache
def dispatch_rotary_emb_function(
default: Callable[..., torch.Tensor] | None = None,
) -> Callable[..., torch.Tensor]:
if current_platform.is_cuda():
return apply_rotary_emb
# if torch compile is not enabled
# use rotary embedding function from flash_attn package
# otherwise use the naive pytorch embedding implementation
# is faster when torch compile is enabled.
if current_platform.is_rocm() and not torch.compiler.is_compiling():
if find_spec("flash_attn") is not None:
from flash_attn.ops.triton.rotary import apply_rotary
return apply_rotary
else:
logger.warning(
"flash_attn is not installed. Falling back to PyTorch "
"implementation for rotary embeddings."
)
if default is not None:
return default
return apply_rotary_emb_torch
# yarn functions
# Inverse dim formula to find dim based on number of rotations
def yarn_find_correction_dim(
@@ -186,3 +116,164 @@ direct_register_custom_op(
mutates_args=["query", "key"], # These tensors are modified in-place
fake_impl=_flashinfer_rotary_embedding_fake,
)
@CustomOp.register("apply_rotary_emb")
class ApplyRotaryEmb(CustomOp):
def __init__(
self,
enforce_enable: bool = False,
is_neox_style: bool = True,
enable_fp32_compute: bool = False,
) -> None:
super().__init__(enforce_enable)
self.is_neox_style = is_neox_style
self.enable_fp32_compute = enable_fp32_compute
self.apply_rotary_emb_flash_attn = None
if find_spec("flash_attn") is not None:
from flash_attn.ops.triton.rotary import apply_rotary
self.apply_rotary_emb_flash_attn = apply_rotary
@staticmethod
def forward_static(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool = True,
enable_fp32_compute: bool = False,
) -> torch.Tensor:
"""
Args:
x: [batch_size (optional), seq_len, num_heads, head_size]
cos: [seq_len, head_size // 2]
sin: [seq_len, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style.
enable_fp32_compute: Temporarily convert x, cos, sin to FP32 dtype
for higher accuracy.
"""
origin_dtype = x.dtype
if enable_fp32_compute:
x = x.float()
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
output = torch.cat((o1, o2), dim=-1)
else:
output = torch.stack((o1, o2), dim=-1).flatten(-2)
if enable_fp32_compute:
output = output.to(origin_dtype)
return output
def forward_native(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
output = self.forward_static(
x, cos, sin, self.is_neox_style, self.enable_fp32_compute
)
return output
def forward_cuda(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
origin_dtype = x.dtype
if self.enable_fp32_compute:
x = x.float()
cos = cos.float()
sin = sin.float()
origin_shape = x.shape
if len(origin_shape) == 3:
# x: [seq_len, num_heads, head_size]
x = x.unsqueeze(0)
"""
Arguments of apply_rotary_emb() in vllm_flash_attn:
x: [batch_size, seq_len, nheads, headdim]
cos, sin: [seqlen_rotary, rotary_dim / 2]
interleaved: defalut as False (Neox-style).
...
"""
interleaved = not self.is_neox_style
output = apply_rotary_emb(x, cos, sin, interleaved)
if len(origin_shape) == 3:
output = output.squeeze(0)
if self.enable_fp32_compute:
output = output.to(origin_dtype)
return output
def forward_hip(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
if self.apply_rotary_emb_flash_attn is not None:
origin_dtype = x.dtype
if self.enable_fp32_compute:
x = x.float()
cos = cos.float()
sin = sin.float()
origin_shape = x.shape
if len(origin_shape) == 3:
# x: [seq_len, num_heads, head_size]
x = x.unsqueeze(0)
"""
Arguments of apply_rotary() in flash_attn:
x: [batch_size, seq_len, nheads, headdim]
cos, sin: [seqlen_rotary, rotary_dim / 2]
interleaved: defalut as False (Neox-style).
...
"""
interleaved = not self.is_neox_style
output = self.apply_rotary_emb_flash_attn(
x, cos, sin, interleaved=interleaved
).type_as(x)
if len(origin_shape) == 3:
output = output.squeeze(0)
if self.enable_fp32_compute:
output = output.to(origin_dtype)
else:
# Falling back to PyTorch native implementation.
output = self.forward_native(x, cos, sin)
return output
def forward_cpu(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
# TODO (bigPYJ1151): need to enable fused CPU ROPE here
return self.forward_native(x, cos, sin)
def extra_repr(self) -> str:
s = f"is_neox_style={self.is_neox_style}"
s += f"enable_fp32_compute={self.enable_fp32_compute}"
return s

View File

@@ -4,7 +4,6 @@
import torch
from .common import apply_rotary_emb_dispatch
from .mrope import MRotaryEmbedding
@@ -55,14 +54,22 @@ class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
query_rot = self.apply_rotary_emb.forward_native(
query_rot,
cos,
sin,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
key_rot = self.apply_rotary_emb.forward_native(
key_rot,
cos,
sin,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key

View File

@@ -8,7 +8,6 @@ import torch
from vllm.triton_utils import tl, triton
from .base import RotaryEmbeddingBase
from .common import apply_rotary_emb_dispatch
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale
@@ -301,14 +300,22 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
query_rot = self.apply_rotary_emb.forward_native(
query_rot,
cos,
sin,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
key_rot = self.apply_rotary_emb.forward_native(
key_rot,
cos,
sin,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
@@ -347,13 +354,21 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
query_rot = self.apply_rotary_emb(
query_rot,
cos,
sin,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
key_rot = self.apply_rotary_emb(
key_rot,
cos,
sin,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key

View File

@@ -4,7 +4,6 @@
import numpy as np
import torch
from .common import apply_rotary_emb_dispatch
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
@@ -36,7 +35,7 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
dtype,
)
def forward(
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
@@ -68,14 +67,73 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
query_rot = self.apply_rotary_emb.forward_native(
query_rot,
cos,
sin,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
key_rot = self.apply_rotary_emb.forward_native(
key_rot,
cos,
sin,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""PyTorch-native implementation equivalent to forward().
Args:
positions:
[4, num_tokens] (P/W/H/T positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert positions.ndim == 2
assert key is not None
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
cos = torch.cat(
[m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1
)
sin = torch.cat(
[m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1
)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = self.apply_rotary_emb(
query_rot,
cos,
sin,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = self.apply_rotary_emb(
key_rot,
cos,
sin,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key

View File

@@ -29,6 +29,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import (
MultiModalEmbeddings,
@@ -158,32 +161,6 @@ class DotsOCRProcessingInfo(Qwen2VLProcessingInfo):
return processor
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
orig_dtype = tensor.dtype
tensor = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
output = (tensor * cos) + (rotate_half(tensor) * sin)
output = output.to(orig_dtype)
return output
class VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
@@ -298,6 +275,11 @@ class DotsVisionAttention(nn.Module):
prefix=f"{prefix}.attn",
)
self.apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
def forward(
self,
hidden_states: torch.Tensor,
@@ -318,7 +300,11 @@ class DotsVisionAttention(nn.Module):
if rotary_pos_emb is not None:
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
qk_rotated = self.apply_rotary_emb(
qk_concat,
rotary_pos_emb.cos(),
rotary_pos_emb.sin(),
)
q, k = torch.chunk(qk_rotated, 2, dim=0)
context_layer = self.attn(

View File

@@ -33,7 +33,7 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops import rearrange
from transformers import BatchFeature
from vllm.attention.backends.registry import AttentionBackendEnum
@@ -53,6 +53,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
@@ -69,7 +72,6 @@ from vllm.multimodal.processing import (
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -89,52 +91,6 @@ logger = init_logger(__name__)
# === Vision Transformer === #
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)
def apply_rotary_emb_torch(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
apply_rotary_emb = apply_rotary_emb_torch
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
output = apply_rotary_emb(t_, cos, sin).type_as(t)
return output
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
"""All-gather the input tensor interleavely across model parallel group."""
import torch.distributed as dist
@@ -200,6 +156,11 @@ class Ernie4_5_VisionAttention(nn.Module):
prefix=f"{prefix}.attn",
)
self.apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
@@ -244,7 +205,11 @@ class Ernie4_5_VisionAttention(nn.Module):
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
if rotary_pos_emb is not None:
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
qk_rotated = self.apply_rotary_emb(
qk_concat,
rotary_pos_emb.cos(),
rotary_pos_emb.sin(),
)
q, k = torch.chunk(qk_rotated, 2, dim=0)
output = self.attn(

View File

@@ -19,7 +19,6 @@ from collections.abc import Iterable
from itertools import islice
import torch
import torch.nn.functional as F
from torch import nn
from transformers import Gemma3TextConfig
@@ -223,77 +222,9 @@ class Gemma3Attention(nn.Module):
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
if not kwargs.get("has_images", False):
# Fast path for text-only inputs. The performance for the text-only
# inputs are not affected by the naive attention below.
output, _ = self.o_proj(attn_output)
return output
# NOTE(woosuk): Gemma3 uses bidirectional attention between image tokens
# that correspond to the same image while using causal attention
# otherwise. Current attention backends cannot handle this pattern, so
# we temporarily use a naive attention implementation with mask tensors.
# We intentionally keep the attention backend as-is and only override
# `attn_output` with the naive implementation's output. This minimizes
# changes to existing model runners and attention backends. The call to
# `self.attn(q, k, v)` is only used to populate the KV cache - its
# output is discarded and overwritten below. While this duplicates
# computation, it maintains compatibility.
# TODO(woosuk): Optimize by implementing custom attention kernels.
attn_output = self.naive_attn_with_masks(q, k, v, out=attn_output, **kwargs)
output, _ = self.o_proj(attn_output)
return output
def naive_attn_with_masks(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
out: torch.Tensor,
**kwargs,
) -> torch.Tensor:
# NOTE(woosuk): As described in the comment above, this code is not
# meant to be performant. It is only meant to be correct.
q = q.view(-1, self.num_heads, self.head_dim)
# Expand the key and value to handle GQA.
num_queries_per_kv = self.num_heads // self.num_kv_heads
k = k.view(-1, self.num_kv_heads, self.head_dim)
k = k.repeat_interleave(num_queries_per_kv, dim=-2)
v = v.view(-1, self.num_kv_heads, self.head_dim)
v = v.repeat_interleave(num_queries_per_kv, dim=-2)
if self.is_sliding:
attn_masks = kwargs["local_attn_masks"]
else:
attn_masks = kwargs["global_attn_masks"]
seq_lens = kwargs["seq_lens"]
start_idx = 0
for seq_len, attn_mask in zip(seq_lens, attn_masks):
end_idx = start_idx + seq_len
query = q[start_idx:end_idx].unsqueeze(0)
key = k[start_idx:end_idx].unsqueeze(0)
value = v[start_idx:end_idx].unsqueeze(0)
# Transpose.
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
output = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask,
self.scaling,
)
output = output.transpose(1, 2).flatten(-2, -1)
out[start_idx:end_idx] = output
start_idx = end_idx
return out
class Gemma3DecoderLayer(nn.Module):
def __init__(

View File

@@ -65,6 +65,9 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -95,7 +98,7 @@ from .interfaces import (
SupportsMultiModal,
SupportsPP,
)
from .qwen2_vl import _create_qwen2vl_field_factory, apply_rotary_pos_emb_vision
from .qwen2_vl import _create_qwen2vl_field_factory
from .utils import (
AutoWeightsLoader,
WeightsMapper,
@@ -304,6 +307,8 @@ class Glm4vVisionAttention(nn.Module):
multimodal_config=multimodal_config,
)
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
@@ -339,8 +344,10 @@ class Glm4vVisionAttention(nn.Module):
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
# [2 * b, s, heads, head_dim]
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(
qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
qk_rotated = self.apply_rotary_emb(
qk_concat,
rotary_pos_emb_cos,
rotary_pos_emb_sin,
)
q, k = torch.chunk(qk_rotated, 2, dim=0)

View File

@@ -30,6 +30,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
@@ -59,7 +62,6 @@ from vllm.multimodal.processing import (
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -341,20 +343,14 @@ def apply_rotary_pos_emb_flashatt(
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
elif current_platform.is_rocm():
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
else:
# For other platforms, use PyTorch fallback
from vllm.model_executor.layers.rotary_embedding.common import (
apply_rotary_emb_torch,
)
apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
apply_rotary_emb = partial(apply_rotary_emb_torch, is_neox_style=True)
q_embed = apply_rotary_emb(q, cos, sin)
k_embed = apply_rotary_emb(k, cos, sin)
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
return q_embed, k_embed

View File

@@ -22,7 +22,7 @@ from typing import Annotated, Literal
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange, repeat
from einops import rearrange
from transformers import BatchFeature, PretrainedConfig
from transformers.activations import GELUActivation
from transformers.modeling_outputs import (
@@ -47,7 +47,7 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
dispatch_rotary_emb_function,
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
@@ -130,47 +130,6 @@ def smart_resize(
return h_bar, w_bar
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
def apply_rotary_emb_torch(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch)
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
output = rotary_emb_function(t_, cos, sin).type_as(t)
return output
class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config()
@@ -609,6 +568,10 @@ class SiglipAttention(nn.Module):
multimodal_config=multimodal_config,
prefix=f"{prefix}.attn",
)
self.apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
seq_len, bs, _ = qkv.shape
@@ -651,7 +614,11 @@ class SiglipAttention(nn.Module):
if rotary_pos_emb is not None:
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
qk_rotated = self.apply_rotary_emb(
qk_concat,
rotary_pos_emb.cos(),
rotary_pos_emb.sin(),
)
q, k = torch.chunk(qk_rotated, 2, dim=0)
context_layer = self.attn(

View File

@@ -60,6 +60,9 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.vision import should_torch_compile_mm_vit
@@ -95,7 +98,6 @@ from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
from .qwen2_vl import (
Qwen2VLMultiModalProcessor,
Qwen2VLProcessingInfo,
apply_rotary_pos_emb_vision,
)
from .utils import (
AutoWeightsLoader,
@@ -353,6 +355,8 @@ class Qwen2_5_VisionAttention(nn.Module):
multimodal_config=multimodal_config,
)
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
def forward(
self,
x: torch.Tensor,
@@ -378,8 +382,10 @@ class Qwen2_5_VisionAttention(nn.Module):
qk_reshaped = einops.rearrange(
qk, "b s two head head_dim -> (two b) s head head_dim", two=2
)
qk_rotated = apply_rotary_pos_emb_vision(
qk_reshaped, cos=rotary_pos_emb_cos, sin=rotary_pos_emb_sin
qk_rotated = self.apply_rotary_emb(
qk_reshaped,
rotary_pos_emb_cos,
rotary_pos_emb_sin,
)
qk_rotated = qk_rotated.view(
2,

View File

@@ -59,8 +59,7 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.rotary_embedding.common import (
apply_rotary_emb_torch,
dispatch_rotary_emb_function,
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
@@ -280,16 +279,6 @@ class Qwen2VisionMLP(nn.Module):
return x
def apply_rotary_pos_emb_vision(
t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
rotary_emb_function = dispatch_rotary_emb_function(
default=partial(apply_rotary_emb_torch, is_neox_style=True)
)
output = rotary_emb_function(t, cos, sin).type_as(t)
return output
class Qwen2VisionAttention(nn.Module):
def __init__(
self,
@@ -341,6 +330,8 @@ class Qwen2VisionAttention(nn.Module):
multimodal_config=multimodal_config,
)
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
@@ -387,8 +378,10 @@ class Qwen2VisionAttention(nn.Module):
# [2 * b, s, heads, head_dim]
qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(
qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
qk_rotated = self.apply_rotary_emb(
qk_concat,
rotary_pos_emb_cos,
rotary_pos_emb_sin,
)
q, k = torch.chunk(qk_rotated, 2, dim=0)

View File

@@ -713,17 +713,13 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
mm_counts: Mapping[str, int],
) -> int:
target_width, target_height = self.get_image_size_with_most_features()
video_soft_tokens = self.get_num_video_tokens(
num_video_soft_tokens = self.get_num_video_tokens(
image_width=target_width,
image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
image_processor=None,
)
# NOTE: By default in Qwen3-VL, one video token is converted to
# "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501
formatted_video_soft_tokens = video_soft_tokens * 12.5
return int(formatted_video_soft_tokens)
return num_video_soft_tokens
def _calculate_timestamps(
self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int

View File

@@ -6,7 +6,6 @@ within a vision language model."""
from collections.abc import Iterable
import torch
from einops import rearrange, repeat
from torch import nn
from torch.nn import functional as F
from transformers import Siglip2VisionConfig
@@ -26,6 +25,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform
@@ -146,40 +148,6 @@ class Siglip2VisionEmbeddings(nn.Module):
return patch_embeds
# copy from flash_attn/layers/rotary.py
def rotate_half(x, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
@@ -189,14 +157,20 @@ def apply_rotary_pos_emb(
) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
if is_flash_attn_backend and current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
apply_rotary_emb_func = apply_rotary_emb
apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
if is_flash_attn_backend and not current_platform.is_cuda():
apply_rotary_emb_func = apply_rotary_emb.forward_cuda
else:
apply_rotary_emb_func = apply_rotary_emb_torch
q_embed = apply_rotary_emb_func(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb_func(k.float(), cos.float(), sin.float()).type_as(k)
apply_rotary_emb_func = apply_rotary_emb.forward_native
q_embed = apply_rotary_emb_func(q, cos, sin)
k_embed = apply_rotary_emb_func(k, cos, sin)
return q_embed, k_embed

View File

@@ -11,7 +11,7 @@ import torch
from transformers import PretrainedConfig
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config import VllmConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@@ -88,16 +88,10 @@ def get_vit_attn_backend(
"""
Get the available attention backend for Vision Transformer.
"""
attn_backend = attn_backend_override
selected_backend = get_current_vllm_config().attention_config.backend
if attn_backend is None:
attn_backend = selected_backend
return current_platform.get_vit_attn_backend(
head_size,
dtype,
backend=attn_backend,
backend=attn_backend_override,
)

View File

@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from functools import partial
from functools import cached_property, partial
from itertools import accumulate
from typing import (
TYPE_CHECKING,
@@ -169,11 +169,42 @@ class PlaceholderRange:
between `offset` and `offset + length` to assign embeddings to.
"""
def get_num_embeds(self) -> int:
@cached_property
def embeds_cumsum(self) -> torch.Tensor | None:
if self.is_embed is None:
return None
return self.is_embed.cumsum(dim=0)
@cached_property
def get_num_embeds(self) -> int:
if self.embeds_cumsum is None:
return self.length
return int(self.is_embed.sum().item())
return int(self.embeds_cumsum[-1])
def get_embeds_indices_in_range(
self, start_idx: int, end_idx: int
) -> tuple[int, int]:
"""
Returns the starting and ending indices of the embeddings of encoder outputs
in the range of [start_idx, end_idx) in the placeholders.
For example, given:
PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True])
If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get
the second and the third embeddings from the encoder output.
"""
if self.embeds_cumsum is None:
return start_idx, end_idx
embeds_start_idx = (
int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0
)
embeds_end_idx = int(self.embeds_cumsum[end_idx - 1])
return embeds_start_idx, embeds_end_idx
def extract_embeds_range(self) -> list[tuple[int, int]]:
"""Extract the start and end indices of the embedded region in prompt.
@@ -188,7 +219,7 @@ class PlaceholderRange:
Returns full placeholder range if `is_embed` is `None`.
"""
if self.is_embed is None:
return [(self.offset, self.offset + self.length)]
return [(self.offset, self.offset + self.length - 1)]
mask_i = self.is_embed.int()
starts = torch.nonzero(

View File

@@ -274,15 +274,11 @@ class MultiModalProfiler(Generic[_I]):
def _get_mm_num_tokens(
self,
mm_inputs: MultiModalInputs,
mm_embeddings_only: bool = True,
) -> Mapping[str, int]:
placeholders_by_modality = mm_inputs["mm_placeholders"]
return {
modality: sum(
item.get_num_embeds() if mm_embeddings_only else item.length
for item in placeholders
)
modality: sum(item.get_num_embeds for item in placeholders)
for modality, placeholders in placeholders_by_modality.items()
}
@@ -328,12 +324,15 @@ class MultiModalProfiler(Generic[_I]):
multi_modal_placeholders=mm_inputs["mm_placeholders"],
)
def _get_mm_max_tokens(
def get_mm_max_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
mm_embeddings_only: bool = True,
) -> Mapping[str, int]:
"""
Returns the maximum number of embeddings per item of each modality, excluding
any break/text tokens in-between multimodal embeddings/encoder outputs.
"""
if mm_counts is None:
mm_counts = self.get_mm_limits()
@@ -349,21 +348,4 @@ class MultiModalProfiler(Generic[_I]):
}
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only)
def get_mm_max_contiguous_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
) -> Mapping[str, int]:
"""
Returns the maximum length of the multimodal (image placeholders+text)
tokens, including any break/text tokens in-between image embeddings.
`<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>`
Returns 9, even when the number of image embeddings is 6.
This is important to take into account when profiling and
initializing the encoder cache size.
"""
return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)
return self._get_mm_num_tokens(mm_inputs)

View File

@@ -164,7 +164,7 @@ class MultiModalRegistry:
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
)
return profiler.get_mm_max_contiguous_tokens(
return profiler.get_mm_max_tokens(
seq_len,
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
)

View File

@@ -76,6 +76,39 @@ class TritonAttentionMetadata:
# Optional aot scheduling
scheduler_metadata: torch.Tensor | None = None
prefix_scheduler_metadata: torch.Tensor | None = None
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
@property
def mm_prefix_range_tensor(self) -> torch.Tensor | None:
"""Convert mm_prefix_range dict to padded tensor for Triton kernel.
Returns shape: (num_seqs, max_ranges, 2) with 0-padding for empty ranges.
Empty ranges have start==end==0, which kernel skips via is_valid check.
"""
# TODO(Isotr0py): Move to model runner's attention metadata
# preparation to avoid duplicate computation.
if self.mm_prefix_range is None:
return None
num_seqs = self.seq_lens.shape[0]
device = self.seq_lens.device
# Collect ranges, using [(0,0)] for empty sequences to ensure uniform dims
range_lists = [
self.mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)] for i in range(num_seqs)
]
# Return None if all ranges are trivial (only (0,0) placeholders)
if all(r == [(0, 0)] for r in range_lists):
return None
# Create 2D tensors with shape (num_ranges, 2) for each sequence
range_tensors = [
torch.tensor(r, dtype=torch.int32, device=device).view(-1, 2)
for r in range_lists
]
return torch.nested.nested_tensor(range_tensors).to_padded_tensor(0)
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
@@ -268,6 +301,10 @@ class TritonAttentionBackend(AttentionBackend):
def supports_head_size(cls, head_size: int) -> bool:
return head_size >= 32
@classmethod
def supports_mm_prefix(cls) -> bool:
return True
@classmethod
def supports_sink(cls) -> bool:
return True
@@ -427,6 +464,7 @@ class TritonAttentionImpl(AttentionImpl):
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor
unified_attention(
q=query[:num_actual_tokens],
@@ -453,6 +491,7 @@ class TritonAttentionImpl(AttentionImpl):
softmax_segm_expsum=softmax_segm_expsum,
sinks=self.sinks,
output_scale=output_scale,
mm_prefix_range=mm_prefix_range_tensor,
)
return output

View File

@@ -39,20 +39,26 @@ class EncoderCacheManager:
space for new embeddings.
Oldest cached embeddings with no request referenced will be first evicted.
NOTE: The EncoderCacheManager operates on the level of multimodal embeddings
instead of encoder tokens (i.e. all tokens that represent the multimodal data
in the input sequence). This means all break/text tokens in-between multimodal
embeddings are not considered with respect to the cache size and the number
of free slots.
Args:
cache_size: Limit the size of the cache, measured by the number of
tokens from the input sequence.
encoder embeddings from the input sequence.
Attributes:
cache_size: Total cache capacity in encoder tokens.
num_free_slots: Current available cache capacity in encoder tokens.
cache_size: Total cache capacity in encoder embeddings.
num_free_slots: Current available cache capacity in encoder embeddings.
num_freeable_slots: Capacity that can be immediately reclaimed by
evicting entries with zero references (in encoder tokens).
evicting entries with zero references (in encoder embeddings).
cached: Mapping from mm_hash to a set of request IDs that currently
reference the cached entry. If the set is empty, the entry exists
but is not referenced by any request and is eligible for
reclamation.
freeable: List of tuples (mm_hash, num_tokens) representing entries
freeable: List of tuples (mm_hash, num_encoder_embeds) representing entries
whose no current running request is needed and that can be freed to
make space when needed.
freed: List of mm_hash strings that were actually evicted since the
@@ -67,7 +73,7 @@ class EncoderCacheManager:
# mm_hash of mm_data => ids of requests that reference the mm_data
self.cached: dict[str, set[str]] = {}
# mm_hash of mm_data => num_encoder_tokens of the mm_data
# mm_hash of mm_data => num_encoder_embeds of the mm_data
self.freeable: OrderedDict[str, int] = OrderedDict()
self.freed: list[str] = []
@@ -93,8 +99,8 @@ class EncoderCacheManager:
# Cached but currently not referenced by any request
if not self.cached[mm_hash]:
num_tokens = self.freeable.pop(mm_hash)
self.num_freeable_slots -= num_tokens
num_encoder_embeds = self.freeable.pop(mm_hash)
self.num_freeable_slots -= num_encoder_embeds
self.cached[mm_hash].add(request.request_id)
return True
@@ -104,7 +110,7 @@ class EncoderCacheManager:
request: Request,
input_id: int,
encoder_compute_budget: int,
num_tokens_to_schedule: int,
num_embeds_to_schedule: int,
) -> bool:
"""Check if there's sufficient cache space for a multimodal input.
If there is, return True and update EncoderCacheManager state.
@@ -121,9 +127,9 @@ class EncoderCacheManager:
Args:
request: The request containing the multimodal input.
input_id: Index of the multimodal input within the request.
encoder_compute_budget: Number of encoder tokens allowed to be
encoder_compute_budget: Number of encoder embeddings allowed to be
computed when this method is invoked.
num_tokens_to_schedule: Number of tokens already scheduled to be
num_embeds_to_schedule: Number of encoder embeddings already scheduled to be
allocated with cache space when this method is invoked.
Returns:
@@ -134,30 +140,30 @@ class EncoderCacheManager:
Note: This method does not allocate physical memory for the encoder
output but only the state of EncoderCacheManager.
"""
num_tokens = request.get_num_encoder_tokens(input_id)
num_embeds = request.get_num_encoder_embeds(input_id)
# Not enough compute budget
if num_tokens > encoder_compute_budget:
if num_embeds > encoder_compute_budget:
return False
num_tokens += num_tokens_to_schedule
num_embeds += num_embeds_to_schedule
# Enough free slots
if num_tokens <= self.num_free_slots:
if num_embeds <= self.num_free_slots:
return True
# Not enough reclaimable slots
if num_tokens > self.num_freeable_slots:
if num_embeds > self.num_freeable_slots:
return False
# Not enough free slots but enough reclaimable slots
# NOTE: Eviction takes place here, but physical memory is not freed
# until model runner is notified by the scheduler output.
while num_tokens > self.num_free_slots:
mm_hash, num_free_token = self.freeable.popitem(last=False)
while num_embeds > self.num_free_slots:
mm_hash, num_free_embeds = self.freeable.popitem(last=False)
del self.cached[mm_hash]
self.freed.append(mm_hash)
self.num_free_slots += num_free_token
self.num_free_slots += num_free_embeds
return True
def allocate(self, request: Request, input_id: int) -> None:
@@ -176,16 +182,16 @@ class EncoderCacheManager:
if mm_hash not in self.cached:
self.cached[mm_hash] = set()
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
# NOTE: Encoder cache should always have enough space for encoder inputs
# that are scheduled since eviction takes place at can_allocate().
assert self.num_free_slots >= num_encoder_tokens
assert self.num_freeable_slots >= num_encoder_tokens
assert self.num_free_slots >= num_encoder_embeds
assert self.num_freeable_slots >= num_encoder_embeds
self.cached[mm_hash].add(request_id)
self.num_free_slots -= num_encoder_tokens
self.num_freeable_slots -= num_encoder_tokens
self.num_free_slots -= num_encoder_embeds
self.num_freeable_slots -= num_encoder_embeds
def get_cached_input_ids(self, request: Request) -> set[int]:
"""Get all cached multimodal input IDs for a request.
@@ -206,7 +212,7 @@ class EncoderCacheManager:
When the reference set for the corresponding `mm_hash` becomes empty,
the entry is appended to `freeable` and `num_freeable_slots` is
increased by the number of encoder tokens for that input.
increased by the number of encoder embeddings for that input.
The entry is NOT physically freed until capacity is needed (e.g., by
`can_allocate`).
@@ -218,9 +224,9 @@ class EncoderCacheManager:
return
self.cached[mm_hash].discard(req_id)
if not self.cached[mm_hash]:
num_tokens = request.get_num_encoder_tokens(input_id)
self.freeable[mm_hash] = num_tokens
self.num_freeable_slots += num_tokens
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
self.freeable[mm_hash] = num_encoder_embeds
self.num_freeable_slots += num_encoder_embeds
def free(self, request: Request) -> None:
"""Free all encoder input cache reference held by *request*.
@@ -361,20 +367,20 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
request: Request,
input_id: int,
encoder_compute_budget: int,
num_tokens_to_schedule: int,
num_embeds_to_schedule: int,
) -> bool:
num_tokens = request.get_num_encoder_tokens(input_id)
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
# Not enough compute budget
if num_tokens > encoder_compute_budget:
if num_encoder_embeds > encoder_compute_budget:
return False
num_tokens += num_tokens_to_schedule
num_encoder_embeds += num_embeds_to_schedule
# Enough free slots
return num_tokens <= self.num_free_slots
return num_encoder_embeds <= self.num_free_slots
def allocate(self, request: Request, input_id: int) -> None:
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
self.num_free_slots -= num_encoder_tokens
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
self.num_free_slots -= num_encoder_embeds
mm_hash = request.mm_features[input_id].identifier
self.freed.append(mm_hash)
@@ -392,5 +398,5 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
return freed
def free_encoder_input(self, request: Request, input_id: int) -> None:
num_tokens = request.get_num_encoder_tokens(input_id)
self.num_free_slots += num_tokens
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
self.num_free_slots += num_encoder_embeds

View File

@@ -349,11 +349,11 @@ class Scheduler(SchedulerInterface):
if preempted_encoder_inputs:
# Restore encoder compute budget if the preempted
# request had encoder inputs scheduled in this step.
num_tokens_to_restore = sum(
preempted_req.get_num_encoder_tokens(i)
num_embeds_to_restore = sum(
preempted_req.get_num_encoder_embeds(i)
for i in preempted_encoder_inputs
)
encoder_compute_budget += num_tokens_to_restore
encoder_compute_budget += num_embeds_to_restore
req_index -= 1
else:
preempted_req = self.running.pop()
@@ -911,10 +911,11 @@ class Scheduler(SchedulerInterface):
# multiple encoder inputs per request), we need to create temporary
# trackers for accounting at the encoder input level.
mm_hashes_to_schedule = set()
num_tokens_to_schedule = 0
num_embeds_to_schedule = 0
for i, mm_feature in enumerate(mm_features):
start_pos = mm_feature.mm_position.offset
num_encoder_tokens = mm_feature.mm_position.length
num_encoder_embeds = mm_feature.mm_position.get_num_embeds
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
@@ -970,9 +971,8 @@ class Scheduler(SchedulerInterface):
):
num_new_tokens = start_pos - num_computed_tokens
break
if not self.encoder_cache_manager.can_allocate(
request, i, encoder_compute_budget, num_tokens_to_schedule
request, i, encoder_compute_budget, num_embeds_to_schedule
):
# The encoder cache is full or the encoder budget is exhausted.
# NOTE(woosuk): We assume that the encoder input tokens should
@@ -992,14 +992,31 @@ class Scheduler(SchedulerInterface):
num_new_tokens = 0
break
# Calculate the number of embeddings to schedule in the current range
# of scheduled encoder placholder tokens.
start_idx_rel = max(0, num_computed_tokens - start_pos)
end_idx_rel = min(
num_encoder_tokens, num_computed_tokens + num_new_tokens - start_pos
)
curr_embeds_start, curr_embeds_end = (
mm_feature.mm_position.get_embeds_indices_in_range(
start_idx_rel,
end_idx_rel,
)
)
# There's no embeddings in the current range of encoder placeholder tokens
# so we can skip the encoder input.
if curr_embeds_end - curr_embeds_start == 0:
continue
if self.ec_connector is not None and remote_cache_has_item[i]:
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
external_load_encoder_input.append(i)
num_tokens_to_schedule += num_encoder_tokens
num_embeds_to_schedule += num_encoder_embeds
continue
num_tokens_to_schedule += num_encoder_tokens
encoder_compute_budget -= num_encoder_tokens
num_embeds_to_schedule += num_encoder_embeds
encoder_compute_budget -= num_encoder_embeds
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
encoder_inputs_to_schedule.append(i)

View File

@@ -209,10 +209,10 @@ class Request:
def get_finished_reason(self) -> FinishReason | None:
return RequestStatus.get_finished_reason(self.status)
def get_num_encoder_tokens(self, input_id: int) -> int:
def get_num_encoder_embeds(self, input_id: int) -> int:
assert input_id < len(self.mm_features)
num_tokens = self.mm_features[input_id].mm_position.length
return num_tokens
num_embeds = self.mm_features[input_id].mm_position.get_num_embeds
return num_embeds
def record_event(
self,

View File

@@ -169,9 +169,7 @@ from .utils import (
MultiModalBudget,
add_kv_sharing_layers_to_kv_cache_groups,
bind_kv_cache,
gather_mm_placeholders,
sanity_check_mm_encoder_outputs,
scatter_mm_placeholders,
)
if TYPE_CHECKING:
@@ -2187,10 +2185,7 @@ class GPUModelRunner(
# Cache the encoder outputs by mm_hash
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
self.encoder_cache[mm_hash] = scatter_mm_placeholders(
output,
is_embed=pos_info.is_embed,
)
self.encoder_cache[mm_hash] = output
logger.debug("Finish execute for mm hash %s", mm_hash)
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
@@ -2241,6 +2236,13 @@ class GPUModelRunner(
num_encoder_tokens,
)
assert start_idx < end_idx
curr_embeds_start, curr_embeds_end = (
pos_info.get_embeds_indices_in_range(start_idx, end_idx)
)
# If there are no embeddings in the current range, we skip
# gathering the embeddings.
if curr_embeds_start == curr_embeds_end:
continue
mm_hash = mm_feature.identifier
encoder_output = self.encoder_cache.get(mm_hash, None)
@@ -2248,16 +2250,14 @@ class GPUModelRunner(
if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]
mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end]
else:
mm_embeds_item = encoder_output[start_idx:end_idx]
req_start_pos = req_start_idx + start_pos - num_computed_tokens
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
True if is_embed is None else is_embed
)
mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
mm_embeds_req.append(mm_embeds_item)
if self.is_multimodal_pruning_enabled and self.uses_mrope:
@@ -4467,31 +4467,8 @@ class GPUModelRunner(
dummy_encoder_outputs,
expected_num_items=max_mm_items_per_batch,
)
# NOTE: This happens when encoder cache needs to store
# the embeddings that encoder outputs are scattered onto.
# In this case we create dummy embeddings of size
# (max_tokens_for_modality, hidden_size) and scatter
# encoder output into it.
encoder_output_shape = dummy_encoder_outputs[0].shape
max_mm_tokens_per_item = mm_budget.max_tokens_by_modality[
dummy_modality
]
if encoder_output_shape[0] < max_mm_tokens_per_item:
encoder_hidden_size = encoder_output_shape[-1]
expanded_outputs = []
for output in dummy_encoder_outputs:
expanded = output.new_zeros(
(max_mm_tokens_per_item, encoder_hidden_size)
)
num_tokens = output.shape[0]
expanded[:num_tokens].copy_(output)
expanded_outputs.append(expanded)
dummy_encoder_outputs = expanded_outputs
# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
for i, output in enumerate(dummy_encoder_outputs):
self.encoder_cache[f"tmp_{i}"] = output
# Add `is_profile` here to pre-allocate communication buffers
hidden_states, last_hidden_states = self._dummy_run(

View File

@@ -4,10 +4,12 @@ from collections import defaultdict
from dataclasses import dataclass, field
import torch
from typing_extensions import deprecated
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.cache import processor_only_cache_from_config
@@ -17,6 +19,8 @@ from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
logger = init_logger(__name__)
class MultiModalBudget:
"""Helper class to calculate budget information for multi-modal models."""
@@ -198,6 +202,7 @@ def sanity_check_mm_encoder_outputs(
)
@deprecated("`scatter_mm_placeholders` is deprecated and will be removed in v0.15.0.")
def scatter_mm_placeholders(
embeds: torch.Tensor,
is_embed: torch.Tensor | None,
@@ -226,6 +231,7 @@ def scatter_mm_placeholders(
return placeholders
@deprecated("`gather_mm_placeholders` is deprecated and will be removed in v0.15.0.")
def gather_mm_placeholders(
placeholders: torch.Tensor,
is_embed: torch.Tensor | None,

View File

@@ -145,12 +145,20 @@ class WorkspaceManager:
for ubatch_id in range(self._num_ubatches):
current_workspace = self._current_workspaces[ubatch_id]
if current_workspace is None:
if (
current_workspace is None
or self._workspace_size_bytes(current_workspace) < required_bytes
):
# Delete old tensor before allocating new one to avoid
# memory spike from resize_(). resize_() allocates new
# memory before freeing old, which can cause OOM.
# Must clear the list reference first since local var
# is just a copy of the reference.
self._current_workspaces[ubatch_id] = None
del current_workspace
self._current_workspaces[ubatch_id] = torch.empty(
(required_bytes,), dtype=torch.uint8, device=self._device
)
elif self._workspace_size_bytes(current_workspace) < required_bytes:
current_workspace.resize_(required_bytes)
if envs.VLLM_DEBUG_WORKSPACE:
logger.info(