Compare commits

...

6 Commits

Author SHA1 Message Date
TJian
f34eca5f01 [ROCm] [Bugfix] Fix torch sdpa hallucination (#30789)
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
(cherry picked from commit 2410132bb1)
2025-12-16 17:16:25 -08:00
Wentao Ye
4cd332f3cf [CI] Skip ci failure test (#30804)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
(cherry picked from commit b6ec077e05)
2025-12-16 17:16:08 -08:00
Roger Wang
16484d394c [Core][MM] Optimize encoder cache manager by operating with embeddings only (#30475)
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Sun Kim <sunytokki@gmail.com>
(cherry picked from commit f5f51e5931)
2025-12-16 17:15:49 -08:00
Isotr0py
e397bd6592 [CI/Build] Skip broken ViT backend functionality test tempoarily (#30782)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
(cherry picked from commit 4de08ad698)
2025-12-16 17:15:26 -08:00
Isotr0py
6a88d590bb [Bugfix] Fix broken ViT attention selection for Blackwell device (#30731)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
(cherry picked from commit e94384bbad)
2025-12-16 17:13:54 -08:00
Shanshan Shen
ad8c073131 [CustomOp] Extract ApplyRotaryEmb as CustomOp and unify the dispatch logic (#29873)
Signed-off-by: shen-shanshan <467638484@qq.com>
Co-authored-by: gcanlin <canlinguosdu@gmail.com>
Co-authored-by: TJian <tunjian.tan@embeddedllm.com>
(cherry picked from commit 3bd9c49158)
2025-12-16 17:13:23 -08:00
32 changed files with 873 additions and 419 deletions

View File

@@ -523,6 +523,8 @@ CUSTOM_OPS_QUANT_RMS_NORM = ["+quant_fp8,+rms_norm"]
list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)), list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)),
) )
@pytest.mark.parametrize("inductor_graph_partition", [True, False]) @pytest.mark.parametrize("inductor_graph_partition", [True, False])
# TODO: remove skip after we fix the fusion thoroughly
@pytest.mark.skipif(is_blackwell(), reason="Temporarily disabled on Blackwell")
def test_rms_group_quant( def test_rms_group_quant(
model_name: str, model_name: str,
model_kwargs: dict[str, Any], model_kwargs: dict[str, Any],
@@ -562,7 +564,7 @@ def test_rms_group_quant(
splitting_ops=splitting_ops, splitting_ops=splitting_ops,
# Common # Common
mode=CompilationMode.VLLM_COMPILE, mode=CompilationMode.VLLM_COMPILE,
pass_config=PassConfig(eliminate_noops=True, enable_fusion=True), pass_config=PassConfig(eliminate_noops=True, fuse_norm_quant=True),
# Inductor caches custom passes by default as well via uuid # Inductor caches custom passes by default as well via uuid
inductor_compile_config={"force_disable_caches": True}, inductor_compile_config={"force_disable_caches": True},
) )

View File

@@ -0,0 +1,203 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for ApplyRotaryEmb CustomOp dispatch behavior.
This test ensures that RotaryEmbedding classes correctly call the appropriate
ApplyRotaryEmb methods based on the calling context:
1. RotaryEmbedding.forward_native() -> ApplyRotaryEmb.forward_native()
2. RotaryEmbedding.forward_cuda() -> ApplyRotaryEmb.forward() (auto-dispatch)
3. RotaryEmbedding.forward_hip() -> ApplyRotaryEmb.forward() (auto-dispatch)
"""
from dataclasses import dataclass
import pytest
import torch
from vllm.config import (
CompilationConfig,
VllmConfig,
get_cached_compilation_config,
set_current_vllm_config,
)
from vllm.platforms import current_platform
CUDA_DEVICES = ["cuda:0"]
@dataclass
class RotaryEmbeddingTestCase:
"""Test case configuration for RotaryEmbedding dispatch tests."""
name: str
rope_class: type
rope_kwargs: dict
method_name: str # forward_native, forward_cuda, forward
positions_shape: tuple # (num_tokens,) or (3, num_tokens) or (4, num_tokens)
expect_forward_native: bool # Should call ApplyRotaryEmb.forward_native()
expect_forward: bool # Should call ApplyRotaryEmb.forward()
def get_test_cases() -> list[RotaryEmbeddingTestCase]:
"""Generate test cases for all RotaryEmbedding classes."""
from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import (
Ernie4_5_VLRotaryEmbedding,
)
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
from vllm.model_executor.layers.rotary_embedding.xdrope import XDRotaryEmbedding
common_kwargs = {
"head_size": 128,
"rotary_dim": 128,
"max_position_embeddings": 4096,
"base": 10000,
"is_neox_style": True,
"dtype": torch.bfloat16,
}
return [
# MRotaryEmbedding tests
RotaryEmbeddingTestCase(
name="MRotaryEmbedding.forward_native",
rope_class=MRotaryEmbedding,
rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]},
method_name="forward_native",
positions_shape=(3, 32), # 2D for multimodal
expect_forward_native=True,
expect_forward=False,
),
RotaryEmbeddingTestCase(
name="MRotaryEmbedding.forward_cuda_1d",
rope_class=MRotaryEmbedding,
rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]},
method_name="forward_cuda",
positions_shape=(32,), # 1D triggers apply_rotary_emb path
expect_forward_native=False,
expect_forward=True,
),
# XDRotaryEmbedding tests
RotaryEmbeddingTestCase(
name="XDRotaryEmbedding.forward",
rope_class=XDRotaryEmbedding,
rope_kwargs={
**common_kwargs,
"scaling_alpha": 1.0,
"xdrope_section": [16, 16, 16, 16],
},
method_name="forward",
positions_shape=(4, 32), # 4D for P/W/H/T
expect_forward_native=False,
expect_forward=True,
),
# Ernie4_5_VLRotaryEmbedding tests
RotaryEmbeddingTestCase(
name="Ernie4_5_VLRotaryEmbedding.forward_native",
rope_class=Ernie4_5_VLRotaryEmbedding,
rope_kwargs={**common_kwargs, "mrope_section": [22, 22, 20]},
method_name="forward_native",
positions_shape=(3, 32), # 2D for multimodal
expect_forward_native=True,
expect_forward=False,
),
]
def run_dispatch_test(
test_case: RotaryEmbeddingTestCase,
device: str,
):
"""Run a dispatch test for a RotaryEmbedding class."""
vllm_config = VllmConfig(
compilation_config=CompilationConfig(custom_ops=["all", "+apply_rotary_emb"])
)
get_cached_compilation_config.cache_clear()
with set_current_vllm_config(vllm_config):
rope = test_case.rope_class(**test_case.rope_kwargs).to(device=device)
apply_rotary_emb = rope.apply_rotary_emb
# Verify custom op is enabled
if test_case.expect_forward_native:
assert (
apply_rotary_emb._forward_method != apply_rotary_emb.forward_native
), "Test setup error: ApplyRotaryEmb custom op should be enabled"
# Setup call tracking
call_tracker = {"forward_native_called": False, "forward_called": False}
original_forward_native = apply_rotary_emb.forward_native
original_forward = apply_rotary_emb.forward
def tracked_forward_native(*args, **kwargs):
call_tracker["forward_native_called"] = True
return original_forward_native(*args, **kwargs)
def tracked_forward(*args, **kwargs):
call_tracker["forward_called"] = True
return original_forward(*args, **kwargs)
apply_rotary_emb.forward_native = tracked_forward_native
apply_rotary_emb.forward = tracked_forward
try:
num_tokens = test_case.positions_shape[-1]
num_q_heads = 8
num_kv_heads = 2
head_size = test_case.rope_kwargs["head_size"]
max_position = test_case.rope_kwargs["max_position_embeddings"]
positions = torch.randint(
0, max_position // 4, test_case.positions_shape, device=device
)
query = torch.randn(
num_tokens, num_q_heads * head_size, dtype=torch.bfloat16, device=device
)
key = torch.randn(
num_tokens,
num_kv_heads * head_size,
dtype=torch.bfloat16,
device=device,
)
# Call the method under test
method = getattr(rope, test_case.method_name)
method(positions, query.clone(), key.clone())
# Verify expectations
if test_case.expect_forward_native:
assert call_tracker["forward_native_called"], (
f"{test_case.name} should call ApplyRotaryEmb.forward_native()"
)
if not test_case.expect_forward:
assert not call_tracker["forward_called"], (
f"{test_case.name} should NOT call ApplyRotaryEmb.forward(). "
"Bug: when +apply_rotary_emb is enabled, forward_native() "
"incorrectly dispatches to CUDA/HIP kernels."
)
if test_case.expect_forward:
assert call_tracker["forward_called"], (
f"{test_case.name} should call ApplyRotaryEmb.forward()"
)
finally:
apply_rotary_emb.forward_native = original_forward_native
apply_rotary_emb.forward = original_forward
@pytest.mark.skipif(
not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests."
)
@pytest.mark.parametrize("test_case", get_test_cases(), ids=lambda tc: tc.name)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_rotary_embedding_dispatch(
test_case: RotaryEmbeddingTestCase,
device: str,
):
"""
Test that RotaryEmbedding classes dispatch to the correct ApplyRotaryEmb method.
- forward_native methods should call ApplyRotaryEmb.forward_native()
- forward_cuda/forward methods should call ApplyRotaryEmb.forward()
"""
run_dispatch_test(test_case, device)

View File

@@ -388,6 +388,7 @@ def run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner):
"mm_encoder_attn_backend", "mm_encoder_attn_backend",
[None] + current_platform.get_supported_vit_attn_backends(), [None] + current_platform.get_supported_vit_attn_backends(),
) )
@pytest.mark.skip(reason="Broken test due to memory segmentation fault")
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_vit_backend_functionality( def test_vit_backend_functionality(
model_key: str, model_key: str,

View File

@@ -60,12 +60,12 @@ def test_profiling(model_id: str, max_model_len: int):
total_num_patches.item() + num_tiles.item() + 3 total_num_patches.item() + num_tiles.item() + 3
) # image start, image, image end ) # image start, image, image end
profiled_tokens = profiler.get_mm_max_contiguous_tokens( profiled_tokens = profiler.get_mm_max_tokens(
max_model_len, max_model_len,
mm_counts=mm_counts, mm_counts=mm_counts,
) )
assert total_tokens == profiled_tokens["image"] assert total_num_patches == profiled_tokens["image"]
assert total_tokens == sum( assert total_tokens == sum(
placeholder.length placeholder.length
for placeholder in decoder_dummy_data.multi_modal_placeholders["image"] for placeholder in decoder_dummy_data.multi_modal_placeholders["image"]

View File

@@ -9,6 +9,7 @@ from tempfile import NamedTemporaryFile, TemporaryDirectory
import numpy as np import numpy as np
import pytest import pytest
import torch
from PIL import Image, ImageChops from PIL import Image, ImageChops
from vllm.multimodal.image import convert_image_mode from vllm.multimodal.image import convert_image_mode
@@ -410,6 +411,97 @@ def test_argsort_mm_positions(case):
assert modality_idxs == expected_modality_idxs assert modality_idxs == expected_modality_idxs
@pytest.mark.parametrize(
"is_embed,expected",
[
(None, 5),
(torch.tensor([True, True, True, True, True]), 5),
(torch.tensor([False, False, False, False, False]), 0),
(torch.tensor([True, False, True, False, True]), 3),
(torch.tensor([True]), 1),
],
)
def test_placeholder_range_get_num_embeds(is_embed, expected):
length = len(is_embed) if is_embed is not None else 5
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
assert pr.get_num_embeds == expected
@pytest.mark.parametrize(
"is_embed,expected",
[
(None, None),
(
torch.tensor([False, True, False, True, True]),
torch.tensor([0, 1, 1, 2, 3]),
),
(torch.tensor([True, True, True]), torch.tensor([1, 2, 3])),
],
)
def test_placeholder_range_embeds_cumsum(is_embed, expected):
length = len(is_embed) if is_embed is not None else 5
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
if expected is None:
assert pr.embeds_cumsum is None
return
assert torch.equal(pr.embeds_cumsum, expected)
# cached_property should return the same object on repeated access
assert pr.embeds_cumsum is pr.embeds_cumsum
@pytest.mark.parametrize(
"is_embed,start_idx,end_idx,expected",
[
(None, 2, 4, (2, 4)),
(
torch.tensor([False, True, False, True, True]),
3,
5,
(1, 3),
),
(
torch.tensor([False, True, False, True, True]),
0,
2,
(0, 1),
),
(
torch.tensor([True, False, True, False]),
2,
2,
(1, 1),
),
],
)
def test_placeholder_range_get_embeds_indices_in_range(
is_embed, start_idx, end_idx, expected
):
length = len(is_embed) if is_embed is not None else 5
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
assert pr.get_embeds_indices_in_range(start_idx, end_idx) == expected
@pytest.mark.parametrize(
"offset,is_embed,expected",
[
(0, None, [(0, 4)]),
(
2,
torch.tensor([False, True, False, True, True]),
[(3, 3), (5, 6)],
),
(0, torch.tensor([True, True, True, True]), [(0, 3)]),
(0, torch.tensor([False, False, False, False]), []),
],
)
def test_placeholder_range_extract_embeds_range(offset, is_embed, expected):
length = len(is_embed) if is_embed is not None else 5
pr = PlaceholderRange(offset=offset, length=length, is_embed=is_embed)
assert pr.extract_embeds_range() == expected
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS) @pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
@pytest.mark.parametrize("num_frames", [-1, 32, 1800]) @pytest.mark.parametrize("num_frames", [-1, 32, 1800])

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest import pytest
import torch
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
@@ -23,7 +24,7 @@ class MockRequest:
) )
self.mm_features.append(feature) self.mm_features.append(feature)
def get_num_encoder_tokens(self, input_id: int) -> int: def get_num_encoder_embeds(self, input_id: int) -> int:
return self._token_counts[input_id] return self._token_counts[input_id]
@@ -162,8 +163,8 @@ def test_schedule_request_multi_images_respect_space_limit():
num_tokens_to_schedule = 0 num_tokens_to_schedule = 0
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule) assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
num_tokens_to_schedule += req.get_num_encoder_tokens(0) num_tokens_to_schedule += req.get_num_encoder_embeds(0)
compute_budget -= req.get_num_encoder_tokens(0) compute_budget -= req.get_num_encoder_embeds(0)
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule) assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
@@ -174,7 +175,75 @@ def test_schedule_request_multi_images_respect_compute_limit():
compute_budget = 10 compute_budget = 10
num_tokens_to_schedule = 0 num_tokens_to_schedule = 0
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule) assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
num_tokens_to_schedule += req.get_num_encoder_tokens(0) num_tokens_to_schedule += req.get_num_encoder_embeds(0)
compute_budget -= req.get_num_encoder_tokens(0) compute_budget -= req.get_num_encoder_embeds(0)
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule) assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
def test_encoder_cache_with_is_embed_mask():
class MockRequestWithMask(MockRequest):
def get_num_encoder_embeds(self, input_id: int) -> int:
return self.mm_features[input_id].mm_position.get_num_embeds
is_embed = torch.zeros(100, dtype=torch.bool)
is_embed[torch.tensor([5, 15, 25, 35, 45, 55, 65, 75])] = True
request = MockRequestWithMask("r1", ["img1"], [100])
request.mm_features[0] = MultiModalFeatureSpec(
data=None,
modality="image",
identifier="img1",
mm_position=PlaceholderRange(offset=0, length=100, is_embed=is_embed),
)
manager = EncoderCacheManager(cache_size=100)
manager.allocate(request, 0)
assert manager.num_free_slots == 92
assert "img1" in manager.cached
old_size = 100
new_size = request.mm_features[0].mm_position.get_num_embeds
assert new_size == 8
savings_ratio = old_size / new_size
assert savings_ratio == 12.5
def test_encoder_cache_mask_based_retrieval():
class MockRequestWithMask(MockRequest):
def get_num_encoder_embeds(self, input_id: int) -> int:
return self.mm_features[input_id].mm_position.get_num_embeds
is_embed = torch.tensor(
[False, False, True, True, False, True, True, True, False, False]
)
request = MockRequestWithMask("r1", ["img1"], [10])
request.mm_features[0] = MultiModalFeatureSpec(
data=None,
modality="image",
identifier="img1",
mm_position=PlaceholderRange(offset=0, length=10, is_embed=is_embed),
)
manager = EncoderCacheManager(cache_size=50)
manager.allocate(request, 0)
assert request.mm_features[0].mm_position.get_num_embeds == 5
start_idx = 2
end_idx = 8
num_embeds_before = is_embed[:start_idx].sum().item()
num_embeds_in_range = is_embed[start_idx:end_idx].sum().item()
assert num_embeds_before == 0
assert num_embeds_in_range == 5
start_idx = 0
end_idx = 5
num_embeds_before = is_embed[:start_idx].sum().item() if start_idx > 0 else 0
num_embeds_in_range = is_embed[start_idx:end_idx].sum().item()
assert num_embeds_before == 0
assert num_embeds_in_range == 2

View File

@@ -38,7 +38,7 @@ class MockRequest:
) )
self.mm_features.append(feature) self.mm_features.append(feature)
def get_num_encoder_tokens(self, input_id: int) -> int: def get_num_encoder_embeds(self, input_id: int) -> int:
assert input_id < len(self._token_counts) assert input_id < len(self._token_counts)
return self._token_counts[input_id] return self._token_counts[input_id]

View File

@@ -16,6 +16,7 @@ import einops
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
@@ -89,6 +90,13 @@ def torch_sdpa_wrapper(
v: torch.Tensor, v: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# Never remove the contiguous logic for ROCm
# Without it, hallucinations occur with the backend
if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
outputs = [] outputs = []
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()

View File

@@ -144,7 +144,7 @@ class ECExampleConnector(ECConnectorBase):
Update ECConnector state after encoder cache allocation. Update ECConnector state after encoder cache allocation.
""" """
mm_hash = request.mm_features[index].identifier mm_hash = request.mm_features[index].identifier
num_encoder_token = request.get_num_encoder_tokens(index) num_encoder_token = request.get_num_encoder_embeds(index)
# Insert mm_hash only if this block has not been recorded yet. # Insert mm_hash only if this block has not been recorded yet.
self._mm_datas_need_loads[mm_hash] = num_encoder_token self._mm_datas_need_loads[mm_hash] = num_encoder_token

View File

@@ -7,7 +7,7 @@ import torch
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from .common import apply_rotary_emb_torch from .common import ApplyRotaryEmb
@CustomOp.register("rotary_embedding") @CustomOp.register("rotary_embedding")
@@ -49,6 +49,10 @@ class RotaryEmbeddingBase(CustomOp):
rocm_aiter_ops.is_triton_rotary_embed_enabled() rocm_aiter_ops.is_triton_rotary_embed_enabled()
) )
self.apply_rotary_emb = ApplyRotaryEmb(
is_neox_style=self.is_neox_style,
)
def _compute_inv_freq(self, base: float) -> torch.Tensor: def _compute_inv_freq(self, base: float) -> torch.Tensor:
"""Compute the inverse frequency.""" """Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to # NOTE(woosuk): To exactly match the HF implementation, we need to
@@ -123,7 +127,12 @@ class RotaryEmbedding(RotaryEmbeddingBase):
query = query.view(num_tokens, -1, head_size) query = query.view(num_tokens, -1, head_size)
query_rot = query[..., :rotary_dim] query_rot = query[..., :rotary_dim]
query_pass = query[..., rotary_dim:] query_pass = query[..., rotary_dim:]
query_rot = apply_rotary_emb_torch(query_rot, cos, sin, is_neox_style) query_rot = ApplyRotaryEmb.forward_static(
query_rot,
cos,
sin,
is_neox_style,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
# key may be None in some cases, e.g. cross-layer KV sharing # key may be None in some cases, e.g. cross-layer KV sharing
@@ -132,7 +141,12 @@ class RotaryEmbedding(RotaryEmbeddingBase):
key = key.view(num_tokens, -1, head_size) key = key.view(num_tokens, -1, head_size)
key_rot = key[..., :rotary_dim] key_rot = key[..., :rotary_dim]
key_pass = key[..., rotary_dim:] key_pass = key[..., rotary_dim:]
key_rot = apply_rotary_emb_torch(key_rot, cos, sin, is_neox_style) key_rot = ApplyRotaryEmb.forward_static(
key_rot,
cos,
sin,
is_neox_style,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key return query, key

View File

@@ -2,19 +2,14 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math import math
from collections.abc import Callable
from functools import cache
from importlib.util import find_spec from importlib.util import find_spec
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.model_executor.custom_op import CustomOp
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -32,71 +27,6 @@ def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
return x.flatten(-2) return x.flatten(-2)
def apply_rotary_emb_torch(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
def apply_rotary_emb_dispatch(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, is_neox_style: bool
) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
if current_platform.is_cuda():
return apply_rotary_emb(x.unsqueeze(0), cos, sin, not is_neox_style).squeeze(0)
else:
return apply_rotary_emb_torch(x, cos, sin, is_neox_style)
@cache
def dispatch_rotary_emb_function(
default: Callable[..., torch.Tensor] | None = None,
) -> Callable[..., torch.Tensor]:
if current_platform.is_cuda():
return apply_rotary_emb
# if torch compile is not enabled
# use rotary embedding function from flash_attn package
# otherwise use the naive pytorch embedding implementation
# is faster when torch compile is enabled.
if current_platform.is_rocm() and not torch.compiler.is_compiling():
if find_spec("flash_attn") is not None:
from flash_attn.ops.triton.rotary import apply_rotary
return apply_rotary
else:
logger.warning(
"flash_attn is not installed. Falling back to PyTorch "
"implementation for rotary embeddings."
)
if default is not None:
return default
return apply_rotary_emb_torch
# yarn functions # yarn functions
# Inverse dim formula to find dim based on number of rotations # Inverse dim formula to find dim based on number of rotations
def yarn_find_correction_dim( def yarn_find_correction_dim(
@@ -186,3 +116,155 @@ direct_register_custom_op(
mutates_args=["query", "key"], # These tensors are modified in-place mutates_args=["query", "key"], # These tensors are modified in-place
fake_impl=_flashinfer_rotary_embedding_fake, fake_impl=_flashinfer_rotary_embedding_fake,
) )
@CustomOp.register("apply_rotary_emb")
class ApplyRotaryEmb(CustomOp):
def __init__(
self,
enforce_enable: bool = False,
is_neox_style: bool = True,
enable_fp32_compute: bool = False,
) -> None:
super().__init__(enforce_enable)
self.is_neox_style = is_neox_style
self.enable_fp32_compute = enable_fp32_compute
self.apply_rotary_emb_flash_attn = None
if find_spec("flash_attn") is not None:
from flash_attn.ops.triton.rotary import apply_rotary
self.apply_rotary_emb_flash_attn = apply_rotary
@staticmethod
def forward_static(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool = True,
enable_fp32_compute: bool = False,
) -> torch.Tensor:
"""
Args:
x: [batch_size (optional), seq_len, num_heads, head_size]
cos: [seq_len, head_size // 2]
sin: [seq_len, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style.
enable_fp32_compute: Temporarily convert x, cos, sin to FP32 dtype
for higher accuracy.
"""
origin_dtype = x.dtype
if enable_fp32_compute:
x = x.float()
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
output = torch.cat((o1, o2), dim=-1)
else:
output = torch.stack((o1, o2), dim=-1).flatten(-2)
if enable_fp32_compute:
output = output.to(origin_dtype)
return output
def forward_native(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
output = self.forward_static(
x, cos, sin, self.is_neox_style, self.enable_fp32_compute
)
return output
def forward_cuda(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
origin_dtype = x.dtype
if self.enable_fp32_compute:
x = x.float()
cos = cos.float()
sin = sin.float()
origin_shape = x.shape
if len(origin_shape) == 3:
# x: [seq_len, num_heads, head_size]
x = x.unsqueeze(0)
"""
Arguments of apply_rotary_emb() in vllm_flash_attn:
x: [batch_size, seq_len, nheads, headdim]
cos, sin: [seqlen_rotary, rotary_dim / 2]
interleaved: defalut as False (Neox-style).
...
"""
interleaved = not self.is_neox_style
output = apply_rotary_emb(x, cos, sin, interleaved)
if len(origin_shape) == 3:
output = output.squeeze(0)
if self.enable_fp32_compute:
output = output.to(origin_dtype)
return output
def forward_hip(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
if self.apply_rotary_emb_flash_attn is not None:
origin_dtype = x.dtype
if self.enable_fp32_compute:
x = x.float()
cos = cos.float()
sin = sin.float()
origin_shape = x.shape
if len(origin_shape) == 3:
# x: [seq_len, num_heads, head_size]
x = x.unsqueeze(0)
"""
Arguments of apply_rotary() in flash_attn:
x: [batch_size, seq_len, nheads, headdim]
cos, sin: [seqlen_rotary, rotary_dim / 2]
interleaved: defalut as False (Neox-style).
...
"""
interleaved = not self.is_neox_style
output = self.apply_rotary_emb_flash_attn(
x, cos, sin, interleaved=interleaved
).type_as(x)
if len(origin_shape) == 3:
output = output.squeeze(0)
if self.enable_fp32_compute:
output = output.to(origin_dtype)
else:
# Falling back to PyTorch native implementation.
output = self.forward_native(x, cos, sin)
return output
def extra_repr(self) -> str:
s = f"is_neox_style={self.is_neox_style}"
s += f"enable_fp32_compute={self.enable_fp32_compute}"
return s

View File

@@ -4,7 +4,6 @@
import torch import torch
from .common import apply_rotary_emb_dispatch
from .mrope import MRotaryEmbedding from .mrope import MRotaryEmbedding
@@ -55,14 +54,22 @@ class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
query = query.view(num_tokens, -1, self.head_size) query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim] query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :] query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) query_rot = self.apply_rotary_emb.forward_native(
query_rot,
cos,
sin,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size) key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim] key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :] key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) key_rot = self.apply_rotary_emb.forward_native(
key_rot,
cos,
sin,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key return query, key

View File

@@ -8,7 +8,6 @@ import torch
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from .base import RotaryEmbeddingBase from .base import RotaryEmbeddingBase
from .common import apply_rotary_emb_dispatch
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale
@@ -301,14 +300,22 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
query = query.view(num_tokens, -1, self.head_size) query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim] query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :] query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) query_rot = self.apply_rotary_emb.forward_native(
query_rot,
cos,
sin,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size) key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim] key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :] key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) key_rot = self.apply_rotary_emb.forward_native(
key_rot,
cos,
sin,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key return query, key
@@ -347,13 +354,21 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
query = query.view(num_tokens, -1, self.head_size) query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim] query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :] query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) query_rot = self.apply_rotary_emb(
query_rot,
cos,
sin,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key = key.view(num_tokens, -1, self.head_size) key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim] key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :] key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) key_rot = self.apply_rotary_emb(
key_rot,
cos,
sin,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key return query, key

View File

@@ -4,7 +4,6 @@
import numpy as np import numpy as np
import torch import torch
from .common import apply_rotary_emb_dispatch
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
@@ -36,7 +35,7 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
dtype, dtype,
) )
def forward( def forward_native(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
@@ -68,14 +67,73 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
query = query.view(num_tokens, -1, self.head_size) query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim] query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :] query_pass = query[..., self.rotary_dim :]
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style) query_rot = self.apply_rotary_emb.forward_native(
query_rot,
cos,
sin,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size) key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim] key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :] key_pass = key[..., self.rotary_dim :]
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style) key_rot = self.apply_rotary_emb.forward_native(
key_rot,
cos,
sin,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""PyTorch-native implementation equivalent to forward().
Args:
positions:
[4, num_tokens] (P/W/H/T positions with multimodal inputs)
query: [num_tokens, num_heads * head_size]
key: [num_tokens, num_kv_heads * head_size]
"""
assert positions.ndim == 2
assert key is not None
num_tokens = positions.shape[-1]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
cos = torch.cat(
[m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1
)
sin = torch.cat(
[m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1
)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., : self.rotary_dim]
query_pass = query[..., self.rotary_dim :]
query_rot = self.apply_rotary_emb(
query_rot,
cos,
sin,
)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
key_pass = key[..., self.rotary_dim :]
key_rot = self.apply_rotary_emb(
key_rot,
cos,
sin,
)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key return query, key

View File

@@ -29,6 +29,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import ( from vllm.model_executor.models.interfaces import (
MultiModalEmbeddings, MultiModalEmbeddings,
@@ -158,32 +161,6 @@ class DotsOCRProcessingInfo(Qwen2VLProcessingInfo):
return processor return processor
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
tensor: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
orig_dtype = tensor.dtype
tensor = tensor.float()
cos = freqs.cos()
sin = freqs.sin()
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
output = (tensor * cos) + (rotate_half(tensor) * sin)
output = output.to(orig_dtype)
return output
class VisionRotaryEmbedding(nn.Module): class VisionRotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None: def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__() super().__init__()
@@ -298,6 +275,11 @@ class DotsVisionAttention(nn.Module):
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
) )
self.apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@@ -318,7 +300,11 @@ class DotsVisionAttention(nn.Module):
if rotary_pos_emb is not None: if rotary_pos_emb is not None:
qk_concat = torch.cat([q, k], dim=0) qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) qk_rotated = self.apply_rotary_emb(
qk_concat,
rotary_pos_emb.cos(),
rotary_pos_emb.sin(),
)
q, k = torch.chunk(qk_rotated, 2, dim=0) q, k = torch.chunk(qk_rotated, 2, dim=0)
context_layer = self.attn( context_layer = self.attn(

View File

@@ -33,7 +33,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange
from transformers import BatchFeature from transformers import BatchFeature
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
@@ -53,6 +53,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
@@ -69,7 +72,6 @@ from vllm.multimodal.processing import (
PromptUpdate, PromptUpdate,
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -89,52 +91,6 @@ logger = init_logger(__name__)
# === Vision Transformer === # # === Vision Transformer === #
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)
def apply_rotary_emb_torch(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
apply_rotary_emb = apply_rotary_emb_torch
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
output = apply_rotary_emb(t_, cos, sin).type_as(t)
return output
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int): def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
"""All-gather the input tensor interleavely across model parallel group.""" """All-gather the input tensor interleavely across model parallel group."""
import torch.distributed as dist import torch.distributed as dist
@@ -200,6 +156,11 @@ class Ernie4_5_VisionAttention(nn.Module):
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
) )
self.apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim] # [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape seq_len, bs, _ = qkv.shape
@@ -244,7 +205,11 @@ class Ernie4_5_VisionAttention(nn.Module):
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)) q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
if rotary_pos_emb is not None: if rotary_pos_emb is not None:
qk_concat = torch.cat([q, k], dim=0) qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) qk_rotated = self.apply_rotary_emb(
qk_concat,
rotary_pos_emb.cos(),
rotary_pos_emb.sin(),
)
q, k = torch.chunk(qk_rotated, 2, dim=0) q, k = torch.chunk(qk_rotated, 2, dim=0)
output = self.attn( output = self.attn(

View File

@@ -65,6 +65,9 @@ from vllm.model_executor.layers.linear import (
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -95,7 +98,7 @@ from .interfaces import (
SupportsMultiModal, SupportsMultiModal,
SupportsPP, SupportsPP,
) )
from .qwen2_vl import _create_qwen2vl_field_factory, apply_rotary_pos_emb_vision from .qwen2_vl import _create_qwen2vl_field_factory
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
WeightsMapper, WeightsMapper,
@@ -304,6 +307,8 @@ class Glm4vVisionAttention(nn.Module):
multimodal_config=multimodal_config, multimodal_config=multimodal_config,
) )
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim] # [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape seq_len, bs, _ = qkv.shape
@@ -339,8 +344,10 @@ class Glm4vVisionAttention(nn.Module):
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None: if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
# [2 * b, s, heads, head_dim] # [2 * b, s, heads, head_dim]
qk_concat = torch.cat([q, k], dim=0) qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision( qk_rotated = self.apply_rotary_emb(
qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin qk_concat,
rotary_pos_emb_cos,
rotary_pos_emb_sin,
) )
q, k = torch.chunk(qk_rotated, 2, dim=0) q, k = torch.chunk(qk_rotated, 2, dim=0)

View File

@@ -30,6 +30,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
maybe_remap_kv_scale_name, maybe_remap_kv_scale_name,
@@ -59,7 +62,6 @@ from vllm.multimodal.processing import (
PromptUpdate, PromptUpdate,
) )
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
@@ -341,20 +343,14 @@ def apply_rotary_pos_emb_flashatt(
cos = cos.chunk(2, dim=-1)[0].contiguous() cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous()
if current_platform.is_cuda(): apply_rotary_emb = ApplyRotaryEmb(
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb enforce_enable=True,
elif current_platform.is_rocm(): enable_fp32_compute=True,
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb )
else:
# For other platforms, use PyTorch fallback
from vllm.model_executor.layers.rotary_embedding.common import (
apply_rotary_emb_torch,
)
apply_rotary_emb = partial(apply_rotary_emb_torch, is_neox_style=True) q_embed = apply_rotary_emb(q, cos, sin)
k_embed = apply_rotary_emb(k, cos, sin)
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
return q_embed, k_embed return q_embed, k_embed

View File

@@ -22,7 +22,7 @@ from typing import Annotated, Literal
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from einops import rearrange, repeat from einops import rearrange
from transformers import BatchFeature, PretrainedConfig from transformers import BatchFeature, PretrainedConfig
from transformers.activations import GELUActivation from transformers.activations import GELUActivation
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
@@ -47,7 +47,7 @@ from vllm.model_executor.layers.linear import (
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import ( from vllm.model_executor.layers.rotary_embedding.common import (
dispatch_rotary_emb_function, ApplyRotaryEmb,
) )
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, default_weight_loader,
@@ -130,47 +130,6 @@ def smart_resize(
return h_bar, w_bar return h_bar, w_bar
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
def apply_rotary_emb_torch(
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
) -> torch.Tensor:
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch)
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
output = rotary_emb_function(t_, cos, sin).type_as(t)
return output
class PaddleOCRVLProcessingInfo(BaseProcessingInfo): class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
def get_hf_config(self): def get_hf_config(self):
return self.ctx.get_hf_config() return self.ctx.get_hf_config()
@@ -609,6 +568,10 @@ class SiglipAttention(nn.Module):
multimodal_config=multimodal_config, multimodal_config=multimodal_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
) )
self.apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
seq_len, bs, _ = qkv.shape seq_len, bs, _ = qkv.shape
@@ -651,7 +614,11 @@ class SiglipAttention(nn.Module):
if rotary_pos_emb is not None: if rotary_pos_emb is not None:
qk_concat = torch.cat([q, k], dim=0) qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb) qk_rotated = self.apply_rotary_emb(
qk_concat,
rotary_pos_emb.cos(),
rotary_pos_emb.sin(),
)
q, k = torch.chunk(qk_rotated, 2, dim=0) q, k = torch.chunk(qk_rotated, 2, dim=0)
context_layer = self.attn( context_layer = self.attn(

View File

@@ -60,6 +60,9 @@ from vllm.model_executor.layers.linear import (
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.vision import should_torch_compile_mm_vit from vllm.model_executor.models.vision import should_torch_compile_mm_vit
@@ -95,7 +98,6 @@ from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
from .qwen2_vl import ( from .qwen2_vl import (
Qwen2VLMultiModalProcessor, Qwen2VLMultiModalProcessor,
Qwen2VLProcessingInfo, Qwen2VLProcessingInfo,
apply_rotary_pos_emb_vision,
) )
from .utils import ( from .utils import (
AutoWeightsLoader, AutoWeightsLoader,
@@ -353,6 +355,8 @@ class Qwen2_5_VisionAttention(nn.Module):
multimodal_config=multimodal_config, multimodal_config=multimodal_config,
) )
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
@@ -378,8 +382,10 @@ class Qwen2_5_VisionAttention(nn.Module):
qk_reshaped = einops.rearrange( qk_reshaped = einops.rearrange(
qk, "b s two head head_dim -> (two b) s head head_dim", two=2 qk, "b s two head head_dim -> (two b) s head head_dim", two=2
) )
qk_rotated = apply_rotary_pos_emb_vision( qk_rotated = self.apply_rotary_emb(
qk_reshaped, cos=rotary_pos_emb_cos, sin=rotary_pos_emb_sin qk_reshaped,
rotary_pos_emb_cos,
rotary_pos_emb_sin,
) )
qk_rotated = qk_rotated.view( qk_rotated = qk_rotated.view(
2, 2,

View File

@@ -59,8 +59,7 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.rotary_embedding.common import ( from vllm.model_executor.layers.rotary_embedding.common import (
apply_rotary_emb_torch, ApplyRotaryEmb,
dispatch_rotary_emb_function,
) )
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
@@ -280,16 +279,6 @@ class Qwen2VisionMLP(nn.Module):
return x return x
def apply_rotary_pos_emb_vision(
t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
rotary_emb_function = dispatch_rotary_emb_function(
default=partial(apply_rotary_emb_torch, is_neox_style=True)
)
output = rotary_emb_function(t, cos, sin).type_as(t)
return output
class Qwen2VisionAttention(nn.Module): class Qwen2VisionAttention(nn.Module):
def __init__( def __init__(
self, self,
@@ -341,6 +330,8 @@ class Qwen2VisionAttention(nn.Module):
multimodal_config=multimodal_config, multimodal_config=multimodal_config,
) )
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim] # [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape seq_len, bs, _ = qkv.shape
@@ -387,8 +378,10 @@ class Qwen2VisionAttention(nn.Module):
# [2 * b, s, heads, head_dim] # [2 * b, s, heads, head_dim]
qk_concat = torch.cat([q, k], dim=0) qk_concat = torch.cat([q, k], dim=0)
qk_rotated = apply_rotary_pos_emb_vision( qk_rotated = self.apply_rotary_emb(
qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin qk_concat,
rotary_pos_emb_cos,
rotary_pos_emb_sin,
) )
q, k = torch.chunk(qk_rotated, 2, dim=0) q, k = torch.chunk(qk_rotated, 2, dim=0)

View File

@@ -713,17 +713,13 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> int: ) -> int:
target_width, target_height = self.get_image_size_with_most_features() target_width, target_height = self.get_image_size_with_most_features()
video_soft_tokens = self.get_num_video_tokens( num_video_soft_tokens = self.get_num_video_tokens(
image_width=target_width, image_width=target_width,
image_height=target_height, image_height=target_height,
num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts), num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
image_processor=None, image_processor=None,
) )
return num_video_soft_tokens
# NOTE: By default in Qwen3-VL, one video token is converted to
# "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501
formatted_video_soft_tokens = video_soft_tokens * 12.5
return int(formatted_video_soft_tokens)
def _calculate_timestamps( def _calculate_timestamps(
self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int

View File

@@ -6,7 +6,6 @@ within a vision language model."""
from collections.abc import Iterable from collections.abc import Iterable
import torch import torch
from einops import rearrange, repeat
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from transformers import Siglip2VisionConfig from transformers import Siglip2VisionConfig
@@ -26,6 +25,9 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding.common import (
ApplyRotaryEmb,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform from vllm.platforms import current_platform
@@ -146,40 +148,6 @@ class Siglip2VisionEmbeddings(nn.Module):
return patch_embeds return patch_embeds
# copy from flash_attn/layers/rotary.py
def rotate_half(x, interleaved=False):
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
)
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
sin = repeat(
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
)
return torch.cat(
[
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
x[..., ro_dim:],
],
dim=-1,
)
def apply_rotary_pos_emb( def apply_rotary_pos_emb(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
@@ -189,14 +157,20 @@ def apply_rotary_pos_emb(
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous() cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous()
if is_flash_attn_backend and current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
apply_rotary_emb_func = apply_rotary_emb apply_rotary_emb = ApplyRotaryEmb(
enforce_enable=True,
enable_fp32_compute=True,
)
if is_flash_attn_backend and not current_platform.is_cuda():
apply_rotary_emb_func = apply_rotary_emb.forward_cuda
else: else:
apply_rotary_emb_func = apply_rotary_emb_torch apply_rotary_emb_func = apply_rotary_emb.forward_native
q_embed = apply_rotary_emb_func(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb_func(k.float(), cos.float(), sin.float()).type_as(k) q_embed = apply_rotary_emb_func(q, cos, sin)
k_embed = apply_rotary_emb_func(k, cos, sin)
return q_embed, k_embed return q_embed, k_embed

View File

@@ -11,7 +11,7 @@ import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention.backends.registry import AttentionBackendEnum from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig
from vllm.distributed import ( from vllm.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
@@ -88,16 +88,10 @@ def get_vit_attn_backend(
""" """
Get the available attention backend for Vision Transformer. Get the available attention backend for Vision Transformer.
""" """
attn_backend = attn_backend_override
selected_backend = get_current_vllm_config().attention_config.backend
if attn_backend is None:
attn_backend = selected_backend
return current_platform.get_vit_attn_backend( return current_platform.get_vit_attn_backend(
head_size, head_size,
dtype, dtype,
backend=attn_backend, backend=attn_backend_override,
) )

View File

@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from collections import UserDict, defaultdict from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import cached_property, partial
from itertools import accumulate from itertools import accumulate
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
@@ -169,11 +169,42 @@ class PlaceholderRange:
between `offset` and `offset + length` to assign embeddings to. between `offset` and `offset + length` to assign embeddings to.
""" """
def get_num_embeds(self) -> int: @cached_property
def embeds_cumsum(self) -> torch.Tensor | None:
if self.is_embed is None: if self.is_embed is None:
return None
return self.is_embed.cumsum(dim=0)
@cached_property
def get_num_embeds(self) -> int:
if self.embeds_cumsum is None:
return self.length return self.length
return int(self.is_embed.sum().item()) return int(self.embeds_cumsum[-1])
def get_embeds_indices_in_range(
self, start_idx: int, end_idx: int
) -> tuple[int, int]:
"""
Returns the starting and ending indices of the embeddings of encoder outputs
in the range of [start_idx, end_idx) in the placeholders.
For example, given:
PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True])
If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get
the second and the third embeddings from the encoder output.
"""
if self.embeds_cumsum is None:
return start_idx, end_idx
embeds_start_idx = (
int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0
)
embeds_end_idx = int(self.embeds_cumsum[end_idx - 1])
return embeds_start_idx, embeds_end_idx
def extract_embeds_range(self) -> list[tuple[int, int]]: def extract_embeds_range(self) -> list[tuple[int, int]]:
"""Extract the start and end indices of the embedded region in prompt. """Extract the start and end indices of the embedded region in prompt.
@@ -188,7 +219,7 @@ class PlaceholderRange:
Returns full placeholder range if `is_embed` is `None`. Returns full placeholder range if `is_embed` is `None`.
""" """
if self.is_embed is None: if self.is_embed is None:
return [(self.offset, self.offset + self.length)] return [(self.offset, self.offset + self.length - 1)]
mask_i = self.is_embed.int() mask_i = self.is_embed.int()
starts = torch.nonzero( starts = torch.nonzero(

View File

@@ -274,15 +274,11 @@ class MultiModalProfiler(Generic[_I]):
def _get_mm_num_tokens( def _get_mm_num_tokens(
self, self,
mm_inputs: MultiModalInputs, mm_inputs: MultiModalInputs,
mm_embeddings_only: bool = True,
) -> Mapping[str, int]: ) -> Mapping[str, int]:
placeholders_by_modality = mm_inputs["mm_placeholders"] placeholders_by_modality = mm_inputs["mm_placeholders"]
return { return {
modality: sum( modality: sum(item.get_num_embeds for item in placeholders)
item.get_num_embeds() if mm_embeddings_only else item.length
for item in placeholders
)
for modality, placeholders in placeholders_by_modality.items() for modality, placeholders in placeholders_by_modality.items()
} }
@@ -328,12 +324,15 @@ class MultiModalProfiler(Generic[_I]):
multi_modal_placeholders=mm_inputs["mm_placeholders"], multi_modal_placeholders=mm_inputs["mm_placeholders"],
) )
def _get_mm_max_tokens( def get_mm_max_tokens(
self, self,
seq_len: int, seq_len: int,
mm_counts: Mapping[str, int] | None = None, mm_counts: Mapping[str, int] | None = None,
mm_embeddings_only: bool = True,
) -> Mapping[str, int]: ) -> Mapping[str, int]:
"""
Returns the maximum number of embeddings per item of each modality, excluding
any break/text tokens in-between multimodal embeddings/encoder outputs.
"""
if mm_counts is None: if mm_counts is None:
mm_counts = self.get_mm_limits() mm_counts = self.get_mm_limits()
@@ -349,21 +348,4 @@ class MultiModalProfiler(Generic[_I]):
} }
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only) return self._get_mm_num_tokens(mm_inputs)
def get_mm_max_contiguous_tokens(
self,
seq_len: int,
mm_counts: Mapping[str, int] | None = None,
) -> Mapping[str, int]:
"""
Returns the maximum length of the multimodal (image placeholders+text)
tokens, including any break/text tokens in-between image embeddings.
`<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>`
Returns 9, even when the number of image embeddings is 6.
This is important to take into account when profiling and
initializing the encoder cache size.
"""
return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)

View File

@@ -164,7 +164,7 @@ class MultiModalRegistry:
profiler.get_mm_limits() if profiler_limits is None else profiler_limits profiler.get_mm_limits() if profiler_limits is None else profiler_limits
) )
return profiler.get_mm_max_contiguous_tokens( return profiler.get_mm_max_tokens(
seq_len, seq_len,
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0}, {modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
) )

View File

@@ -39,20 +39,26 @@ class EncoderCacheManager:
space for new embeddings. space for new embeddings.
Oldest cached embeddings with no request referenced will be first evicted. Oldest cached embeddings with no request referenced will be first evicted.
NOTE: The EncoderCacheManager operates on the level of multimodal embeddings
instead of encoder tokens (i.e. all tokens that represent the multimodal data
in the input sequence). This means all break/text tokens in-between multimodal
embeddings are not considered with respect to the cache size and the number
of free slots.
Args: Args:
cache_size: Limit the size of the cache, measured by the number of cache_size: Limit the size of the cache, measured by the number of
tokens from the input sequence. encoder embeddings from the input sequence.
Attributes: Attributes:
cache_size: Total cache capacity in encoder tokens. cache_size: Total cache capacity in encoder embeddings.
num_free_slots: Current available cache capacity in encoder tokens. num_free_slots: Current available cache capacity in encoder embeddings.
num_freeable_slots: Capacity that can be immediately reclaimed by num_freeable_slots: Capacity that can be immediately reclaimed by
evicting entries with zero references (in encoder tokens). evicting entries with zero references (in encoder embeddings).
cached: Mapping from mm_hash to a set of request IDs that currently cached: Mapping from mm_hash to a set of request IDs that currently
reference the cached entry. If the set is empty, the entry exists reference the cached entry. If the set is empty, the entry exists
but is not referenced by any request and is eligible for but is not referenced by any request and is eligible for
reclamation. reclamation.
freeable: List of tuples (mm_hash, num_tokens) representing entries freeable: List of tuples (mm_hash, num_encoder_embeds) representing entries
whose no current running request is needed and that can be freed to whose no current running request is needed and that can be freed to
make space when needed. make space when needed.
freed: List of mm_hash strings that were actually evicted since the freed: List of mm_hash strings that were actually evicted since the
@@ -67,7 +73,7 @@ class EncoderCacheManager:
# mm_hash of mm_data => ids of requests that reference the mm_data # mm_hash of mm_data => ids of requests that reference the mm_data
self.cached: dict[str, set[str]] = {} self.cached: dict[str, set[str]] = {}
# mm_hash of mm_data => num_encoder_tokens of the mm_data # mm_hash of mm_data => num_encoder_embeds of the mm_data
self.freeable: OrderedDict[str, int] = OrderedDict() self.freeable: OrderedDict[str, int] = OrderedDict()
self.freed: list[str] = [] self.freed: list[str] = []
@@ -93,8 +99,8 @@ class EncoderCacheManager:
# Cached but currently not referenced by any request # Cached but currently not referenced by any request
if not self.cached[mm_hash]: if not self.cached[mm_hash]:
num_tokens = self.freeable.pop(mm_hash) num_encoder_embeds = self.freeable.pop(mm_hash)
self.num_freeable_slots -= num_tokens self.num_freeable_slots -= num_encoder_embeds
self.cached[mm_hash].add(request.request_id) self.cached[mm_hash].add(request.request_id)
return True return True
@@ -104,7 +110,7 @@ class EncoderCacheManager:
request: Request, request: Request,
input_id: int, input_id: int,
encoder_compute_budget: int, encoder_compute_budget: int,
num_tokens_to_schedule: int, num_embeds_to_schedule: int,
) -> bool: ) -> bool:
"""Check if there's sufficient cache space for a multimodal input. """Check if there's sufficient cache space for a multimodal input.
If there is, return True and update EncoderCacheManager state. If there is, return True and update EncoderCacheManager state.
@@ -121,9 +127,9 @@ class EncoderCacheManager:
Args: Args:
request: The request containing the multimodal input. request: The request containing the multimodal input.
input_id: Index of the multimodal input within the request. input_id: Index of the multimodal input within the request.
encoder_compute_budget: Number of encoder tokens allowed to be encoder_compute_budget: Number of encoder embeddings allowed to be
computed when this method is invoked. computed when this method is invoked.
num_tokens_to_schedule: Number of tokens already scheduled to be num_embeds_to_schedule: Number of encoder embeddings already scheduled to be
allocated with cache space when this method is invoked. allocated with cache space when this method is invoked.
Returns: Returns:
@@ -134,30 +140,30 @@ class EncoderCacheManager:
Note: This method does not allocate physical memory for the encoder Note: This method does not allocate physical memory for the encoder
output but only the state of EncoderCacheManager. output but only the state of EncoderCacheManager.
""" """
num_tokens = request.get_num_encoder_tokens(input_id) num_embeds = request.get_num_encoder_embeds(input_id)
# Not enough compute budget # Not enough compute budget
if num_tokens > encoder_compute_budget: if num_embeds > encoder_compute_budget:
return False return False
num_tokens += num_tokens_to_schedule num_embeds += num_embeds_to_schedule
# Enough free slots # Enough free slots
if num_tokens <= self.num_free_slots: if num_embeds <= self.num_free_slots:
return True return True
# Not enough reclaimable slots # Not enough reclaimable slots
if num_tokens > self.num_freeable_slots: if num_embeds > self.num_freeable_slots:
return False return False
# Not enough free slots but enough reclaimable slots # Not enough free slots but enough reclaimable slots
# NOTE: Eviction takes place here, but physical memory is not freed # NOTE: Eviction takes place here, but physical memory is not freed
# until model runner is notified by the scheduler output. # until model runner is notified by the scheduler output.
while num_tokens > self.num_free_slots: while num_embeds > self.num_free_slots:
mm_hash, num_free_token = self.freeable.popitem(last=False) mm_hash, num_free_embeds = self.freeable.popitem(last=False)
del self.cached[mm_hash] del self.cached[mm_hash]
self.freed.append(mm_hash) self.freed.append(mm_hash)
self.num_free_slots += num_free_token self.num_free_slots += num_free_embeds
return True return True
def allocate(self, request: Request, input_id: int) -> None: def allocate(self, request: Request, input_id: int) -> None:
@@ -176,16 +182,16 @@ class EncoderCacheManager:
if mm_hash not in self.cached: if mm_hash not in self.cached:
self.cached[mm_hash] = set() self.cached[mm_hash] = set()
num_encoder_tokens = request.get_num_encoder_tokens(input_id) num_encoder_embeds = request.get_num_encoder_embeds(input_id)
# NOTE: Encoder cache should always have enough space for encoder inputs # NOTE: Encoder cache should always have enough space for encoder inputs
# that are scheduled since eviction takes place at can_allocate(). # that are scheduled since eviction takes place at can_allocate().
assert self.num_free_slots >= num_encoder_tokens assert self.num_free_slots >= num_encoder_embeds
assert self.num_freeable_slots >= num_encoder_tokens assert self.num_freeable_slots >= num_encoder_embeds
self.cached[mm_hash].add(request_id) self.cached[mm_hash].add(request_id)
self.num_free_slots -= num_encoder_tokens self.num_free_slots -= num_encoder_embeds
self.num_freeable_slots -= num_encoder_tokens self.num_freeable_slots -= num_encoder_embeds
def get_cached_input_ids(self, request: Request) -> set[int]: def get_cached_input_ids(self, request: Request) -> set[int]:
"""Get all cached multimodal input IDs for a request. """Get all cached multimodal input IDs for a request.
@@ -206,7 +212,7 @@ class EncoderCacheManager:
When the reference set for the corresponding `mm_hash` becomes empty, When the reference set for the corresponding `mm_hash` becomes empty,
the entry is appended to `freeable` and `num_freeable_slots` is the entry is appended to `freeable` and `num_freeable_slots` is
increased by the number of encoder tokens for that input. increased by the number of encoder embeddings for that input.
The entry is NOT physically freed until capacity is needed (e.g., by The entry is NOT physically freed until capacity is needed (e.g., by
`can_allocate`). `can_allocate`).
@@ -218,9 +224,9 @@ class EncoderCacheManager:
return return
self.cached[mm_hash].discard(req_id) self.cached[mm_hash].discard(req_id)
if not self.cached[mm_hash]: if not self.cached[mm_hash]:
num_tokens = request.get_num_encoder_tokens(input_id) num_encoder_embeds = request.get_num_encoder_embeds(input_id)
self.freeable[mm_hash] = num_tokens self.freeable[mm_hash] = num_encoder_embeds
self.num_freeable_slots += num_tokens self.num_freeable_slots += num_encoder_embeds
def free(self, request: Request) -> None: def free(self, request: Request) -> None:
"""Free all encoder input cache reference held by *request*. """Free all encoder input cache reference held by *request*.
@@ -361,20 +367,20 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
request: Request, request: Request,
input_id: int, input_id: int,
encoder_compute_budget: int, encoder_compute_budget: int,
num_tokens_to_schedule: int, num_embeds_to_schedule: int,
) -> bool: ) -> bool:
num_tokens = request.get_num_encoder_tokens(input_id) num_encoder_embeds = request.get_num_encoder_embeds(input_id)
# Not enough compute budget # Not enough compute budget
if num_tokens > encoder_compute_budget: if num_encoder_embeds > encoder_compute_budget:
return False return False
num_tokens += num_tokens_to_schedule num_encoder_embeds += num_embeds_to_schedule
# Enough free slots # Enough free slots
return num_tokens <= self.num_free_slots return num_encoder_embeds <= self.num_free_slots
def allocate(self, request: Request, input_id: int) -> None: def allocate(self, request: Request, input_id: int) -> None:
num_encoder_tokens = request.get_num_encoder_tokens(input_id) num_encoder_embeds = request.get_num_encoder_embeds(input_id)
self.num_free_slots -= num_encoder_tokens self.num_free_slots -= num_encoder_embeds
mm_hash = request.mm_features[input_id].identifier mm_hash = request.mm_features[input_id].identifier
self.freed.append(mm_hash) self.freed.append(mm_hash)
@@ -392,5 +398,5 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
return freed return freed
def free_encoder_input(self, request: Request, input_id: int) -> None: def free_encoder_input(self, request: Request, input_id: int) -> None:
num_tokens = request.get_num_encoder_tokens(input_id) num_encoder_embeds = request.get_num_encoder_embeds(input_id)
self.num_free_slots += num_tokens self.num_free_slots += num_encoder_embeds

