Compare commits
6 Commits
v0.15.2rc0
...
v0.13.0rc2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f34eca5f01 | ||
|
|
4cd332f3cf | ||
|
|
16484d394c | ||
|
|
e397bd6592 | ||
|
|
6a88d590bb | ||
|
|
ad8c073131 |
@@ -523,6 +523,8 @@ CUSTOM_OPS_QUANT_RMS_NORM = ["+quant_fp8,+rms_norm"]
|
|||||||
list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)),
|
list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)),
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||||
|
# TODO: remove skip after we fix the fusion thoroughly
|
||||||
|
@pytest.mark.skipif(is_blackwell(), reason="Temporarily disabled on Blackwell")
|
||||||
def test_rms_group_quant(
|
def test_rms_group_quant(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_kwargs: dict[str, Any],
|
model_kwargs: dict[str, Any],
|
||||||
@@ -562,7 +564,7 @@ def test_rms_group_quant(
|
|||||||
splitting_ops=splitting_ops,
|
splitting_ops=splitting_ops,
|
||||||
# Common
|
# Common
|
||||||
mode=CompilationMode.VLLM_COMPILE,
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
pass_config=PassConfig(eliminate_noops=True, enable_fusion=True),
|
pass_config=PassConfig(eliminate_noops=True, fuse_norm_quant=True),
|
||||||
# Inductor caches custom passes by default as well via uuid
|
# Inductor caches custom passes by default as well via uuid
|
||||||
inductor_compile_config={"force_disable_caches": True},
|
inductor_compile_config={"force_disable_caches": True},
|
||||||
)
|
)
|
||||||
|
|||||||
203
tests/kernels/core/test_apply_rotary_emb.py
Normal file
203
tests/kernels/core/test_apply_rotary_emb.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Tests for ApplyRotaryEmb CustomOp dispatch behavior.
|
||||||
|
|
||||||
|
This test ensures that RotaryEmbedding classes correctly call the appropriate
|
||||||
|
ApplyRotaryEmb methods based on the calling context:
|
||||||
|
|
||||||
|
1. RotaryEmbedding.forward_native() -> ApplyRotaryEmb.forward_native()
|
||||||
|
2. RotaryEmbedding.forward_cuda() -> ApplyRotaryEmb.forward() (auto-dispatch)
|
||||||
|
3. RotaryEmbedding.forward_hip() -> ApplyRotaryEmb.forward() (auto-dispatch)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import (
|
||||||
|
CompilationConfig,
|
||||||
|
VllmConfig,
|
||||||
|
get_cached_compilation_config,
|
||||||
|
set_current_vllm_config,
|
||||||
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
CUDA_DEVICES = ["cuda:0"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RotaryEmbeddingTestCase:
|
||||||
|
"""Test case configuration for RotaryEmbedding dispatch tests."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
rope_class: type
|
||||||
|
rope_kwargs: dict
|
||||||
|
method_name: str # forward_native, forward_cuda, forward
|
||||||
|
positions_shape: tuple # (num_tokens,) or (3, num_tokens) or (4, num_tokens)
|
||||||
|
expect_forward_native: bool # Should call ApplyRotaryEmb.forward_native()
|
||||||
|
expect_forward: bool # Should call ApplyRotaryEmb.forward()
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_cases() -> list[RotaryEmbeddingTestCase]:
|
||||||
|
"""Generate test cases for all RotaryEmbedding classes."""
|
||||||
|
from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import (
|
||||||
|
Ernie4_5_VLRotaryEmbedding,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
|
||||||
|
from vllm.model_executor.layers.rotary_embedding.xdrope import XDRotaryEmbedding
|
||||||
|
|
||||||
|
common_kwargs = {
|
||||||
|
"head_size": 128,
|
||||||
|
"rotary_dim": 128,
|
||||||
|
"max_position_embeddings": 4096,
|
||||||
|
"base": 10000,
|
||||||
|
"is_neox_style": True,
|
||||||
|
"dtype": torch.bfloat16,
|
||||||
|
}
|
||||||
|
|
||||||
|
return [
|
||||||
|
# MRotaryEmbedding tests
|
||||||
|
RotaryEmbeddingTestCase(
|
||||||
|
name="MRotaryEmbedding.forward_native",
|
||||||
|
rope_class=MRotaryEmbedding,
|
||||||
|
rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]},
|
||||||
|
method_name="forward_native",
|
||||||
|
positions_shape=(3, 32), # 2D for multimodal
|
||||||
|
expect_forward_native=True,
|
||||||
|
expect_forward=False,
|
||||||
|
),
|
||||||
|
RotaryEmbeddingTestCase(
|
||||||
|
name="MRotaryEmbedding.forward_cuda_1d",
|
||||||
|
rope_class=MRotaryEmbedding,
|
||||||
|
rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]},
|
||||||
|
method_name="forward_cuda",
|
||||||
|
positions_shape=(32,), # 1D triggers apply_rotary_emb path
|
||||||
|
expect_forward_native=False,
|
||||||
|
expect_forward=True,
|
||||||
|
),
|
||||||
|
# XDRotaryEmbedding tests
|
||||||
|
RotaryEmbeddingTestCase(
|
||||||
|
name="XDRotaryEmbedding.forward",
|
||||||
|
rope_class=XDRotaryEmbedding,
|
||||||
|
rope_kwargs={
|
||||||
|
**common_kwargs,
|
||||||
|
"scaling_alpha": 1.0,
|
||||||
|
"xdrope_section": [16, 16, 16, 16],
|
||||||
|
},
|
||||||
|
method_name="forward",
|
||||||
|
positions_shape=(4, 32), # 4D for P/W/H/T
|
||||||
|
expect_forward_native=False,
|
||||||
|
expect_forward=True,
|
||||||
|
),
|
||||||
|
# Ernie4_5_VLRotaryEmbedding tests
|
||||||
|
RotaryEmbeddingTestCase(
|
||||||
|
name="Ernie4_5_VLRotaryEmbedding.forward_native",
|
||||||
|
rope_class=Ernie4_5_VLRotaryEmbedding,
|
||||||
|
rope_kwargs={**common_kwargs, "mrope_section": [22, 22, 20]},
|
||||||
|
method_name="forward_native",
|
||||||
|
positions_shape=(3, 32), # 2D for multimodal
|
||||||
|
expect_forward_native=True,
|
||||||
|
expect_forward=False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def run_dispatch_test(
|
||||||
|
test_case: RotaryEmbeddingTestCase,
|
||||||
|
device: str,
|
||||||
|
):
|
||||||
|
"""Run a dispatch test for a RotaryEmbedding class."""
|
||||||
|
vllm_config = VllmConfig(
|
||||||
|
compilation_config=CompilationConfig(custom_ops=["all", "+apply_rotary_emb"])
|
||||||
|
)
|
||||||
|
get_cached_compilation_config.cache_clear()
|
||||||
|
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
rope = test_case.rope_class(**test_case.rope_kwargs).to(device=device)
|
||||||
|
|
||||||
|
apply_rotary_emb = rope.apply_rotary_emb
|
||||||
|
|
||||||
|
# Verify custom op is enabled
|
||||||
|
if test_case.expect_forward_native:
|
||||||
|
assert (
|
||||||
|
apply_rotary_emb._forward_method != apply_rotary_emb.forward_native
|
||||||
|
), "Test setup error: ApplyRotaryEmb custom op should be enabled"
|
||||||
|
|
||||||
|
# Setup call tracking
|
||||||
|
call_tracker = {"forward_native_called": False, "forward_called": False}
|
||||||
|
original_forward_native = apply_rotary_emb.forward_native
|
||||||
|
original_forward = apply_rotary_emb.forward
|
||||||
|
|
||||||
|
def tracked_forward_native(*args, **kwargs):
|
||||||
|
call_tracker["forward_native_called"] = True
|
||||||
|
return original_forward_native(*args, **kwargs)
|
||||||
|
|
||||||
|
def tracked_forward(*args, **kwargs):
|
||||||
|
call_tracker["forward_called"] = True
|
||||||
|
return original_forward(*args, **kwargs)
|
||||||
|
|
||||||
|
apply_rotary_emb.forward_native = tracked_forward_native
|
||||||
|
apply_rotary_emb.forward = tracked_forward
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_tokens = test_case.positions_shape[-1]
|
||||||
|
num_q_heads = 8
|
||||||
|
num_kv_heads = 2
|
||||||
|
head_size = test_case.rope_kwargs["head_size"]
|
||||||
|
max_position = test_case.rope_kwargs["max_position_embeddings"]
|
||||||
|
|
||||||
|
positions = torch.randint(
|
||||||
|
0, max_position // 4, test_case.positions_shape, device=device
|
||||||
|
)
|
||||||
|
query = torch.randn(
|
||||||
|
num_tokens, num_q_heads * head_size, dtype=torch.bfloat16, device=device
|
||||||
|
)
|
||||||
|
key = torch.randn(
|
||||||
|
num_tokens,
|
||||||
|
num_kv_heads * head_size,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the method under test
|
||||||
|
method = getattr(rope, test_case.method_name)
|
||||||
|
method(positions, query.clone(), key.clone())
|
||||||
|
|
||||||
|
# Verify expectations
|
||||||
|
if test_case.expect_forward_native:
|
||||||
|
assert call_tracker["forward_native_called"], (
|
||||||
|
f"{test_case.name} should call ApplyRotaryEmb.forward_native()"
|
||||||
|
)
|
||||||
|
if not test_case.expect_forward:
|
||||||
|
assert not call_tracker["forward_called"], (
|
||||||
|
f"{test_case.name} should NOT call ApplyRotaryEmb.forward(). "
|
||||||
|
"Bug: when +apply_rotary_emb is enabled, forward_native() "
|
||||||
|
"incorrectly dispatches to CUDA/HIP kernels."
|
||||||
|
)
|
||||||
|
if test_case.expect_forward:
|
||||||
|
assert call_tracker["forward_called"], (
|
||||||
|
f"{test_case.name} should call ApplyRotaryEmb.forward()"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
apply_rotary_emb.forward_native = original_forward_native
|
||||||
|
apply_rotary_emb.forward = original_forward
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests."
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("test_case", get_test_cases(), ids=lambda tc: tc.name)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
def test_rotary_embedding_dispatch(
|
||||||
|
test_case: RotaryEmbeddingTestCase,
|
||||||
|
device: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that RotaryEmbedding classes dispatch to the correct ApplyRotaryEmb method.
|
||||||
|
|
||||||
|
- forward_native methods should call ApplyRotaryEmb.forward_native()
|
||||||
|
- forward_cuda/forward methods should call ApplyRotaryEmb.forward()
|
||||||
|
"""
|
||||||
|
run_dispatch_test(test_case, device)
|
||||||
@@ -388,6 +388,7 @@ def run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner):
|
|||||||
"mm_encoder_attn_backend",
|
"mm_encoder_attn_backend",
|
||||||
[None] + current_platform.get_supported_vit_attn_backends(),
|
[None] + current_platform.get_supported_vit_attn_backends(),
|
||||||
)
|
)
|
||||||
|
@pytest.mark.skip(reason="Broken test due to memory segmentation fault")
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
def test_vit_backend_functionality(
|
def test_vit_backend_functionality(
|
||||||
model_key: str,
|
model_key: str,
|
||||||
|
|||||||
@@ -60,12 +60,12 @@ def test_profiling(model_id: str, max_model_len: int):
|
|||||||
total_num_patches.item() + num_tiles.item() + 3
|
total_num_patches.item() + num_tiles.item() + 3
|
||||||
) # image start, image, image end
|
) # image start, image, image end
|
||||||
|
|
||||||
profiled_tokens = profiler.get_mm_max_contiguous_tokens(
|
profiled_tokens = profiler.get_mm_max_tokens(
|
||||||
max_model_len,
|
max_model_len,
|
||||||
mm_counts=mm_counts,
|
mm_counts=mm_counts,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert total_tokens == profiled_tokens["image"]
|
assert total_num_patches == profiled_tokens["image"]
|
||||||
assert total_tokens == sum(
|
assert total_tokens == sum(
|
||||||
placeholder.length
|
placeholder.length
|
||||||
for placeholder in decoder_dummy_data.multi_modal_placeholders["image"]
|
for placeholder in decoder_dummy_data.multi_modal_placeholders["image"]
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from tempfile import NamedTemporaryFile, TemporaryDirectory
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
from PIL import Image, ImageChops
|
from PIL import Image, ImageChops
|
||||||
|
|
||||||
from vllm.multimodal.image import convert_image_mode
|
from vllm.multimodal.image import convert_image_mode
|
||||||
@@ -410,6 +411,97 @@ def test_argsort_mm_positions(case):
|
|||||||
assert modality_idxs == expected_modality_idxs
|
assert modality_idxs == expected_modality_idxs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"is_embed,expected",
|
||||||
|
[
|
||||||
|
(None, 5),
|
||||||
|
(torch.tensor([True, True, True, True, True]), 5),
|
||||||
|
(torch.tensor([False, False, False, False, False]), 0),
|
||||||
|
(torch.tensor([True, False, True, False, True]), 3),
|
||||||
|
(torch.tensor([True]), 1),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_placeholder_range_get_num_embeds(is_embed, expected):
|
||||||
|
length = len(is_embed) if is_embed is not None else 5
|
||||||
|
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
|
||||||
|
assert pr.get_num_embeds == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"is_embed,expected",
|
||||||
|
[
|
||||||
|
(None, None),
|
||||||
|
(
|
||||||
|
torch.tensor([False, True, False, True, True]),
|
||||||
|
torch.tensor([0, 1, 1, 2, 3]),
|
||||||
|
),
|
||||||
|
(torch.tensor([True, True, True]), torch.tensor([1, 2, 3])),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_placeholder_range_embeds_cumsum(is_embed, expected):
|
||||||
|
length = len(is_embed) if is_embed is not None else 5
|
||||||
|
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
|
||||||
|
|
||||||
|
if expected is None:
|
||||||
|
assert pr.embeds_cumsum is None
|
||||||
|
return
|
||||||
|
|
||||||
|
assert torch.equal(pr.embeds_cumsum, expected)
|
||||||
|
# cached_property should return the same object on repeated access
|
||||||
|
assert pr.embeds_cumsum is pr.embeds_cumsum
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"is_embed,start_idx,end_idx,expected",
|
||||||
|
[
|
||||||
|
(None, 2, 4, (2, 4)),
|
||||||
|
(
|
||||||
|
torch.tensor([False, True, False, True, True]),
|
||||||
|
3,
|
||||||
|
5,
|
||||||
|
(1, 3),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
torch.tensor([False, True, False, True, True]),
|
||||||
|
0,
|
||||||
|
2,
|
||||||
|
(0, 1),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
torch.tensor([True, False, True, False]),
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
(1, 1),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_placeholder_range_get_embeds_indices_in_range(
|
||||||
|
is_embed, start_idx, end_idx, expected
|
||||||
|
):
|
||||||
|
length = len(is_embed) if is_embed is not None else 5
|
||||||
|
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
|
||||||
|
assert pr.get_embeds_indices_in_range(start_idx, end_idx) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"offset,is_embed,expected",
|
||||||
|
[
|
||||||
|
(0, None, [(0, 4)]),
|
||||||
|
(
|
||||||
|
2,
|
||||||
|
torch.tensor([False, True, False, True, True]),
|
||||||
|
[(3, 3), (5, 6)],
|
||||||
|
),
|
||||||
|
(0, torch.tensor([True, True, True, True]), [(0, 3)]),
|
||||||
|
(0, torch.tensor([False, False, False, False]), []),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_placeholder_range_extract_embeds_range(offset, is_embed, expected):
|
||||||
|
length = len(is_embed) if is_embed is not None else 5
|
||||||
|
pr = PlaceholderRange(offset=offset, length=length, is_embed=is_embed)
|
||||||
|
assert pr.extract_embeds_range() == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||||
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
|
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
|
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
|
||||||
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
|
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
|
||||||
@@ -23,7 +24,7 @@ class MockRequest:
|
|||||||
)
|
)
|
||||||
self.mm_features.append(feature)
|
self.mm_features.append(feature)
|
||||||
|
|
||||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||||
return self._token_counts[input_id]
|
return self._token_counts[input_id]
|
||||||
|
|
||||||
|
|
||||||
@@ -162,8 +163,8 @@ def test_schedule_request_multi_images_respect_space_limit():
|
|||||||
|
|
||||||
num_tokens_to_schedule = 0
|
num_tokens_to_schedule = 0
|
||||||
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
|
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
|
||||||
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
|
num_tokens_to_schedule += req.get_num_encoder_embeds(0)
|
||||||
compute_budget -= req.get_num_encoder_tokens(0)
|
compute_budget -= req.get_num_encoder_embeds(0)
|
||||||
|
|
||||||
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
|
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
|
||||||
|
|
||||||
@@ -174,7 +175,75 @@ def test_schedule_request_multi_images_respect_compute_limit():
|
|||||||
compute_budget = 10
|
compute_budget = 10
|
||||||
num_tokens_to_schedule = 0
|
num_tokens_to_schedule = 0
|
||||||
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
|
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
|
||||||
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
|
num_tokens_to_schedule += req.get_num_encoder_embeds(0)
|
||||||
compute_budget -= req.get_num_encoder_tokens(0)
|
compute_budget -= req.get_num_encoder_embeds(0)
|
||||||
|
|
||||||
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
|
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder_cache_with_is_embed_mask():
|
||||||
|
class MockRequestWithMask(MockRequest):
|
||||||
|
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||||
|
return self.mm_features[input_id].mm_position.get_num_embeds
|
||||||
|
|
||||||
|
is_embed = torch.zeros(100, dtype=torch.bool)
|
||||||
|
is_embed[torch.tensor([5, 15, 25, 35, 45, 55, 65, 75])] = True
|
||||||
|
|
||||||
|
request = MockRequestWithMask("r1", ["img1"], [100])
|
||||||
|
request.mm_features[0] = MultiModalFeatureSpec(
|
||||||
|
data=None,
|
||||||
|
modality="image",
|
||||||
|
identifier="img1",
|
||||||
|
mm_position=PlaceholderRange(offset=0, length=100, is_embed=is_embed),
|
||||||
|
)
|
||||||
|
|
||||||
|
manager = EncoderCacheManager(cache_size=100)
|
||||||
|
manager.allocate(request, 0)
|
||||||
|
|
||||||
|
assert manager.num_free_slots == 92
|
||||||
|
assert "img1" in manager.cached
|
||||||
|
|
||||||
|
old_size = 100
|
||||||
|
new_size = request.mm_features[0].mm_position.get_num_embeds
|
||||||
|
assert new_size == 8
|
||||||
|
savings_ratio = old_size / new_size
|
||||||
|
assert savings_ratio == 12.5
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder_cache_mask_based_retrieval():
|
||||||
|
class MockRequestWithMask(MockRequest):
|
||||||
|
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||||
|
return self.mm_features[input_id].mm_position.get_num_embeds
|
||||||
|
|
||||||
|
is_embed = torch.tensor(
|
||||||
|
[False, False, True, True, False, True, True, True, False, False]
|
||||||
|
)
|
||||||
|
|
||||||
|
request = MockRequestWithMask("r1", ["img1"], [10])
|
||||||
|
request.mm_features[0] = MultiModalFeatureSpec(
|
||||||
|
data=None,
|
||||||
|
modality="image",
|
||||||
|
identifier="img1",
|
||||||
|
mm_position=PlaceholderRange(offset=0, length=10, is_embed=is_embed),
|
||||||
|
)
|
||||||
|
|
||||||
|
manager = EncoderCacheManager(cache_size=50)
|
||||||
|
manager.allocate(request, 0)
|
||||||
|
|
||||||
|
assert request.mm_features[0].mm_position.get_num_embeds == 5
|
||||||
|
|
||||||
|
start_idx = 2
|
||||||
|
end_idx = 8
|
||||||
|
num_embeds_before = is_embed[:start_idx].sum().item()
|
||||||
|
num_embeds_in_range = is_embed[start_idx:end_idx].sum().item()
|
||||||
|
|
||||||
|
assert num_embeds_before == 0
|
||||||
|
assert num_embeds_in_range == 5
|
||||||
|
|
||||||
|
start_idx = 0
|
||||||
|
end_idx = 5
|
||||||
|
num_embeds_before = is_embed[:start_idx].sum().item() if start_idx > 0 else 0
|
||||||
|
num_embeds_in_range = is_embed[start_idx:end_idx].sum().item()
|
||||||
|
|
||||||
|
assert num_embeds_before == 0
|
||||||
|
assert num_embeds_in_range == 2
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ class MockRequest:
|
|||||||
)
|
)
|
||||||
self.mm_features.append(feature)
|
self.mm_features.append(feature)
|
||||||
|
|
||||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||||
assert input_id < len(self._token_counts)
|
assert input_id < len(self._token_counts)
|
||||||
return self._token_counts[input_id]
|
return self._token_counts[input_id]
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import einops
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
@@ -89,6 +90,13 @@ def torch_sdpa_wrapper(
|
|||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
# Never remove the contiguous logic for ROCm
|
||||||
|
# Without it, hallucinations occur with the backend
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
q = q.contiguous()
|
||||||
|
k = k.contiguous()
|
||||||
|
v = v.contiguous()
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ class ECExampleConnector(ECConnectorBase):
|
|||||||
Update ECConnector state after encoder cache allocation.
|
Update ECConnector state after encoder cache allocation.
|
||||||
"""
|
"""
|
||||||
mm_hash = request.mm_features[index].identifier
|
mm_hash = request.mm_features[index].identifier
|
||||||
num_encoder_token = request.get_num_encoder_tokens(index)
|
num_encoder_token = request.get_num_encoder_embeds(index)
|
||||||
# Insert mm_hash only if this block has not been recorded yet.
|
# Insert mm_hash only if this block has not been recorded yet.
|
||||||
self._mm_datas_need_loads[mm_hash] = num_encoder_token
|
self._mm_datas_need_loads[mm_hash] = num_encoder_token
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import torch
|
|||||||
from vllm._aiter_ops import rocm_aiter_ops
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
from .common import apply_rotary_emb_torch
|
from .common import ApplyRotaryEmb
|
||||||
|
|
||||||
|
|
||||||
@CustomOp.register("rotary_embedding")
|
@CustomOp.register("rotary_embedding")
|
||||||
@@ -49,6 +49,10 @@ class RotaryEmbeddingBase(CustomOp):
|
|||||||
rocm_aiter_ops.is_triton_rotary_embed_enabled()
|
rocm_aiter_ops.is_triton_rotary_embed_enabled()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.apply_rotary_emb = ApplyRotaryEmb(
|
||||||
|
is_neox_style=self.is_neox_style,
|
||||||
|
)
|
||||||
|
|
||||||
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
||||||
"""Compute the inverse frequency."""
|
"""Compute the inverse frequency."""
|
||||||
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||||
@@ -123,7 +127,12 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
|||||||
query = query.view(num_tokens, -1, head_size)
|
query = query.view(num_tokens, -1, head_size)
|
||||||
query_rot = query[..., :rotary_dim]
|
query_rot = query[..., :rotary_dim]
|
||||||
query_pass = query[..., rotary_dim:]
|
query_pass = query[..., rotary_dim:]
|
||||||
query_rot = apply_rotary_emb_torch(query_rot, cos, sin, is_neox_style)
|
query_rot = ApplyRotaryEmb.forward_static(
|
||||||
|
query_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
is_neox_style,
|
||||||
|
)
|
||||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||||
|
|
||||||
# key may be None in some cases, e.g. cross-layer KV sharing
|
# key may be None in some cases, e.g. cross-layer KV sharing
|
||||||
@@ -132,7 +141,12 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
|||||||
key = key.view(num_tokens, -1, head_size)
|
key = key.view(num_tokens, -1, head_size)
|
||||||
key_rot = key[..., :rotary_dim]
|
key_rot = key[..., :rotary_dim]
|
||||||
key_pass = key[..., rotary_dim:]
|
key_pass = key[..., rotary_dim:]
|
||||||
key_rot = apply_rotary_emb_torch(key_rot, cos, sin, is_neox_style)
|
key_rot = ApplyRotaryEmb.forward_static(
|
||||||
|
key_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
is_neox_style,
|
||||||
|
)
|
||||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
|||||||
@@ -2,19 +2,14 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from collections.abc import Callable
|
|
||||||
from functools import cache
|
|
||||||
from importlib.util import find_spec
|
from importlib.util import find_spec
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
|
||||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -32,71 +27,6 @@ def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
|||||||
return x.flatten(-2)
|
return x.flatten(-2)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb_torch(
|
|
||||||
x: torch.Tensor,
|
|
||||||
cos: torch.Tensor,
|
|
||||||
sin: torch.Tensor,
|
|
||||||
is_neox_style: bool,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
|
||||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
|
||||||
if is_neox_style:
|
|
||||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
|
||||||
else:
|
|
||||||
x1 = x[..., ::2]
|
|
||||||
x2 = x[..., 1::2]
|
|
||||||
o1 = x1 * cos - x2 * sin
|
|
||||||
o2 = x2 * cos + x1 * sin
|
|
||||||
if is_neox_style:
|
|
||||||
return torch.cat((o1, o2), dim=-1)
|
|
||||||
else:
|
|
||||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb_dispatch(
|
|
||||||
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, is_neox_style: bool
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x: [num_tokens, num_heads, head_size]
|
|
||||||
cos: [num_tokens, head_size // 2]
|
|
||||||
sin: [num_tokens, head_size // 2]
|
|
||||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
|
||||||
positional embeddings.
|
|
||||||
"""
|
|
||||||
if current_platform.is_cuda():
|
|
||||||
return apply_rotary_emb(x.unsqueeze(0), cos, sin, not is_neox_style).squeeze(0)
|
|
||||||
else:
|
|
||||||
return apply_rotary_emb_torch(x, cos, sin, is_neox_style)
|
|
||||||
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def dispatch_rotary_emb_function(
|
|
||||||
default: Callable[..., torch.Tensor] | None = None,
|
|
||||||
) -> Callable[..., torch.Tensor]:
|
|
||||||
if current_platform.is_cuda():
|
|
||||||
return apply_rotary_emb
|
|
||||||
|
|
||||||
# if torch compile is not enabled
|
|
||||||
# use rotary embedding function from flash_attn package
|
|
||||||
# otherwise use the naive pytorch embedding implementation
|
|
||||||
# is faster when torch compile is enabled.
|
|
||||||
if current_platform.is_rocm() and not torch.compiler.is_compiling():
|
|
||||||
if find_spec("flash_attn") is not None:
|
|
||||||
from flash_attn.ops.triton.rotary import apply_rotary
|
|
||||||
|
|
||||||
return apply_rotary
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"flash_attn is not installed. Falling back to PyTorch "
|
|
||||||
"implementation for rotary embeddings."
|
|
||||||
)
|
|
||||||
if default is not None:
|
|
||||||
return default
|
|
||||||
|
|
||||||
return apply_rotary_emb_torch
|
|
||||||
|
|
||||||
|
|
||||||
# yarn functions
|
# yarn functions
|
||||||
# Inverse dim formula to find dim based on number of rotations
|
# Inverse dim formula to find dim based on number of rotations
|
||||||
def yarn_find_correction_dim(
|
def yarn_find_correction_dim(
|
||||||
@@ -186,3 +116,155 @@ direct_register_custom_op(
|
|||||||
mutates_args=["query", "key"], # These tensors are modified in-place
|
mutates_args=["query", "key"], # These tensors are modified in-place
|
||||||
fake_impl=_flashinfer_rotary_embedding_fake,
|
fake_impl=_flashinfer_rotary_embedding_fake,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@CustomOp.register("apply_rotary_emb")
|
||||||
|
class ApplyRotaryEmb(CustomOp):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
enforce_enable: bool = False,
|
||||||
|
is_neox_style: bool = True,
|
||||||
|
enable_fp32_compute: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(enforce_enable)
|
||||||
|
self.is_neox_style = is_neox_style
|
||||||
|
self.enable_fp32_compute = enable_fp32_compute
|
||||||
|
|
||||||
|
self.apply_rotary_emb_flash_attn = None
|
||||||
|
if find_spec("flash_attn") is not None:
|
||||||
|
from flash_attn.ops.triton.rotary import apply_rotary
|
||||||
|
|
||||||
|
self.apply_rotary_emb_flash_attn = apply_rotary
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward_static(
|
||||||
|
x: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
is_neox_style: bool = True,
|
||||||
|
enable_fp32_compute: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: [batch_size (optional), seq_len, num_heads, head_size]
|
||||||
|
cos: [seq_len, head_size // 2]
|
||||||
|
sin: [seq_len, head_size // 2]
|
||||||
|
is_neox_style: Whether to use the Neox-style or GPT-J-style.
|
||||||
|
enable_fp32_compute: Temporarily convert x, cos, sin to FP32 dtype
|
||||||
|
for higher accuracy.
|
||||||
|
"""
|
||||||
|
origin_dtype = x.dtype
|
||||||
|
if enable_fp32_compute:
|
||||||
|
x = x.float()
|
||||||
|
|
||||||
|
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||||
|
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||||
|
|
||||||
|
if is_neox_style:
|
||||||
|
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||||
|
else:
|
||||||
|
x1 = x[..., ::2]
|
||||||
|
x2 = x[..., 1::2]
|
||||||
|
|
||||||
|
o1 = x1 * cos - x2 * sin
|
||||||
|
o2 = x2 * cos + x1 * sin
|
||||||
|
|
||||||
|
if is_neox_style:
|
||||||
|
output = torch.cat((o1, o2), dim=-1)
|
||||||
|
else:
|
||||||
|
output = torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||||
|
|
||||||
|
if enable_fp32_compute:
|
||||||
|
output = output.to(origin_dtype)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward_native(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
output = self.forward_static(
|
||||||
|
x, cos, sin, self.is_neox_style, self.enable_fp32_compute
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward_cuda(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||||
|
|
||||||
|
origin_dtype = x.dtype
|
||||||
|
if self.enable_fp32_compute:
|
||||||
|
x = x.float()
|
||||||
|
cos = cos.float()
|
||||||
|
sin = sin.float()
|
||||||
|
|
||||||
|
origin_shape = x.shape
|
||||||
|
if len(origin_shape) == 3:
|
||||||
|
# x: [seq_len, num_heads, head_size]
|
||||||
|
x = x.unsqueeze(0)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Arguments of apply_rotary_emb() in vllm_flash_attn:
|
||||||
|
x: [batch_size, seq_len, nheads, headdim]
|
||||||
|
cos, sin: [seqlen_rotary, rotary_dim / 2]
|
||||||
|
interleaved: defalut as False (Neox-style).
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
interleaved = not self.is_neox_style
|
||||||
|
output = apply_rotary_emb(x, cos, sin, interleaved)
|
||||||
|
|
||||||
|
if len(origin_shape) == 3:
|
||||||
|
output = output.squeeze(0)
|
||||||
|
if self.enable_fp32_compute:
|
||||||
|
output = output.to(origin_dtype)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward_hip(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if self.apply_rotary_emb_flash_attn is not None:
|
||||||
|
origin_dtype = x.dtype
|
||||||
|
if self.enable_fp32_compute:
|
||||||
|
x = x.float()
|
||||||
|
cos = cos.float()
|
||||||
|
sin = sin.float()
|
||||||
|
|
||||||
|
origin_shape = x.shape
|
||||||
|
if len(origin_shape) == 3:
|
||||||
|
# x: [seq_len, num_heads, head_size]
|
||||||
|
x = x.unsqueeze(0)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Arguments of apply_rotary() in flash_attn:
|
||||||
|
x: [batch_size, seq_len, nheads, headdim]
|
||||||
|
cos, sin: [seqlen_rotary, rotary_dim / 2]
|
||||||
|
interleaved: defalut as False (Neox-style).
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
interleaved = not self.is_neox_style
|
||||||
|
output = self.apply_rotary_emb_flash_attn(
|
||||||
|
x, cos, sin, interleaved=interleaved
|
||||||
|
).type_as(x)
|
||||||
|
|
||||||
|
if len(origin_shape) == 3:
|
||||||
|
output = output.squeeze(0)
|
||||||
|
if self.enable_fp32_compute:
|
||||||
|
output = output.to(origin_dtype)
|
||||||
|
else:
|
||||||
|
# Falling back to PyTorch native implementation.
|
||||||
|
output = self.forward_native(x, cos, sin)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
s = f"is_neox_style={self.is_neox_style}"
|
||||||
|
s += f"enable_fp32_compute={self.enable_fp32_compute}"
|
||||||
|
return s
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .common import apply_rotary_emb_dispatch
|
|
||||||
from .mrope import MRotaryEmbedding
|
from .mrope import MRotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
@@ -55,14 +54,22 @@ class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
|
|||||||
query = query.view(num_tokens, -1, self.head_size)
|
query = query.view(num_tokens, -1, self.head_size)
|
||||||
query_rot = query[..., : self.rotary_dim]
|
query_rot = query[..., : self.rotary_dim]
|
||||||
query_pass = query[..., self.rotary_dim :]
|
query_pass = query[..., self.rotary_dim :]
|
||||||
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
|
query_rot = self.apply_rotary_emb.forward_native(
|
||||||
|
query_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
)
|
||||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||||
|
|
||||||
key_shape = key.shape
|
key_shape = key.shape
|
||||||
key = key.view(num_tokens, -1, self.head_size)
|
key = key.view(num_tokens, -1, self.head_size)
|
||||||
key_rot = key[..., : self.rotary_dim]
|
key_rot = key[..., : self.rotary_dim]
|
||||||
key_pass = key[..., self.rotary_dim :]
|
key_pass = key[..., self.rotary_dim :]
|
||||||
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
|
key_rot = self.apply_rotary_emb.forward_native(
|
||||||
|
key_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
)
|
||||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import torch
|
|||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
from .base import RotaryEmbeddingBase
|
from .base import RotaryEmbeddingBase
|
||||||
from .common import apply_rotary_emb_dispatch
|
|
||||||
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale
|
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale
|
||||||
|
|
||||||
|
|
||||||
@@ -301,14 +300,22 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
|
|||||||
query = query.view(num_tokens, -1, self.head_size)
|
query = query.view(num_tokens, -1, self.head_size)
|
||||||
query_rot = query[..., : self.rotary_dim]
|
query_rot = query[..., : self.rotary_dim]
|
||||||
query_pass = query[..., self.rotary_dim :]
|
query_pass = query[..., self.rotary_dim :]
|
||||||
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
|
query_rot = self.apply_rotary_emb.forward_native(
|
||||||
|
query_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
)
|
||||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||||
|
|
||||||
key_shape = key.shape
|
key_shape = key.shape
|
||||||
key = key.view(num_tokens, -1, self.head_size)
|
key = key.view(num_tokens, -1, self.head_size)
|
||||||
key_rot = key[..., : self.rotary_dim]
|
key_rot = key[..., : self.rotary_dim]
|
||||||
key_pass = key[..., self.rotary_dim :]
|
key_pass = key[..., self.rotary_dim :]
|
||||||
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
|
key_rot = self.apply_rotary_emb.forward_native(
|
||||||
|
key_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
)
|
||||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
@@ -347,13 +354,21 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
|
|||||||
query = query.view(num_tokens, -1, self.head_size)
|
query = query.view(num_tokens, -1, self.head_size)
|
||||||
query_rot = query[..., : self.rotary_dim]
|
query_rot = query[..., : self.rotary_dim]
|
||||||
query_pass = query[..., self.rotary_dim :]
|
query_pass = query[..., self.rotary_dim :]
|
||||||
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
|
query_rot = self.apply_rotary_emb(
|
||||||
|
query_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
)
|
||||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||||
|
|
||||||
key = key.view(num_tokens, -1, self.head_size)
|
key = key.view(num_tokens, -1, self.head_size)
|
||||||
key_rot = key[..., : self.rotary_dim]
|
key_rot = key[..., : self.rotary_dim]
|
||||||
key_pass = key[..., self.rotary_dim :]
|
key_pass = key[..., self.rotary_dim :]
|
||||||
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
|
key_rot = self.apply_rotary_emb(
|
||||||
|
key_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
)
|
||||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .common import apply_rotary_emb_dispatch
|
|
||||||
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
|
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
@@ -36,7 +35,7 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
|
|||||||
dtype,
|
dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@@ -68,14 +67,73 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
|
|||||||
query = query.view(num_tokens, -1, self.head_size)
|
query = query.view(num_tokens, -1, self.head_size)
|
||||||
query_rot = query[..., : self.rotary_dim]
|
query_rot = query[..., : self.rotary_dim]
|
||||||
query_pass = query[..., self.rotary_dim :]
|
query_pass = query[..., self.rotary_dim :]
|
||||||
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
|
query_rot = self.apply_rotary_emb.forward_native(
|
||||||
|
query_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
)
|
||||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||||
|
|
||||||
key_shape = key.shape
|
key_shape = key.shape
|
||||||
key = key.view(num_tokens, -1, self.head_size)
|
key = key.view(num_tokens, -1, self.head_size)
|
||||||
key_rot = key[..., : self.rotary_dim]
|
key_rot = key[..., : self.rotary_dim]
|
||||||
key_pass = key[..., self.rotary_dim :]
|
key_pass = key[..., self.rotary_dim :]
|
||||||
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
|
key_rot = self.apply_rotary_emb.forward_native(
|
||||||
|
key_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
)
|
||||||
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
|
return query, key
|
||||||
|
|
||||||
|
def forward_cuda(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor | None = None,
|
||||||
|
offsets: torch.Tensor | None = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||||
|
"""PyTorch-native implementation equivalent to forward().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
positions:
|
||||||
|
[4, num_tokens] (P/W/H/T positions with multimodal inputs)
|
||||||
|
query: [num_tokens, num_heads * head_size]
|
||||||
|
key: [num_tokens, num_kv_heads * head_size]
|
||||||
|
"""
|
||||||
|
assert positions.ndim == 2
|
||||||
|
assert key is not None
|
||||||
|
|
||||||
|
num_tokens = positions.shape[-1]
|
||||||
|
cos_sin = self.cos_sin_cache[positions]
|
||||||
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||||
|
cos = torch.cat(
|
||||||
|
[m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1
|
||||||
|
)
|
||||||
|
sin = torch.cat(
|
||||||
|
[m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
query_shape = query.shape
|
||||||
|
query = query.view(num_tokens, -1, self.head_size)
|
||||||
|
query_rot = query[..., : self.rotary_dim]
|
||||||
|
query_pass = query[..., self.rotary_dim :]
|
||||||
|
query_rot = self.apply_rotary_emb(
|
||||||
|
query_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
)
|
||||||
|
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||||
|
|
||||||
|
key_shape = key.shape
|
||||||
|
key = key.view(num_tokens, -1, self.head_size)
|
||||||
|
key_rot = key[..., : self.rotary_dim]
|
||||||
|
key_pass = key[..., self.rotary_dim :]
|
||||||
|
key_rot = self.apply_rotary_emb(
|
||||||
|
key_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
)
|
||||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,9 @@ from vllm.model_executor.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||||
|
ApplyRotaryEmb,
|
||||||
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.interfaces import (
|
from vllm.model_executor.models.interfaces import (
|
||||||
MultiModalEmbeddings,
|
MultiModalEmbeddings,
|
||||||
@@ -158,32 +161,6 @@ class DotsOCRProcessingInfo(Qwen2VLProcessingInfo):
|
|||||||
return processor
|
return processor
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
"""Rotates half the hidden dims of the input."""
|
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb_vision(
|
|
||||||
tensor: torch.Tensor, freqs: torch.Tensor
|
|
||||||
) -> torch.Tensor:
|
|
||||||
orig_dtype = tensor.dtype
|
|
||||||
tensor = tensor.float()
|
|
||||||
|
|
||||||
cos = freqs.cos()
|
|
||||||
sin = freqs.sin()
|
|
||||||
|
|
||||||
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
|
||||||
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
|
||||||
|
|
||||||
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
|
||||||
|
|
||||||
output = output.to(orig_dtype)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class VisionRotaryEmbedding(nn.Module):
|
class VisionRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -298,6 +275,11 @@ class DotsVisionAttention(nn.Module):
|
|||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.apply_rotary_emb = ApplyRotaryEmb(
|
||||||
|
enforce_enable=True,
|
||||||
|
enable_fp32_compute=True,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -318,7 +300,11 @@ class DotsVisionAttention(nn.Module):
|
|||||||
|
|
||||||
if rotary_pos_emb is not None:
|
if rotary_pos_emb is not None:
|
||||||
qk_concat = torch.cat([q, k], dim=0)
|
qk_concat = torch.cat([q, k], dim=0)
|
||||||
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
qk_rotated = self.apply_rotary_emb(
|
||||||
|
qk_concat,
|
||||||
|
rotary_pos_emb.cos(),
|
||||||
|
rotary_pos_emb.sin(),
|
||||||
|
)
|
||||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||||
|
|
||||||
context_layer = self.attn(
|
context_layer = self.attn(
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange
|
||||||
from transformers import BatchFeature
|
from transformers import BatchFeature
|
||||||
|
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
@@ -53,6 +53,9 @@ from vllm.model_executor.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||||
|
ApplyRotaryEmb,
|
||||||
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (
|
from vllm.multimodal.inputs import (
|
||||||
@@ -69,7 +72,6 @@ from vllm.multimodal.processing import (
|
|||||||
PromptUpdate,
|
PromptUpdate,
|
||||||
)
|
)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
@@ -89,52 +91,6 @@ logger = init_logger(__name__)
|
|||||||
# === Vision Transformer === #
|
# === Vision Transformer === #
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
|
|
||||||
if not interleaved:
|
|
||||||
x1, x2 = x.chunk(2, dim=-1)
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
else:
|
|
||||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
|
||||||
return rearrange(
|
|
||||||
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb_torch(
|
|
||||||
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
x: (batch_size, seqlen, nheads, headdim)
|
|
||||||
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
|
||||||
"""
|
|
||||||
ro_dim = cos.shape[-1] * 2
|
|
||||||
assert ro_dim <= x.shape[-1]
|
|
||||||
cos = repeat(
|
|
||||||
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
|
||||||
)
|
|
||||||
sin = repeat(
|
|
||||||
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
|
||||||
)
|
|
||||||
return torch.cat(
|
|
||||||
[
|
|
||||||
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
|
||||||
x[..., ro_dim:],
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
|
||||||
t_ = t.float()
|
|
||||||
cos = freqs.cos()
|
|
||||||
sin = freqs.sin()
|
|
||||||
apply_rotary_emb = apply_rotary_emb_torch
|
|
||||||
if current_platform.is_cuda():
|
|
||||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
|
||||||
output = apply_rotary_emb(t_, cos, sin).type_as(t)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
|
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
|
||||||
"""All-gather the input tensor interleavely across model parallel group."""
|
"""All-gather the input tensor interleavely across model parallel group."""
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -200,6 +156,11 @@ class Ernie4_5_VisionAttention(nn.Module):
|
|||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.apply_rotary_emb = ApplyRotaryEmb(
|
||||||
|
enforce_enable=True,
|
||||||
|
enable_fp32_compute=True,
|
||||||
|
)
|
||||||
|
|
||||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||||
# [s, b, 3 * head * head_dim]
|
# [s, b, 3 * head * head_dim]
|
||||||
seq_len, bs, _ = qkv.shape
|
seq_len, bs, _ = qkv.shape
|
||||||
@@ -244,7 +205,11 @@ class Ernie4_5_VisionAttention(nn.Module):
|
|||||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
|
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
|
||||||
if rotary_pos_emb is not None:
|
if rotary_pos_emb is not None:
|
||||||
qk_concat = torch.cat([q, k], dim=0)
|
qk_concat = torch.cat([q, k], dim=0)
|
||||||
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
qk_rotated = self.apply_rotary_emb(
|
||||||
|
qk_concat,
|
||||||
|
rotary_pos_emb.cos(),
|
||||||
|
rotary_pos_emb.sin(),
|
||||||
|
)
|
||||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||||
|
|
||||||
output = self.attn(
|
output = self.attn(
|
||||||
|
|||||||
@@ -65,6 +65,9 @@ from vllm.model_executor.layers.linear import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||||
|
ApplyRotaryEmb,
|
||||||
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
@@ -95,7 +98,7 @@ from .interfaces import (
|
|||||||
SupportsMultiModal,
|
SupportsMultiModal,
|
||||||
SupportsPP,
|
SupportsPP,
|
||||||
)
|
)
|
||||||
from .qwen2_vl import _create_qwen2vl_field_factory, apply_rotary_pos_emb_vision
|
from .qwen2_vl import _create_qwen2vl_field_factory
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
WeightsMapper,
|
WeightsMapper,
|
||||||
@@ -304,6 +307,8 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
multimodal_config=multimodal_config,
|
multimodal_config=multimodal_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
|
||||||
|
|
||||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||||
# [s, b, 3 * head * head_dim]
|
# [s, b, 3 * head * head_dim]
|
||||||
seq_len, bs, _ = qkv.shape
|
seq_len, bs, _ = qkv.shape
|
||||||
@@ -339,8 +344,10 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
|
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
|
||||||
# [2 * b, s, heads, head_dim]
|
# [2 * b, s, heads, head_dim]
|
||||||
qk_concat = torch.cat([q, k], dim=0)
|
qk_concat = torch.cat([q, k], dim=0)
|
||||||
qk_rotated = apply_rotary_pos_emb_vision(
|
qk_rotated = self.apply_rotary_emb(
|
||||||
qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
|
qk_concat,
|
||||||
|
rotary_pos_emb_cos,
|
||||||
|
rotary_pos_emb_sin,
|
||||||
)
|
)
|
||||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,9 @@ from vllm.model_executor.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||||
|
ApplyRotaryEmb,
|
||||||
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
maybe_remap_kv_scale_name,
|
maybe_remap_kv_scale_name,
|
||||||
@@ -59,7 +62,6 @@ from vllm.multimodal.processing import (
|
|||||||
PromptUpdate,
|
PromptUpdate,
|
||||||
)
|
)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
@@ -341,20 +343,14 @@ def apply_rotary_pos_emb_flashatt(
|
|||||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
apply_rotary_emb = ApplyRotaryEmb(
|
||||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
enforce_enable=True,
|
||||||
elif current_platform.is_rocm():
|
enable_fp32_compute=True,
|
||||||
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
|
)
|
||||||
else:
|
|
||||||
# For other platforms, use PyTorch fallback
|
|
||||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
|
||||||
apply_rotary_emb_torch,
|
|
||||||
)
|
|
||||||
|
|
||||||
apply_rotary_emb = partial(apply_rotary_emb_torch, is_neox_style=True)
|
q_embed = apply_rotary_emb(q, cos, sin)
|
||||||
|
k_embed = apply_rotary_emb(k, cos, sin)
|
||||||
|
|
||||||
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
|
|
||||||
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
|
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from typing import Annotated, Literal
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange
|
||||||
from transformers import BatchFeature, PretrainedConfig
|
from transformers import BatchFeature, PretrainedConfig
|
||||||
from transformers.activations import GELUActivation
|
from transformers.activations import GELUActivation
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
@@ -47,7 +47,7 @@ from vllm.model_executor.layers.linear import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||||
dispatch_rotary_emb_function,
|
ApplyRotaryEmb,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
@@ -130,47 +130,6 @@ def smart_resize(
|
|||||||
return h_bar, w_bar
|
return h_bar, w_bar
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
|
|
||||||
if not interleaved:
|
|
||||||
x1, x2 = x.chunk(2, dim=-1)
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
|
||||||
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb_torch(
|
|
||||||
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
x: (batch_size, seqlen, nheads, headdim)
|
|
||||||
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
|
||||||
"""
|
|
||||||
ro_dim = cos.shape[-1] * 2
|
|
||||||
assert ro_dim <= x.shape[-1]
|
|
||||||
cos = repeat(
|
|
||||||
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
|
||||||
)
|
|
||||||
sin = repeat(
|
|
||||||
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
|
||||||
)
|
|
||||||
return torch.cat(
|
|
||||||
[
|
|
||||||
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
|
||||||
x[..., ro_dim:],
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
|
||||||
rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch)
|
|
||||||
t_ = t.float()
|
|
||||||
cos = freqs.cos()
|
|
||||||
sin = freqs.sin()
|
|
||||||
output = rotary_emb_function(t_, cos, sin).type_as(t)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
|
class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
|
||||||
def get_hf_config(self):
|
def get_hf_config(self):
|
||||||
return self.ctx.get_hf_config()
|
return self.ctx.get_hf_config()
|
||||||
@@ -609,6 +568,10 @@ class SiglipAttention(nn.Module):
|
|||||||
multimodal_config=multimodal_config,
|
multimodal_config=multimodal_config,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
)
|
)
|
||||||
|
self.apply_rotary_emb = ApplyRotaryEmb(
|
||||||
|
enforce_enable=True,
|
||||||
|
enable_fp32_compute=True,
|
||||||
|
)
|
||||||
|
|
||||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||||
seq_len, bs, _ = qkv.shape
|
seq_len, bs, _ = qkv.shape
|
||||||
@@ -651,7 +614,11 @@ class SiglipAttention(nn.Module):
|
|||||||
|
|
||||||
if rotary_pos_emb is not None:
|
if rotary_pos_emb is not None:
|
||||||
qk_concat = torch.cat([q, k], dim=0)
|
qk_concat = torch.cat([q, k], dim=0)
|
||||||
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
qk_rotated = self.apply_rotary_emb(
|
||||||
|
qk_concat,
|
||||||
|
rotary_pos_emb.cos(),
|
||||||
|
rotary_pos_emb.sin(),
|
||||||
|
)
|
||||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||||
|
|
||||||
context_layer = self.attn(
|
context_layer = self.attn(
|
||||||
|
|||||||
@@ -60,6 +60,9 @@ from vllm.model_executor.layers.linear import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||||
|
ApplyRotaryEmb,
|
||||||
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.model_executor.models.vision import should_torch_compile_mm_vit
|
from vllm.model_executor.models.vision import should_torch_compile_mm_vit
|
||||||
@@ -95,7 +98,6 @@ from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
|
|||||||
from .qwen2_vl import (
|
from .qwen2_vl import (
|
||||||
Qwen2VLMultiModalProcessor,
|
Qwen2VLMultiModalProcessor,
|
||||||
Qwen2VLProcessingInfo,
|
Qwen2VLProcessingInfo,
|
||||||
apply_rotary_pos_emb_vision,
|
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
@@ -353,6 +355,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
multimodal_config=multimodal_config,
|
multimodal_config=multimodal_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@@ -378,8 +382,10 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
qk_reshaped = einops.rearrange(
|
qk_reshaped = einops.rearrange(
|
||||||
qk, "b s two head head_dim -> (two b) s head head_dim", two=2
|
qk, "b s two head head_dim -> (two b) s head head_dim", two=2
|
||||||
)
|
)
|
||||||
qk_rotated = apply_rotary_pos_emb_vision(
|
qk_rotated = self.apply_rotary_emb(
|
||||||
qk_reshaped, cos=rotary_pos_emb_cos, sin=rotary_pos_emb_sin
|
qk_reshaped,
|
||||||
|
rotary_pos_emb_cos,
|
||||||
|
rotary_pos_emb_sin,
|
||||||
)
|
)
|
||||||
qk_rotated = qk_rotated.view(
|
qk_rotated = qk_rotated.view(
|
||||||
2,
|
2,
|
||||||
|
|||||||
@@ -59,8 +59,7 @@ from vllm.model_executor.layers.linear import (
|
|||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||||
apply_rotary_emb_torch,
|
ApplyRotaryEmb,
|
||||||
dispatch_rotary_emb_function,
|
|
||||||
)
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
@@ -280,16 +279,6 @@ class Qwen2VisionMLP(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb_vision(
|
|
||||||
t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
|
||||||
) -> torch.Tensor:
|
|
||||||
rotary_emb_function = dispatch_rotary_emb_function(
|
|
||||||
default=partial(apply_rotary_emb_torch, is_neox_style=True)
|
|
||||||
)
|
|
||||||
output = rotary_emb_function(t, cos, sin).type_as(t)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen2VisionAttention(nn.Module):
|
class Qwen2VisionAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -341,6 +330,8 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
multimodal_config=multimodal_config,
|
multimodal_config=multimodal_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
|
||||||
|
|
||||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||||
# [s, b, 3 * head * head_dim]
|
# [s, b, 3 * head * head_dim]
|
||||||
seq_len, bs, _ = qkv.shape
|
seq_len, bs, _ = qkv.shape
|
||||||
@@ -387,8 +378,10 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
|
|
||||||
# [2 * b, s, heads, head_dim]
|
# [2 * b, s, heads, head_dim]
|
||||||
qk_concat = torch.cat([q, k], dim=0)
|
qk_concat = torch.cat([q, k], dim=0)
|
||||||
qk_rotated = apply_rotary_pos_emb_vision(
|
qk_rotated = self.apply_rotary_emb(
|
||||||
qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
|
qk_concat,
|
||||||
|
rotary_pos_emb_cos,
|
||||||
|
rotary_pos_emb_sin,
|
||||||
)
|
)
|
||||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||||
|
|
||||||
|
|||||||
@@ -713,17 +713,13 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
|
|||||||
mm_counts: Mapping[str, int],
|
mm_counts: Mapping[str, int],
|
||||||
) -> int:
|
) -> int:
|
||||||
target_width, target_height = self.get_image_size_with_most_features()
|
target_width, target_height = self.get_image_size_with_most_features()
|
||||||
video_soft_tokens = self.get_num_video_tokens(
|
num_video_soft_tokens = self.get_num_video_tokens(
|
||||||
image_width=target_width,
|
image_width=target_width,
|
||||||
image_height=target_height,
|
image_height=target_height,
|
||||||
num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
|
num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
|
||||||
image_processor=None,
|
image_processor=None,
|
||||||
)
|
)
|
||||||
|
return num_video_soft_tokens
|
||||||
# NOTE: By default in Qwen3-VL, one video token is converted to
|
|
||||||
# "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501
|
|
||||||
formatted_video_soft_tokens = video_soft_tokens * 12.5
|
|
||||||
return int(formatted_video_soft_tokens)
|
|
||||||
|
|
||||||
def _calculate_timestamps(
|
def _calculate_timestamps(
|
||||||
self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int
|
self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ within a vision language model."""
|
|||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange, repeat
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from transformers import Siglip2VisionConfig
|
from transformers import Siglip2VisionConfig
|
||||||
@@ -26,6 +25,9 @@ from vllm.model_executor.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||||
|
ApplyRotaryEmb,
|
||||||
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@@ -146,40 +148,6 @@ class Siglip2VisionEmbeddings(nn.Module):
|
|||||||
return patch_embeds
|
return patch_embeds
|
||||||
|
|
||||||
|
|
||||||
# copy from flash_attn/layers/rotary.py
|
|
||||||
def rotate_half(x, interleaved=False):
|
|
||||||
if not interleaved:
|
|
||||||
x1, x2 = x.chunk(2, dim=-1)
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
else:
|
|
||||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
|
||||||
return rearrange(
|
|
||||||
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
|
||||||
"""
|
|
||||||
x: (batch_size, seqlen, nheads, headdim)
|
|
||||||
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
|
||||||
"""
|
|
||||||
ro_dim = cos.shape[-1] * 2
|
|
||||||
assert ro_dim <= x.shape[-1]
|
|
||||||
cos = repeat(
|
|
||||||
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
|
||||||
)
|
|
||||||
sin = repeat(
|
|
||||||
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
|
||||||
)
|
|
||||||
return torch.cat(
|
|
||||||
[
|
|
||||||
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
|
||||||
x[..., ro_dim:],
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(
|
def apply_rotary_pos_emb(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
@@ -189,14 +157,20 @@ def apply_rotary_pos_emb(
|
|||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||||
if is_flash_attn_backend and current_platform.is_cuda():
|
|
||||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
|
||||||
|
|
||||||
apply_rotary_emb_func = apply_rotary_emb
|
apply_rotary_emb = ApplyRotaryEmb(
|
||||||
|
enforce_enable=True,
|
||||||
|
enable_fp32_compute=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_flash_attn_backend and not current_platform.is_cuda():
|
||||||
|
apply_rotary_emb_func = apply_rotary_emb.forward_cuda
|
||||||
else:
|
else:
|
||||||
apply_rotary_emb_func = apply_rotary_emb_torch
|
apply_rotary_emb_func = apply_rotary_emb.forward_native
|
||||||
q_embed = apply_rotary_emb_func(q.float(), cos.float(), sin.float()).type_as(q)
|
|
||||||
k_embed = apply_rotary_emb_func(k.float(), cos.float(), sin.float()).type_as(k)
|
q_embed = apply_rotary_emb_func(q, cos, sin)
|
||||||
|
k_embed = apply_rotary_emb_func(k, cos, sin)
|
||||||
|
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import torch
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.config import VllmConfig, get_current_vllm_config
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
@@ -88,16 +88,10 @@ def get_vit_attn_backend(
|
|||||||
"""
|
"""
|
||||||
Get the available attention backend for Vision Transformer.
|
Get the available attention backend for Vision Transformer.
|
||||||
"""
|
"""
|
||||||
attn_backend = attn_backend_override
|
|
||||||
|
|
||||||
selected_backend = get_current_vllm_config().attention_config.backend
|
|
||||||
if attn_backend is None:
|
|
||||||
attn_backend = selected_backend
|
|
||||||
|
|
||||||
return current_platform.get_vit_attn_backend(
|
return current_platform.get_vit_attn_backend(
|
||||||
head_size,
|
head_size,
|
||||||
dtype,
|
dtype,
|
||||||
backend=attn_backend,
|
backend=attn_backend_override,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
|||||||
from collections import UserDict, defaultdict
|
from collections import UserDict, defaultdict
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import cached_property, partial
|
||||||
from itertools import accumulate
|
from itertools import accumulate
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
@@ -169,11 +169,42 @@ class PlaceholderRange:
|
|||||||
between `offset` and `offset + length` to assign embeddings to.
|
between `offset` and `offset + length` to assign embeddings to.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_num_embeds(self) -> int:
|
@cached_property
|
||||||
|
def embeds_cumsum(self) -> torch.Tensor | None:
|
||||||
if self.is_embed is None:
|
if self.is_embed is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self.is_embed.cumsum(dim=0)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def get_num_embeds(self) -> int:
|
||||||
|
if self.embeds_cumsum is None:
|
||||||
return self.length
|
return self.length
|
||||||
|
|
||||||
return int(self.is_embed.sum().item())
|
return int(self.embeds_cumsum[-1])
|
||||||
|
|
||||||
|
def get_embeds_indices_in_range(
|
||||||
|
self, start_idx: int, end_idx: int
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Returns the starting and ending indices of the embeddings of encoder outputs
|
||||||
|
in the range of [start_idx, end_idx) in the placeholders.
|
||||||
|
|
||||||
|
For example, given:
|
||||||
|
PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True])
|
||||||
|
|
||||||
|
If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get
|
||||||
|
the second and the third embeddings from the encoder output.
|
||||||
|
"""
|
||||||
|
if self.embeds_cumsum is None:
|
||||||
|
return start_idx, end_idx
|
||||||
|
|
||||||
|
embeds_start_idx = (
|
||||||
|
int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0
|
||||||
|
)
|
||||||
|
embeds_end_idx = int(self.embeds_cumsum[end_idx - 1])
|
||||||
|
|
||||||
|
return embeds_start_idx, embeds_end_idx
|
||||||
|
|
||||||
def extract_embeds_range(self) -> list[tuple[int, int]]:
|
def extract_embeds_range(self) -> list[tuple[int, int]]:
|
||||||
"""Extract the start and end indices of the embedded region in prompt.
|
"""Extract the start and end indices of the embedded region in prompt.
|
||||||
@@ -188,7 +219,7 @@ class PlaceholderRange:
|
|||||||
Returns full placeholder range if `is_embed` is `None`.
|
Returns full placeholder range if `is_embed` is `None`.
|
||||||
"""
|
"""
|
||||||
if self.is_embed is None:
|
if self.is_embed is None:
|
||||||
return [(self.offset, self.offset + self.length)]
|
return [(self.offset, self.offset + self.length - 1)]
|
||||||
|
|
||||||
mask_i = self.is_embed.int()
|
mask_i = self.is_embed.int()
|
||||||
starts = torch.nonzero(
|
starts = torch.nonzero(
|
||||||
|
|||||||
@@ -274,15 +274,11 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
def _get_mm_num_tokens(
|
def _get_mm_num_tokens(
|
||||||
self,
|
self,
|
||||||
mm_inputs: MultiModalInputs,
|
mm_inputs: MultiModalInputs,
|
||||||
mm_embeddings_only: bool = True,
|
|
||||||
) -> Mapping[str, int]:
|
) -> Mapping[str, int]:
|
||||||
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
modality: sum(
|
modality: sum(item.get_num_embeds for item in placeholders)
|
||||||
item.get_num_embeds() if mm_embeddings_only else item.length
|
|
||||||
for item in placeholders
|
|
||||||
)
|
|
||||||
for modality, placeholders in placeholders_by_modality.items()
|
for modality, placeholders in placeholders_by_modality.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -328,12 +324,15 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_mm_max_tokens(
|
def get_mm_max_tokens(
|
||||||
self,
|
self,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
mm_counts: Mapping[str, int] | None = None,
|
mm_counts: Mapping[str, int] | None = None,
|
||||||
mm_embeddings_only: bool = True,
|
|
||||||
) -> Mapping[str, int]:
|
) -> Mapping[str, int]:
|
||||||
|
"""
|
||||||
|
Returns the maximum number of embeddings per item of each modality, excluding
|
||||||
|
any break/text tokens in-between multimodal embeddings/encoder outputs.
|
||||||
|
"""
|
||||||
if mm_counts is None:
|
if mm_counts is None:
|
||||||
mm_counts = self.get_mm_limits()
|
mm_counts = self.get_mm_limits()
|
||||||
|
|
||||||
@@ -349,21 +348,4 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
}
|
}
|
||||||
|
|
||||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||||
return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only)
|
return self._get_mm_num_tokens(mm_inputs)
|
||||||
|
|
||||||
def get_mm_max_contiguous_tokens(
|
|
||||||
self,
|
|
||||||
seq_len: int,
|
|
||||||
mm_counts: Mapping[str, int] | None = None,
|
|
||||||
) -> Mapping[str, int]:
|
|
||||||
"""
|
|
||||||
Returns the maximum length of the multimodal (image placeholders+text)
|
|
||||||
tokens, including any break/text tokens in-between image embeddings.
|
|
||||||
|
|
||||||
`<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>`
|
|
||||||
Returns 9, even when the number of image embeddings is 6.
|
|
||||||
|
|
||||||
This is important to take into account when profiling and
|
|
||||||
initializing the encoder cache size.
|
|
||||||
"""
|
|
||||||
return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)
|
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ class MultiModalRegistry:
|
|||||||
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
|
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
|
||||||
)
|
)
|
||||||
|
|
||||||
return profiler.get_mm_max_contiguous_tokens(
|
return profiler.get_mm_max_tokens(
|
||||||
seq_len,
|
seq_len,
|
||||||
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
|
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -39,20 +39,26 @@ class EncoderCacheManager:
|
|||||||
space for new embeddings.
|
space for new embeddings.
|
||||||
Oldest cached embeddings with no request referenced will be first evicted.
|
Oldest cached embeddings with no request referenced will be first evicted.
|
||||||
|
|
||||||
|
NOTE: The EncoderCacheManager operates on the level of multimodal embeddings
|
||||||
|
instead of encoder tokens (i.e. all tokens that represent the multimodal data
|
||||||
|
in the input sequence). This means all break/text tokens in-between multimodal
|
||||||
|
embeddings are not considered with respect to the cache size and the number
|
||||||
|
of free slots.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cache_size: Limit the size of the cache, measured by the number of
|
cache_size: Limit the size of the cache, measured by the number of
|
||||||
tokens from the input sequence.
|
encoder embeddings from the input sequence.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
cache_size: Total cache capacity in encoder tokens.
|
cache_size: Total cache capacity in encoder embeddings.
|
||||||
num_free_slots: Current available cache capacity in encoder tokens.
|
num_free_slots: Current available cache capacity in encoder embeddings.
|
||||||
num_freeable_slots: Capacity that can be immediately reclaimed by
|
num_freeable_slots: Capacity that can be immediately reclaimed by
|
||||||
evicting entries with zero references (in encoder tokens).
|
evicting entries with zero references (in encoder embeddings).
|
||||||
cached: Mapping from mm_hash to a set of request IDs that currently
|
cached: Mapping from mm_hash to a set of request IDs that currently
|
||||||
reference the cached entry. If the set is empty, the entry exists
|
reference the cached entry. If the set is empty, the entry exists
|
||||||
but is not referenced by any request and is eligible for
|
but is not referenced by any request and is eligible for
|
||||||
reclamation.
|
reclamation.
|
||||||
freeable: List of tuples (mm_hash, num_tokens) representing entries
|
freeable: List of tuples (mm_hash, num_encoder_embeds) representing entries
|
||||||
whose no current running request is needed and that can be freed to
|
whose no current running request is needed and that can be freed to
|
||||||
make space when needed.
|
make space when needed.
|
||||||
freed: List of mm_hash strings that were actually evicted since the
|
freed: List of mm_hash strings that were actually evicted since the
|
||||||
@@ -67,7 +73,7 @@ class EncoderCacheManager:
|
|||||||
# mm_hash of mm_data => ids of requests that reference the mm_data
|
# mm_hash of mm_data => ids of requests that reference the mm_data
|
||||||
self.cached: dict[str, set[str]] = {}
|
self.cached: dict[str, set[str]] = {}
|
||||||
|
|
||||||
# mm_hash of mm_data => num_encoder_tokens of the mm_data
|
# mm_hash of mm_data => num_encoder_embeds of the mm_data
|
||||||
self.freeable: OrderedDict[str, int] = OrderedDict()
|
self.freeable: OrderedDict[str, int] = OrderedDict()
|
||||||
self.freed: list[str] = []
|
self.freed: list[str] = []
|
||||||
|
|
||||||
@@ -93,8 +99,8 @@ class EncoderCacheManager:
|
|||||||
|
|
||||||
# Cached but currently not referenced by any request
|
# Cached but currently not referenced by any request
|
||||||
if not self.cached[mm_hash]:
|
if not self.cached[mm_hash]:
|
||||||
num_tokens = self.freeable.pop(mm_hash)
|
num_encoder_embeds = self.freeable.pop(mm_hash)
|
||||||
self.num_freeable_slots -= num_tokens
|
self.num_freeable_slots -= num_encoder_embeds
|
||||||
|
|
||||||
self.cached[mm_hash].add(request.request_id)
|
self.cached[mm_hash].add(request.request_id)
|
||||||
return True
|
return True
|
||||||
@@ -104,7 +110,7 @@ class EncoderCacheManager:
|
|||||||
request: Request,
|
request: Request,
|
||||||
input_id: int,
|
input_id: int,
|
||||||
encoder_compute_budget: int,
|
encoder_compute_budget: int,
|
||||||
num_tokens_to_schedule: int,
|
num_embeds_to_schedule: int,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if there's sufficient cache space for a multimodal input.
|
"""Check if there's sufficient cache space for a multimodal input.
|
||||||
If there is, return True and update EncoderCacheManager state.
|
If there is, return True and update EncoderCacheManager state.
|
||||||
@@ -121,9 +127,9 @@ class EncoderCacheManager:
|
|||||||
Args:
|
Args:
|
||||||
request: The request containing the multimodal input.
|
request: The request containing the multimodal input.
|
||||||
input_id: Index of the multimodal input within the request.
|
input_id: Index of the multimodal input within the request.
|
||||||
encoder_compute_budget: Number of encoder tokens allowed to be
|
encoder_compute_budget: Number of encoder embeddings allowed to be
|
||||||
computed when this method is invoked.
|
computed when this method is invoked.
|
||||||
num_tokens_to_schedule: Number of tokens already scheduled to be
|
num_embeds_to_schedule: Number of encoder embeddings already scheduled to be
|
||||||
allocated with cache space when this method is invoked.
|
allocated with cache space when this method is invoked.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -134,30 +140,30 @@ class EncoderCacheManager:
|
|||||||
Note: This method does not allocate physical memory for the encoder
|
Note: This method does not allocate physical memory for the encoder
|
||||||
output but only the state of EncoderCacheManager.
|
output but only the state of EncoderCacheManager.
|
||||||
"""
|
"""
|
||||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
num_embeds = request.get_num_encoder_embeds(input_id)
|
||||||
|
|
||||||
# Not enough compute budget
|
# Not enough compute budget
|
||||||
if num_tokens > encoder_compute_budget:
|
if num_embeds > encoder_compute_budget:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
num_tokens += num_tokens_to_schedule
|
num_embeds += num_embeds_to_schedule
|
||||||
|
|
||||||
# Enough free slots
|
# Enough free slots
|
||||||
if num_tokens <= self.num_free_slots:
|
if num_embeds <= self.num_free_slots:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Not enough reclaimable slots
|
# Not enough reclaimable slots
|
||||||
if num_tokens > self.num_freeable_slots:
|
if num_embeds > self.num_freeable_slots:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Not enough free slots but enough reclaimable slots
|
# Not enough free slots but enough reclaimable slots
|
||||||
# NOTE: Eviction takes place here, but physical memory is not freed
|
# NOTE: Eviction takes place here, but physical memory is not freed
|
||||||
# until model runner is notified by the scheduler output.
|
# until model runner is notified by the scheduler output.
|
||||||
while num_tokens > self.num_free_slots:
|
while num_embeds > self.num_free_slots:
|
||||||
mm_hash, num_free_token = self.freeable.popitem(last=False)
|
mm_hash, num_free_embeds = self.freeable.popitem(last=False)
|
||||||
del self.cached[mm_hash]
|
del self.cached[mm_hash]
|
||||||
self.freed.append(mm_hash)
|
self.freed.append(mm_hash)
|
||||||
self.num_free_slots += num_free_token
|
self.num_free_slots += num_free_embeds
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def allocate(self, request: Request, input_id: int) -> None:
|
def allocate(self, request: Request, input_id: int) -> None:
|
||||||
@@ -176,16 +182,16 @@ class EncoderCacheManager:
|
|||||||
if mm_hash not in self.cached:
|
if mm_hash not in self.cached:
|
||||||
self.cached[mm_hash] = set()
|
self.cached[mm_hash] = set()
|
||||||
|
|
||||||
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
|
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||||
|
|
||||||
# NOTE: Encoder cache should always have enough space for encoder inputs
|
# NOTE: Encoder cache should always have enough space for encoder inputs
|
||||||
# that are scheduled since eviction takes place at can_allocate().
|
# that are scheduled since eviction takes place at can_allocate().
|
||||||
assert self.num_free_slots >= num_encoder_tokens
|
assert self.num_free_slots >= num_encoder_embeds
|
||||||
assert self.num_freeable_slots >= num_encoder_tokens
|
assert self.num_freeable_slots >= num_encoder_embeds
|
||||||
|
|
||||||
self.cached[mm_hash].add(request_id)
|
self.cached[mm_hash].add(request_id)
|
||||||
self.num_free_slots -= num_encoder_tokens
|
self.num_free_slots -= num_encoder_embeds
|
||||||
self.num_freeable_slots -= num_encoder_tokens
|
self.num_freeable_slots -= num_encoder_embeds
|
||||||
|
|
||||||
def get_cached_input_ids(self, request: Request) -> set[int]:
|
def get_cached_input_ids(self, request: Request) -> set[int]:
|
||||||
"""Get all cached multimodal input IDs for a request.
|
"""Get all cached multimodal input IDs for a request.
|
||||||
@@ -206,7 +212,7 @@ class EncoderCacheManager:
|
|||||||
|
|
||||||
When the reference set for the corresponding `mm_hash` becomes empty,
|
When the reference set for the corresponding `mm_hash` becomes empty,
|
||||||
the entry is appended to `freeable` and `num_freeable_slots` is
|
the entry is appended to `freeable` and `num_freeable_slots` is
|
||||||
increased by the number of encoder tokens for that input.
|
increased by the number of encoder embeddings for that input.
|
||||||
|
|
||||||
The entry is NOT physically freed until capacity is needed (e.g., by
|
The entry is NOT physically freed until capacity is needed (e.g., by
|
||||||
`can_allocate`).
|
`can_allocate`).
|
||||||
@@ -218,9 +224,9 @@ class EncoderCacheManager:
|
|||||||
return
|
return
|
||||||
self.cached[mm_hash].discard(req_id)
|
self.cached[mm_hash].discard(req_id)
|
||||||
if not self.cached[mm_hash]:
|
if not self.cached[mm_hash]:
|
||||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||||
self.freeable[mm_hash] = num_tokens
|
self.freeable[mm_hash] = num_encoder_embeds
|
||||||
self.num_freeable_slots += num_tokens
|
self.num_freeable_slots += num_encoder_embeds
|
||||||
|
|
||||||
def free(self, request: Request) -> None:
|
def free(self, request: Request) -> None:
|
||||||
"""Free all encoder input cache reference held by *request*.
|
"""Free all encoder input cache reference held by *request*.
|
||||||
@@ -361,20 +367,20 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
|
|||||||
request: Request,
|
request: Request,
|
||||||
input_id: int,
|
input_id: int,
|
||||||
encoder_compute_budget: int,
|
encoder_compute_budget: int,
|
||||||
num_tokens_to_schedule: int,
|
num_embeds_to_schedule: int,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||||
# Not enough compute budget
|
# Not enough compute budget
|
||||||
if num_tokens > encoder_compute_budget:
|
if num_encoder_embeds > encoder_compute_budget:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
num_tokens += num_tokens_to_schedule
|
num_encoder_embeds += num_embeds_to_schedule
|
||||||
# Enough free slots
|
# Enough free slots
|
||||||
return num_tokens <= self.num_free_slots
|
return num_encoder_embeds <= self.num_free_slots
|
||||||
|
|
||||||
def allocate(self, request: Request, input_id: int) -> None:
|
def allocate(self, request: Request, input_id: int) -> None:
|
||||||
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
|
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||||
self.num_free_slots -= num_encoder_tokens
|
self.num_free_slots -= num_encoder_embeds
|
||||||
|
|
||||||
mm_hash = request.mm_features[input_id].identifier
|
mm_hash = request.mm_features[input_id].identifier
|
||||||
self.freed.append(mm_hash)
|
self.freed.append(mm_hash)
|
||||||
@@ -392,5 +398,5 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
|
|||||||
return freed
|
return freed
|
||||||
|
|
||||||
def free_encoder_input(self, request: Request, input_id: int) -> None:
|
def free_encoder_input(self, request: Request, input_id: int) -> None:
|
||||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||||
self.num_free_slots += num_tokens
|
self.num_free_slots += num_encoder_embeds
|
||||||
|
|||||||
@@ -349,11 +349,11 @@ class Scheduler(SchedulerInterface):
|
|||||||
if preempted_encoder_inputs:
|
if preempted_encoder_inputs:
|
||||||
# Restore encoder compute budget if the preempted
|
# Restore encoder compute budget if the preempted
|
||||||
# request had encoder inputs scheduled in this step.
|
# request had encoder inputs scheduled in this step.
|
||||||
num_tokens_to_restore = sum(
|
num_embeds_to_restore = sum(
|
||||||
preempted_req.get_num_encoder_tokens(i)
|
preempted_req.get_num_encoder_embeds(i)
|
||||||
for i in preempted_encoder_inputs
|
for i in preempted_encoder_inputs
|
||||||
)
|
)
|
||||||
encoder_compute_budget += num_tokens_to_restore
|
encoder_compute_budget += num_embeds_to_restore
|
||||||
req_index -= 1
|
req_index -= 1
|
||||||
else:
|
else:
|
||||||
preempted_req = self.running.pop()
|
preempted_req = self.running.pop()
|
||||||
@@ -911,10 +911,11 @@ class Scheduler(SchedulerInterface):
|
|||||||
# multiple encoder inputs per request), we need to create temporary
|
# multiple encoder inputs per request), we need to create temporary
|
||||||
# trackers for accounting at the encoder input level.
|
# trackers for accounting at the encoder input level.
|
||||||
mm_hashes_to_schedule = set()
|
mm_hashes_to_schedule = set()
|
||||||
num_tokens_to_schedule = 0
|
num_embeds_to_schedule = 0
|
||||||
for i, mm_feature in enumerate(mm_features):
|
for i, mm_feature in enumerate(mm_features):
|
||||||
start_pos = mm_feature.mm_position.offset
|
start_pos = mm_feature.mm_position.offset
|
||||||
num_encoder_tokens = mm_feature.mm_position.length
|
num_encoder_tokens = mm_feature.mm_position.length
|
||||||
|
num_encoder_embeds = mm_feature.mm_position.get_num_embeds
|
||||||
|
|
||||||
# The encoder output is needed if the two ranges overlap:
|
# The encoder output is needed if the two ranges overlap:
|
||||||
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
|
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
|
||||||
@@ -970,9 +971,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
):
|
):
|
||||||
num_new_tokens = start_pos - num_computed_tokens
|
num_new_tokens = start_pos - num_computed_tokens
|
||||||
break
|
break
|
||||||
|
|
||||||
if not self.encoder_cache_manager.can_allocate(
|
if not self.encoder_cache_manager.can_allocate(
|
||||||
request, i, encoder_compute_budget, num_tokens_to_schedule
|
request, i, encoder_compute_budget, num_embeds_to_schedule
|
||||||
):
|
):
|
||||||
# The encoder cache is full or the encoder budget is exhausted.
|
# The encoder cache is full or the encoder budget is exhausted.
|
||||||
# NOTE(woosuk): We assume that the encoder input tokens should
|
# NOTE(woosuk): We assume that the encoder input tokens should
|
||||||
@@ -992,14 +992,31 @@ class Scheduler(SchedulerInterface):
|
|||||||
num_new_tokens = 0
|
num_new_tokens = 0
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Calculate the number of embeddings to schedule in the current range
|
||||||
|
# of scheduled encoder placholder tokens.
|
||||||
|
start_idx_rel = max(0, num_computed_tokens - start_pos)
|
||||||
|
end_idx_rel = min(
|
||||||
|
num_encoder_tokens, num_computed_tokens + num_new_tokens - start_pos
|
||||||
|
)
|
||||||
|
curr_embeds_start, curr_embeds_end = (
|
||||||
|
mm_feature.mm_position.get_embeds_indices_in_range(
|
||||||
|
start_idx_rel,
|
||||||
|
end_idx_rel,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# There's no embeddings in the current range of encoder placeholder tokens
|
||||||
|
# so we can skip the encoder input.
|
||||||
|
if curr_embeds_end - curr_embeds_start == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
if self.ec_connector is not None and remote_cache_has_item[i]:
|
if self.ec_connector is not None and remote_cache_has_item[i]:
|
||||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
||||||
external_load_encoder_input.append(i)
|
external_load_encoder_input.append(i)
|
||||||
num_tokens_to_schedule += num_encoder_tokens
|
num_embeds_to_schedule += num_encoder_embeds
|
||||||
continue
|
continue
|
||||||
|
|
||||||
num_tokens_to_schedule += num_encoder_tokens
|
num_embeds_to_schedule += num_encoder_embeds
|
||||||
encoder_compute_budget -= num_encoder_tokens
|
encoder_compute_budget -= num_encoder_embeds
|
||||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
||||||
encoder_inputs_to_schedule.append(i)
|
encoder_inputs_to_schedule.append(i)
|
||||||
|
|
||||||
|
|||||||
@@ -209,10 +209,10 @@ class Request:
|
|||||||
def get_finished_reason(self) -> FinishReason | None:
|
def get_finished_reason(self) -> FinishReason | None:
|
||||||
return RequestStatus.get_finished_reason(self.status)
|
return RequestStatus.get_finished_reason(self.status)
|
||||||
|
|
||||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||||
assert input_id < len(self.mm_features)
|
assert input_id < len(self.mm_features)
|
||||||
num_tokens = self.mm_features[input_id].mm_position.length
|
num_embeds = self.mm_features[input_id].mm_position.get_num_embeds
|
||||||
return num_tokens
|
return num_embeds
|
||||||
|
|
||||||
def record_event(
|
def record_event(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -169,9 +169,7 @@ from .utils import (
|
|||||||
MultiModalBudget,
|
MultiModalBudget,
|
||||||
add_kv_sharing_layers_to_kv_cache_groups,
|
add_kv_sharing_layers_to_kv_cache_groups,
|
||||||
bind_kv_cache,
|
bind_kv_cache,
|
||||||
gather_mm_placeholders,
|
|
||||||
sanity_check_mm_encoder_outputs,
|
sanity_check_mm_encoder_outputs,
|
||||||
scatter_mm_placeholders,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -2187,10 +2185,7 @@ class GPUModelRunner(
|
|||||||
|
|
||||||
# Cache the encoder outputs by mm_hash
|
# Cache the encoder outputs by mm_hash
|
||||||
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
|
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
|
||||||
self.encoder_cache[mm_hash] = scatter_mm_placeholders(
|
self.encoder_cache[mm_hash] = output
|
||||||
output,
|
|
||||||
is_embed=pos_info.is_embed,
|
|
||||||
)
|
|
||||||
logger.debug("Finish execute for mm hash %s", mm_hash)
|
logger.debug("Finish execute for mm hash %s", mm_hash)
|
||||||
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
|
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
|
||||||
|
|
||||||
@@ -2241,6 +2236,13 @@ class GPUModelRunner(
|
|||||||
num_encoder_tokens,
|
num_encoder_tokens,
|
||||||
)
|
)
|
||||||
assert start_idx < end_idx
|
assert start_idx < end_idx
|
||||||
|
curr_embeds_start, curr_embeds_end = (
|
||||||
|
pos_info.get_embeds_indices_in_range(start_idx, end_idx)
|
||||||
|
)
|
||||||
|
# If there are no embeddings in the current range, we skip
|
||||||
|
# gathering the embeddings.
|
||||||
|
if curr_embeds_start == curr_embeds_end:
|
||||||
|
continue
|
||||||
|
|
||||||
mm_hash = mm_feature.identifier
|
mm_hash = mm_feature.identifier
|
||||||
encoder_output = self.encoder_cache.get(mm_hash, None)
|
encoder_output = self.encoder_cache.get(mm_hash, None)
|
||||||
@@ -2248,16 +2250,14 @@ class GPUModelRunner(
|
|||||||
|
|
||||||
if (is_embed := pos_info.is_embed) is not None:
|
if (is_embed := pos_info.is_embed) is not None:
|
||||||
is_embed = is_embed[start_idx:end_idx]
|
is_embed = is_embed[start_idx:end_idx]
|
||||||
|
mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end]
|
||||||
|
else:
|
||||||
|
mm_embeds_item = encoder_output[start_idx:end_idx]
|
||||||
|
|
||||||
req_start_pos = req_start_idx + start_pos - num_computed_tokens
|
req_start_pos = req_start_idx + start_pos - num_computed_tokens
|
||||||
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
|
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
|
||||||
True if is_embed is None else is_embed
|
True if is_embed is None else is_embed
|
||||||
)
|
)
|
||||||
|
|
||||||
mm_embeds_item = gather_mm_placeholders(
|
|
||||||
encoder_output[start_idx:end_idx],
|
|
||||||
is_embed=is_embed,
|
|
||||||
)
|
|
||||||
mm_embeds_req.append(mm_embeds_item)
|
mm_embeds_req.append(mm_embeds_item)
|
||||||
|
|
||||||
if self.is_multimodal_pruning_enabled and self.uses_mrope:
|
if self.is_multimodal_pruning_enabled and self.uses_mrope:
|
||||||
@@ -4467,31 +4467,8 @@ class GPUModelRunner(
|
|||||||
dummy_encoder_outputs,
|
dummy_encoder_outputs,
|
||||||
expected_num_items=max_mm_items_per_batch,
|
expected_num_items=max_mm_items_per_batch,
|
||||||
)
|
)
|
||||||
|
for i, output in enumerate(dummy_encoder_outputs):
|
||||||
# NOTE: This happens when encoder cache needs to store
|
self.encoder_cache[f"tmp_{i}"] = output
|
||||||
# the embeddings that encoder outputs are scattered onto.
|
|
||||||
# In this case we create dummy embeddings of size
|
|
||||||
# (max_tokens_for_modality, hidden_size) and scatter
|
|
||||||
# encoder output into it.
|
|
||||||
encoder_output_shape = dummy_encoder_outputs[0].shape
|
|
||||||
max_mm_tokens_per_item = mm_budget.max_tokens_by_modality[
|
|
||||||
dummy_modality
|
|
||||||
]
|
|
||||||
if encoder_output_shape[0] < max_mm_tokens_per_item:
|
|
||||||
encoder_hidden_size = encoder_output_shape[-1]
|
|
||||||
expanded_outputs = []
|
|
||||||
for output in dummy_encoder_outputs:
|
|
||||||
expanded = output.new_zeros(
|
|
||||||
(max_mm_tokens_per_item, encoder_hidden_size)
|
|
||||||
)
|
|
||||||
num_tokens = output.shape[0]
|
|
||||||
expanded[:num_tokens].copy_(output)
|
|
||||||
expanded_outputs.append(expanded)
|
|
||||||
|
|
||||||
dummy_encoder_outputs = expanded_outputs
|
|
||||||
|
|
||||||
# Cache the dummy encoder outputs.
|
|
||||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
|
||||||
|
|
||||||
# Add `is_profile` here to pre-allocate communication buffers
|
# Add `is_profile` here to pre-allocate communication buffers
|
||||||
hidden_states, last_hidden_states = self._dummy_run(
|
hidden_states, last_hidden_states = self._dummy_run(
|
||||||
|
|||||||
@@ -4,10 +4,12 @@ from collections import defaultdict
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||||
from vllm.model_executor.models.utils import extract_layer_index
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
from vllm.multimodal.cache import processor_only_cache_from_config
|
from vllm.multimodal.cache import processor_only_cache_from_config
|
||||||
@@ -17,6 +19,8 @@ from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
|||||||
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
||||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MultiModalBudget:
|
class MultiModalBudget:
|
||||||
"""Helper class to calculate budget information for multi-modal models."""
|
"""Helper class to calculate budget information for multi-modal models."""
|
||||||
@@ -198,6 +202,7 @@ def sanity_check_mm_encoder_outputs(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated("`scatter_mm_placeholders` is deprecated and will be removed in v0.15.0.")
|
||||||
def scatter_mm_placeholders(
|
def scatter_mm_placeholders(
|
||||||
embeds: torch.Tensor,
|
embeds: torch.Tensor,
|
||||||
is_embed: torch.Tensor | None,
|
is_embed: torch.Tensor | None,
|
||||||
@@ -226,6 +231,7 @@ def scatter_mm_placeholders(
|
|||||||
return placeholders
|
return placeholders
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated("`gather_mm_placeholders` is deprecated and will be removed in v0.15.0.")
|
||||||
def gather_mm_placeholders(
|
def gather_mm_placeholders(
|
||||||
placeholders: torch.Tensor,
|
placeholders: torch.Tensor,
|
||||||
is_embed: torch.Tensor | None,
|
is_embed: torch.Tensor | None,
|
||||||
|
|||||||
Reference in New Issue
Block a user