View File

@@ -349,11 +349,11 @@ class Scheduler(SchedulerInterface):
if preempted_encoder_inputs: if preempted_encoder_inputs:
# Restore encoder compute budget if the preempted # Restore encoder compute budget if the preempted
# request had encoder inputs scheduled in this step. # request had encoder inputs scheduled in this step.
num_tokens_to_restore = sum( num_embeds_to_restore = sum(
preempted_req.get_num_encoder_tokens(i) preempted_req.get_num_encoder_embeds(i)
for i in preempted_encoder_inputs for i in preempted_encoder_inputs
) )
encoder_compute_budget += num_tokens_to_restore encoder_compute_budget += num_embeds_to_restore
req_index -= 1 req_index -= 1
else: else:
preempted_req = self.running.pop() preempted_req = self.running.pop()
@@ -911,10 +911,11 @@ class Scheduler(SchedulerInterface):
# multiple encoder inputs per request), we need to create temporary # multiple encoder inputs per request), we need to create temporary
# trackers for accounting at the encoder input level. # trackers for accounting at the encoder input level.
mm_hashes_to_schedule = set() mm_hashes_to_schedule = set()
num_tokens_to_schedule = 0 num_embeds_to_schedule = 0
for i, mm_feature in enumerate(mm_features): for i, mm_feature in enumerate(mm_features):
start_pos = mm_feature.mm_position.offset start_pos = mm_feature.mm_position.offset
num_encoder_tokens = mm_feature.mm_position.length num_encoder_tokens = mm_feature.mm_position.length
num_encoder_embeds = mm_feature.mm_position.get_num_embeds
# The encoder output is needed if the two ranges overlap: # The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and # [num_computed_tokens, num_computed_tokens + num_new_tokens) and
@@ -970,9 +971,8 @@ class Scheduler(SchedulerInterface):
): ):
num_new_tokens = start_pos - num_computed_tokens num_new_tokens = start_pos - num_computed_tokens
break break
if not self.encoder_cache_manager.can_allocate( if not self.encoder_cache_manager.can_allocate(
request, i, encoder_compute_budget, num_tokens_to_schedule request, i, encoder_compute_budget, num_embeds_to_schedule
): ):
# The encoder cache is full or the encoder budget is exhausted. # The encoder cache is full or the encoder budget is exhausted.
# NOTE(woosuk): We assume that the encoder input tokens should # NOTE(woosuk): We assume that the encoder input tokens should
@@ -992,14 +992,31 @@ class Scheduler(SchedulerInterface):
num_new_tokens = 0 num_new_tokens = 0
break break
# Calculate the number of embeddings to schedule in the current range
# of scheduled encoder placholder tokens.
start_idx_rel = max(0, num_computed_tokens - start_pos)
end_idx_rel = min(
num_encoder_tokens, num_computed_tokens + num_new_tokens - start_pos
)
curr_embeds_start, curr_embeds_end = (
mm_feature.mm_position.get_embeds_indices_in_range(
start_idx_rel,
end_idx_rel,
)
)
# There's no embeddings in the current range of encoder placeholder tokens
# so we can skip the encoder input.
if curr_embeds_end - curr_embeds_start == 0:
continue
if self.ec_connector is not None and remote_cache_has_item[i]: if self.ec_connector is not None and remote_cache_has_item[i]:
mm_hashes_to_schedule.add(request.mm_features[i].identifier) mm_hashes_to_schedule.add(request.mm_features[i].identifier)
external_load_encoder_input.append(i) external_load_encoder_input.append(i)
num_tokens_to_schedule += num_encoder_tokens num_embeds_to_schedule += num_encoder_embeds
continue continue
num_tokens_to_schedule += num_encoder_tokens num_embeds_to_schedule += num_encoder_embeds
encoder_compute_budget -= num_encoder_tokens encoder_compute_budget -= num_encoder_embeds
mm_hashes_to_schedule.add(request.mm_features[i].identifier) mm_hashes_to_schedule.add(request.mm_features[i].identifier)
encoder_inputs_to_schedule.append(i) encoder_inputs_to_schedule.append(i)

View File

@@ -209,10 +209,10 @@ class Request:
def get_finished_reason(self) -> FinishReason | None: def get_finished_reason(self) -> FinishReason | None:
return RequestStatus.get_finished_reason(self.status) return RequestStatus.get_finished_reason(self.status)
def get_num_encoder_tokens(self, input_id: int) -> int: def get_num_encoder_embeds(self, input_id: int) -> int:
assert input_id < len(self.mm_features) assert input_id < len(self.mm_features)
num_tokens = self.mm_features[input_id].mm_position.length num_embeds = self.mm_features[input_id].mm_position.get_num_embeds
return num_tokens return num_embeds
def record_event( def record_event(
self, self,

View File

@@ -169,9 +169,7 @@ from .utils import (
MultiModalBudget, MultiModalBudget,
add_kv_sharing_layers_to_kv_cache_groups, add_kv_sharing_layers_to_kv_cache_groups,
bind_kv_cache, bind_kv_cache,
gather_mm_placeholders,
sanity_check_mm_encoder_outputs, sanity_check_mm_encoder_outputs,
scatter_mm_placeholders,
) )
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -2187,10 +2185,7 @@ class GPUModelRunner(
# Cache the encoder outputs by mm_hash # Cache the encoder outputs by mm_hash
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
self.encoder_cache[mm_hash] = scatter_mm_placeholders( self.encoder_cache[mm_hash] = output
output,
is_embed=pos_info.is_embed,
)
logger.debug("Finish execute for mm hash %s", mm_hash) logger.debug("Finish execute for mm hash %s", mm_hash)
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash) self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
@@ -2241,6 +2236,13 @@ class GPUModelRunner(
num_encoder_tokens, num_encoder_tokens,
) )
assert start_idx < end_idx assert start_idx < end_idx
curr_embeds_start, curr_embeds_end = (
pos_info.get_embeds_indices_in_range(start_idx, end_idx)
)
# If there are no embeddings in the current range, we skip
# gathering the embeddings.
if curr_embeds_start == curr_embeds_end:
continue
mm_hash = mm_feature.identifier mm_hash = mm_feature.identifier
encoder_output = self.encoder_cache.get(mm_hash, None) encoder_output = self.encoder_cache.get(mm_hash, None)
@@ -2248,16 +2250,14 @@ class GPUModelRunner(
if (is_embed := pos_info.is_embed) is not None: if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx] is_embed = is_embed[start_idx:end_idx]
mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end]
else:
mm_embeds_item = encoder_output[start_idx:end_idx]
req_start_pos = req_start_idx + start_pos - num_computed_tokens req_start_pos = req_start_idx + start_pos - num_computed_tokens
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = ( is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
True if is_embed is None else is_embed True if is_embed is None else is_embed
) )
mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
mm_embeds_req.append(mm_embeds_item) mm_embeds_req.append(mm_embeds_item)
if self.is_multimodal_pruning_enabled and self.uses_mrope: if self.is_multimodal_pruning_enabled and self.uses_mrope:
@@ -4467,31 +4467,8 @@ class GPUModelRunner(
dummy_encoder_outputs, dummy_encoder_outputs,
expected_num_items=max_mm_items_per_batch, expected_num_items=max_mm_items_per_batch,
) )
for i, output in enumerate(dummy_encoder_outputs):
# NOTE: This happens when encoder cache needs to store self.encoder_cache[f"tmp_{i}"] = output
# the embeddings that encoder outputs are scattered onto.
# In this case we create dummy embeddings of size
# (max_tokens_for_modality, hidden_size) and scatter
# encoder output into it.
encoder_output_shape = dummy_encoder_outputs[0].shape
max_mm_tokens_per_item = mm_budget.max_tokens_by_modality[
dummy_modality
]
if encoder_output_shape[0] < max_mm_tokens_per_item:
encoder_hidden_size = encoder_output_shape[-1]
expanded_outputs = []
for output in dummy_encoder_outputs:
expanded = output.new_zeros(
(max_mm_tokens_per_item, encoder_hidden_size)
)
num_tokens = output.shape[0]
expanded[:num_tokens].copy_(output)
expanded_outputs.append(expanded)
dummy_encoder_outputs = expanded_outputs
# Cache the dummy encoder outputs.
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
# Add `is_profile` here to pre-allocate communication buffers # Add `is_profile` here to pre-allocate communication buffers
hidden_states, last_hidden_states = self._dummy_run( hidden_states, last_hidden_states = self._dummy_run(

View File

@@ -4,10 +4,12 @@ from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
import torch import torch
from typing_extensions import deprecated
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.cache import processor_only_cache_from_config from vllm.multimodal.cache import processor_only_cache_from_config
@@ -17,6 +19,8 @@ from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
logger = init_logger(__name__)
class MultiModalBudget: class MultiModalBudget:
"""Helper class to calculate budget information for multi-modal models.""" """Helper class to calculate budget information for multi-modal models."""
@@ -198,6 +202,7 @@ def sanity_check_mm_encoder_outputs(
) )
@deprecated("`scatter_mm_placeholders` is deprecated and will be removed in v0.15.0.")
def scatter_mm_placeholders( def scatter_mm_placeholders(
embeds: torch.Tensor, embeds: torch.Tensor,
is_embed: torch.Tensor | None, is_embed: torch.Tensor | None,
@@ -226,6 +231,7 @@ def scatter_mm_placeholders(
return placeholders return placeholders
@deprecated("`gather_mm_placeholders` is deprecated and will be removed in v0.15.0.")
def gather_mm_placeholders( def gather_mm_placeholders(
placeholders: torch.Tensor, placeholders: torch.Tensor,
is_embed: torch.Tensor | None, is_embed: torch.Tensor | None,