Compare commits
12 Commits
v0.15.0rc0
...
v0.13.0rc4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
55f1fc1b1b | ||
|
|
17f3988094 | ||
|
|
682c38583c | ||
|
|
f124b56786 | ||
|
|
d78e128b8b | ||
|
|
761b730dcb | ||
|
|
f34eca5f01 | ||
|
|
4cd332f3cf | ||
|
|
16484d394c | ||
|
|
e397bd6592 | ||
|
|
6a88d590bb | ||
|
|
ad8c073131 |
@@ -1223,6 +1223,8 @@ steps:
|
||||
# FIXIT: find out which code initialize cuda before running the test
|
||||
# before the fix, we need to use spawn to test it
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
# Alot of these tests are on the edge of OOMing
|
||||
- export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||
# There is some Tensor Parallelism related processing logic in LoRA that
|
||||
# requires multi-GPU testing for validation.
|
||||
- pytest -v -s -x lora/test_chatglm3_tp.py
|
||||
|
||||
@@ -523,6 +523,8 @@ CUSTOM_OPS_QUANT_RMS_NORM = ["+quant_fp8,+rms_norm"]
|
||||
list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)),
|
||||
)
|
||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||
# TODO: remove skip after we fix the fusion thoroughly
|
||||
@pytest.mark.skipif(is_blackwell(), reason="Temporarily disabled on Blackwell")
|
||||
def test_rms_group_quant(
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
@@ -562,7 +564,7 @@ def test_rms_group_quant(
|
||||
splitting_ops=splitting_ops,
|
||||
# Common
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config=PassConfig(eliminate_noops=True, enable_fusion=True),
|
||||
pass_config=PassConfig(eliminate_noops=True, fuse_norm_quant=True),
|
||||
# Inductor caches custom passes by default as well via uuid
|
||||
inductor_compile_config={"force_disable_caches": True},
|
||||
)
|
||||
|
||||
@@ -254,7 +254,9 @@ async def test_single_chat_session_input_audio(
|
||||
async def test_chat_streaming_audio(
|
||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str
|
||||
):
|
||||
messages = dummy_messages_from_audio_url(audio_url)
|
||||
messages = dummy_messages_from_audio_url(
|
||||
audio_url, "What's a short title for this audio?"
|
||||
)
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
|
||||
203
tests/kernels/core/test_apply_rotary_emb.py
Normal file
203
tests/kernels/core/test_apply_rotary_emb.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests for ApplyRotaryEmb CustomOp dispatch behavior.
|
||||
|
||||
This test ensures that RotaryEmbedding classes correctly call the appropriate
|
||||
ApplyRotaryEmb methods based on the calling context:
|
||||
|
||||
1. RotaryEmbedding.forward_native() -> ApplyRotaryEmb.forward_native()
|
||||
2. RotaryEmbedding.forward_cuda() -> ApplyRotaryEmb.forward() (auto-dispatch)
|
||||
3. RotaryEmbedding.forward_hip() -> ApplyRotaryEmb.forward() (auto-dispatch)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
VllmConfig,
|
||||
get_cached_compilation_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
CUDA_DEVICES = ["cuda:0"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RotaryEmbeddingTestCase:
|
||||
"""Test case configuration for RotaryEmbedding dispatch tests."""
|
||||
|
||||
name: str
|
||||
rope_class: type
|
||||
rope_kwargs: dict
|
||||
method_name: str # forward_native, forward_cuda, forward
|
||||
positions_shape: tuple # (num_tokens,) or (3, num_tokens) or (4, num_tokens)
|
||||
expect_forward_native: bool # Should call ApplyRotaryEmb.forward_native()
|
||||
expect_forward: bool # Should call ApplyRotaryEmb.forward()
|
||||
|
||||
|
||||
def get_test_cases() -> list[RotaryEmbeddingTestCase]:
|
||||
"""Generate test cases for all RotaryEmbedding classes."""
|
||||
from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import (
|
||||
Ernie4_5_VLRotaryEmbedding,
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
|
||||
from vllm.model_executor.layers.rotary_embedding.xdrope import XDRotaryEmbedding
|
||||
|
||||
common_kwargs = {
|
||||
"head_size": 128,
|
||||
"rotary_dim": 128,
|
||||
"max_position_embeddings": 4096,
|
||||
"base": 10000,
|
||||
"is_neox_style": True,
|
||||
"dtype": torch.bfloat16,
|
||||
}
|
||||
|
||||
return [
|
||||
# MRotaryEmbedding tests
|
||||
RotaryEmbeddingTestCase(
|
||||
name="MRotaryEmbedding.forward_native",
|
||||
rope_class=MRotaryEmbedding,
|
||||
rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]},
|
||||
method_name="forward_native",
|
||||
positions_shape=(3, 32), # 2D for multimodal
|
||||
expect_forward_native=True,
|
||||
expect_forward=False,
|
||||
),
|
||||
RotaryEmbeddingTestCase(
|
||||
name="MRotaryEmbedding.forward_cuda_1d",
|
||||
rope_class=MRotaryEmbedding,
|
||||
rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]},
|
||||
method_name="forward_cuda",
|
||||
positions_shape=(32,), # 1D triggers apply_rotary_emb path
|
||||
expect_forward_native=False,
|
||||
expect_forward=True,
|
||||
),
|
||||
# XDRotaryEmbedding tests
|
||||
RotaryEmbeddingTestCase(
|
||||
name="XDRotaryEmbedding.forward",
|
||||
rope_class=XDRotaryEmbedding,
|
||||
rope_kwargs={
|
||||
**common_kwargs,
|
||||
"scaling_alpha": 1.0,
|
||||
"xdrope_section": [16, 16, 16, 16],
|
||||
},
|
||||
method_name="forward",
|
||||
positions_shape=(4, 32), # 4D for P/W/H/T
|
||||
expect_forward_native=False,
|
||||
expect_forward=True,
|
||||
),
|
||||
# Ernie4_5_VLRotaryEmbedding tests
|
||||
RotaryEmbeddingTestCase(
|
||||
name="Ernie4_5_VLRotaryEmbedding.forward_native",
|
||||
rope_class=Ernie4_5_VLRotaryEmbedding,
|
||||
rope_kwargs={**common_kwargs, "mrope_section": [22, 22, 20]},
|
||||
method_name="forward_native",
|
||||
positions_shape=(3, 32), # 2D for multimodal
|
||||
expect_forward_native=True,
|
||||
expect_forward=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def run_dispatch_test(
|
||||
test_case: RotaryEmbeddingTestCase,
|
||||
device: str,
|
||||
):
|
||||
"""Run a dispatch test for a RotaryEmbedding class."""
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(custom_ops=["all", "+apply_rotary_emb"])
|
||||
)
|
||||
get_cached_compilation_config.cache_clear()
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
rope = test_case.rope_class(**test_case.rope_kwargs).to(device=device)
|
||||
|
||||
apply_rotary_emb = rope.apply_rotary_emb
|
||||
|
||||
# Verify custom op is enabled
|
||||
if test_case.expect_forward_native:
|
||||
assert (
|
||||
apply_rotary_emb._forward_method != apply_rotary_emb.forward_native
|
||||
), "Test setup error: ApplyRotaryEmb custom op should be enabled"
|
||||
|
||||
# Setup call tracking
|
||||
call_tracker = {"forward_native_called": False, "forward_called": False}
|
||||
original_forward_native = apply_rotary_emb.forward_native
|
||||
original_forward = apply_rotary_emb.forward
|
||||
|
||||
def tracked_forward_native(*args, **kwargs):
|
||||
call_tracker["forward_native_called"] = True
|
||||
return original_forward_native(*args, **kwargs)
|
||||
|
||||
def tracked_forward(*args, **kwargs):
|
||||
call_tracker["forward_called"] = True
|
||||
return original_forward(*args, **kwargs)
|
||||
|
||||
apply_rotary_emb.forward_native = tracked_forward_native
|
||||
apply_rotary_emb.forward = tracked_forward
|
||||
|
||||
try:
|
||||
num_tokens = test_case.positions_shape[-1]
|
||||
num_q_heads = 8
|
||||
num_kv_heads = 2
|
||||
head_size = test_case.rope_kwargs["head_size"]
|
||||
max_position = test_case.rope_kwargs["max_position_embeddings"]
|
||||
|
||||
positions = torch.randint(
|
||||
0, max_position // 4, test_case.positions_shape, device=device
|
||||
)
|
||||
query = torch.randn(
|
||||
num_tokens, num_q_heads * head_size, dtype=torch.bfloat16, device=device
|
||||
)
|
||||
key = torch.randn(
|
||||
num_tokens,
|
||||
num_kv_heads * head_size,
|
||||
dtype=torch.bfloat16,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Call the method under test
|
||||
method = getattr(rope, test_case.method_name)
|
||||
method(positions, query.clone(), key.clone())
|
||||
|
||||
# Verify expectations
|
||||
if test_case.expect_forward_native:
|
||||
assert call_tracker["forward_native_called"], (
|
||||
f"{test_case.name} should call ApplyRotaryEmb.forward_native()"
|
||||
)
|
||||
if not test_case.expect_forward:
|
||||
assert not call_tracker["forward_called"], (
|
||||
f"{test_case.name} should NOT call ApplyRotaryEmb.forward(). "
|
||||
"Bug: when +apply_rotary_emb is enabled, forward_native() "
|
||||
"incorrectly dispatches to CUDA/HIP kernels."
|
||||
)
|
||||
if test_case.expect_forward:
|
||||
assert call_tracker["forward_called"], (
|
||||
f"{test_case.name} should call ApplyRotaryEmb.forward()"
|
||||
)
|
||||
finally:
|
||||
apply_rotary_emb.forward_native = original_forward_native
|
||||
apply_rotary_emb.forward = original_forward
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests."
|
||||
)
|
||||
@pytest.mark.parametrize("test_case", get_test_cases(), ids=lambda tc: tc.name)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_rotary_embedding_dispatch(
|
||||
test_case: RotaryEmbeddingTestCase,
|
||||
device: str,
|
||||
):
|
||||
"""
|
||||
Test that RotaryEmbedding classes dispatch to the correct ApplyRotaryEmb method.
|
||||
|
||||
- forward_native methods should call ApplyRotaryEmb.forward_native()
|
||||
- forward_cuda/forward methods should call ApplyRotaryEmb.forward()
|
||||
"""
|
||||
run_dispatch_test(test_case, device)
|
||||
@@ -1,17 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Literal, NamedTuple
|
||||
import os
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import pytest
|
||||
from huggingface_hub import hf_hub_download
|
||||
from pytest import MarkDecorator
|
||||
from transformers import AutoModelForImageTextToText
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
|
||||
from ....conftest import PromptImageInput, VllmRunner
|
||||
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
|
||||
@@ -21,9 +27,10 @@ class GGUFMMTestConfig(NamedTuple):
|
||||
gguf_backbone: str
|
||||
gguf_mmproj: str
|
||||
prompt: list[str]
|
||||
mm_data: dict[Literal["images"], PromptImageInput]
|
||||
image_names: list[str] # Store names, load PIL images at runtime
|
||||
max_model_len: int = 4096
|
||||
marks: list[MarkDecorator] = []
|
||||
mm_processor_kwargs: dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def gguf_model(self):
|
||||
@@ -31,27 +38,75 @@ class GGUFMMTestConfig(NamedTuple):
|
||||
return hf_hub_download(self.gguf_repo, filename=self.gguf_backbone)
|
||||
|
||||
|
||||
# Common prompts aligned with test_common.py "gemma3" entry format
|
||||
_GEMMA3_PROMPTS = IMAGE_ASSETS.prompts(
|
||||
{
|
||||
"stop_sign": (
|
||||
"<bos><start_of_turn>user\n"
|
||||
"<start_of_image>What's the content in the center of the image?"
|
||||
"<end_of_turn>\n<start_of_turn>model\n"
|
||||
),
|
||||
"cherry_blossom": (
|
||||
"<bos><start_of_turn>user\n"
|
||||
"<start_of_image>What is the season?"
|
||||
"<end_of_turn>\n<start_of_turn>model\n"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Image asset names - load at runtime to avoid pickle issues with subprocess
|
||||
_GEMMA3_IMAGE_NAMES = ["stop_sign", "cherry_blossom"]
|
||||
|
||||
# Regular multimodal (no pan-and-scan) - uses QAT Q4_0 GGUF
|
||||
GEMMA3_CONFIG = GGUFMMTestConfig(
|
||||
original_model="google/gemma-3-4b-it",
|
||||
gguf_repo="google/gemma-3-4b-it-qat-q4_0-gguf",
|
||||
gguf_backbone="gemma-3-4b-it-q4_0.gguf",
|
||||
gguf_mmproj="mmproj-model-f16-4B.gguf",
|
||||
prompt=["<start_of_image>Describe this image in detail:"],
|
||||
mm_data={"images": [ImageAsset("stop_sign").pil_image]},
|
||||
prompt=_GEMMA3_PROMPTS,
|
||||
image_names=_GEMMA3_IMAGE_NAMES,
|
||||
max_model_len=4096,
|
||||
marks=[pytest.mark.core_model],
|
||||
mm_processor_kwargs={},
|
||||
)
|
||||
|
||||
MODELS_TO_TEST = [GEMMA3_CONFIG]
|
||||
# Pan-and-scan multimodal - uses unquantized BF16 GGUF
|
||||
GEMMA3_CONFIG_PAN_AND_SCAN = GGUFMMTestConfig(
|
||||
original_model="google/gemma-3-4b-it",
|
||||
gguf_repo="unsloth/gemma-3-4b-it-GGUF",
|
||||
gguf_backbone="gemma-3-4b-it-BF16.gguf",
|
||||
gguf_mmproj="mmproj-BF16.gguf",
|
||||
prompt=_GEMMA3_PROMPTS,
|
||||
image_names=_GEMMA3_IMAGE_NAMES,
|
||||
max_model_len=4096,
|
||||
marks=[pytest.mark.core_model],
|
||||
mm_processor_kwargs={"do_pan_and_scan": True},
|
||||
)
|
||||
|
||||
MODELS_TO_TEST = [GEMMA3_CONFIG, GEMMA3_CONFIG_PAN_AND_SCAN]
|
||||
|
||||
|
||||
def run_multimodal_gguf_test(
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
model: GGUFMMTestConfig,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
):
|
||||
# Run gguf model.
|
||||
# Load images at runtime (inside subprocess) to avoid pickle issues
|
||||
images = [ImageAsset(name).pil_image for name in model.image_names]
|
||||
size_factors = [0.25, 0.5, 1.0]
|
||||
inputs_per_image = [
|
||||
(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
)
|
||||
for image, prompt in zip(images, model.prompt)
|
||||
]
|
||||
|
||||
# NOTE: Run vLLM first to avoid CUDA init issues with multiprocessing fork.
|
||||
# Run GGUF model via vLLM.
|
||||
with (
|
||||
set_default_torch_num_threads(1),
|
||||
vllm_runner(
|
||||
@@ -60,35 +115,42 @@ def run_multimodal_gguf_test(
|
||||
tokenizer_name=model.original_model,
|
||||
dtype=dtype,
|
||||
max_model_len=model.max_model_len,
|
||||
mm_processor_kwargs=model.mm_processor_kwargs,
|
||||
) as gguf_model,
|
||||
):
|
||||
gguf_outputs = gguf_model.generate_greedy_logprobs(
|
||||
prompts=model.prompt,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
**model.mm_data,
|
||||
)
|
||||
gguf_outputs_per_case = [
|
||||
gguf_model.generate_greedy_logprobs(
|
||||
prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images,
|
||||
)
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
# Run unquantized model.
|
||||
with vllm_runner(
|
||||
model_name=model.original_model,
|
||||
enforce_eager=True, # faster tests
|
||||
# Then run HfRunner for HuggingFace baseline comparison.
|
||||
with hf_runner(
|
||||
model.original_model,
|
||||
dtype=dtype,
|
||||
max_model_len=model.max_model_len,
|
||||
) as original_model:
|
||||
original_outputs = original_model.generate_greedy_logprobs(
|
||||
prompts=model.prompt,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
**model.mm_data,
|
||||
)
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
) as hf_model:
|
||||
hf_outputs_per_case = [
|
||||
hf_model.generate_greedy_logprobs_limit(
|
||||
prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images,
|
||||
)
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=original_outputs,
|
||||
outputs_1_lst=gguf_outputs,
|
||||
name_0="original",
|
||||
name_1="gguf",
|
||||
)
|
||||
for hf_outputs, gguf_outputs in zip(hf_outputs_per_case, gguf_outputs_per_case):
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=gguf_outputs,
|
||||
name_0="hf",
|
||||
name_1="gguf",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -105,11 +167,14 @@ def run_multimodal_gguf_test(
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_models(
|
||||
def test_gemma3_mm_gguf(
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
model: GGUFMMTestConfig,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
run_multimodal_gguf_test(vllm_runner, model, dtype, max_tokens, num_logprobs)
|
||||
run_multimodal_gguf_test(
|
||||
hf_runner, vllm_runner, model, dtype, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
@@ -388,6 +388,7 @@ def run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner):
|
||||
"mm_encoder_attn_backend",
|
||||
[None] + current_platform.get_supported_vit_attn_backends(),
|
||||
)
|
||||
@pytest.mark.skip(reason="Broken test due to memory segmentation fault")
|
||||
@create_new_process_for_each_test()
|
||||
def test_vit_backend_functionality(
|
||||
model_key: str,
|
||||
|
||||
@@ -60,12 +60,12 @@ def test_profiling(model_id: str, max_model_len: int):
|
||||
total_num_patches.item() + num_tiles.item() + 3
|
||||
) # image start, image, image end
|
||||
|
||||
profiled_tokens = profiler.get_mm_max_contiguous_tokens(
|
||||
profiled_tokens = profiler.get_mm_max_tokens(
|
||||
max_model_len,
|
||||
mm_counts=mm_counts,
|
||||
)
|
||||
|
||||
assert total_tokens == profiled_tokens["image"]
|
||||
assert total_num_patches == profiled_tokens["image"]
|
||||
assert total_tokens == sum(
|
||||
placeholder.length
|
||||
for placeholder in decoder_dummy_data.multi_modal_placeholders["image"]
|
||||
|
||||
@@ -9,6 +9,7 @@ from tempfile import NamedTemporaryFile, TemporaryDirectory
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image, ImageChops
|
||||
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
@@ -410,6 +411,97 @@ def test_argsort_mm_positions(case):
|
||||
assert modality_idxs == expected_modality_idxs
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_embed,expected",
|
||||
[
|
||||
(None, 5),
|
||||
(torch.tensor([True, True, True, True, True]), 5),
|
||||
(torch.tensor([False, False, False, False, False]), 0),
|
||||
(torch.tensor([True, False, True, False, True]), 3),
|
||||
(torch.tensor([True]), 1),
|
||||
],
|
||||
)
|
||||
def test_placeholder_range_get_num_embeds(is_embed, expected):
|
||||
length = len(is_embed) if is_embed is not None else 5
|
||||
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
|
||||
assert pr.get_num_embeds == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_embed,expected",
|
||||
[
|
||||
(None, None),
|
||||
(
|
||||
torch.tensor([False, True, False, True, True]),
|
||||
torch.tensor([0, 1, 1, 2, 3]),
|
||||
),
|
||||
(torch.tensor([True, True, True]), torch.tensor([1, 2, 3])),
|
||||
],
|
||||
)
|
||||
def test_placeholder_range_embeds_cumsum(is_embed, expected):
|
||||
length = len(is_embed) if is_embed is not None else 5
|
||||
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
|
||||
|
||||
if expected is None:
|
||||
assert pr.embeds_cumsum is None
|
||||
return
|
||||
|
||||
assert torch.equal(pr.embeds_cumsum, expected)
|
||||
# cached_property should return the same object on repeated access
|
||||
assert pr.embeds_cumsum is pr.embeds_cumsum
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_embed,start_idx,end_idx,expected",
|
||||
[
|
||||
(None, 2, 4, (2, 4)),
|
||||
(
|
||||
torch.tensor([False, True, False, True, True]),
|
||||
3,
|
||||
5,
|
||||
(1, 3),
|
||||
),
|
||||
(
|
||||
torch.tensor([False, True, False, True, True]),
|
||||
0,
|
||||
2,
|
||||
(0, 1),
|
||||
),
|
||||
(
|
||||
torch.tensor([True, False, True, False]),
|
||||
2,
|
||||
2,
|
||||
(1, 1),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_placeholder_range_get_embeds_indices_in_range(
|
||||
is_embed, start_idx, end_idx, expected
|
||||
):
|
||||
length = len(is_embed) if is_embed is not None else 5
|
||||
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
|
||||
assert pr.get_embeds_indices_in_range(start_idx, end_idx) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"offset,is_embed,expected",
|
||||
[
|
||||
(0, None, [(0, 4)]),
|
||||
(
|
||||
2,
|
||||
torch.tensor([False, True, False, True, True]),
|
||||
[(3, 3), (5, 6)],
|
||||
),
|
||||
(0, torch.tensor([True, True, True, True]), [(0, 3)]),
|
||||
(0, torch.tensor([False, False, False, False]), []),
|
||||
],
|
||||
)
|
||||
def test_placeholder_range_extract_embeds_range(offset, is_embed, expected):
|
||||
length = len(is_embed) if is_embed is not None else 5
|
||||
pr = PlaceholderRange(offset=offset, length=length, is_embed=is_embed)
|
||||
assert pr.extract_embeds_range() == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
|
||||
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
|
||||
@@ -23,7 +24,7 @@ class MockRequest:
|
||||
)
|
||||
self.mm_features.append(feature)
|
||||
|
||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
||||
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||
return self._token_counts[input_id]
|
||||
|
||||
|
||||
@@ -162,8 +163,8 @@ def test_schedule_request_multi_images_respect_space_limit():
|
||||
|
||||
num_tokens_to_schedule = 0
|
||||
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
|
||||
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
|
||||
compute_budget -= req.get_num_encoder_tokens(0)
|
||||
num_tokens_to_schedule += req.get_num_encoder_embeds(0)
|
||||
compute_budget -= req.get_num_encoder_embeds(0)
|
||||
|
||||
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
|
||||
|
||||
@@ -174,7 +175,75 @@ def test_schedule_request_multi_images_respect_compute_limit():
|
||||
compute_budget = 10
|
||||
num_tokens_to_schedule = 0
|
||||
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
|
||||
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
|
||||
compute_budget -= req.get_num_encoder_tokens(0)
|
||||
num_tokens_to_schedule += req.get_num_encoder_embeds(0)
|
||||
compute_budget -= req.get_num_encoder_embeds(0)
|
||||
|
||||
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
|
||||
|
||||
|
||||
def test_encoder_cache_with_is_embed_mask():
|
||||
class MockRequestWithMask(MockRequest):
|
||||
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||
return self.mm_features[input_id].mm_position.get_num_embeds
|
||||
|
||||
is_embed = torch.zeros(100, dtype=torch.bool)
|
||||
is_embed[torch.tensor([5, 15, 25, 35, 45, 55, 65, 75])] = True
|
||||
|
||||
request = MockRequestWithMask("r1", ["img1"], [100])
|
||||
request.mm_features[0] = MultiModalFeatureSpec(
|
||||
data=None,
|
||||
modality="image",
|
||||
identifier="img1",
|
||||
mm_position=PlaceholderRange(offset=0, length=100, is_embed=is_embed),
|
||||
)
|
||||
|
||||
manager = EncoderCacheManager(cache_size=100)
|
||||
manager.allocate(request, 0)
|
||||
|
||||
assert manager.num_free_slots == 92
|
||||
assert "img1" in manager.cached
|
||||
|
||||
old_size = 100
|
||||
new_size = request.mm_features[0].mm_position.get_num_embeds
|
||||
assert new_size == 8
|
||||
savings_ratio = old_size / new_size
|
||||
assert savings_ratio == 12.5
|
||||
|
||||
|
||||
def test_encoder_cache_mask_based_retrieval():
|
||||
class MockRequestWithMask(MockRequest):
|
||||
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||
return self.mm_features[input_id].mm_position.get_num_embeds
|
||||
|
||||
is_embed = torch.tensor(
|
||||
[False, False, True, True, False, True, True, True, False, False]
|
||||
)
|
||||
|
||||
request = MockRequestWithMask("r1", ["img1"], [10])
|
||||
request.mm_features[0] = MultiModalFeatureSpec(
|
||||
data=None,
|
||||
modality="image",
|
||||
identifier="img1",
|
||||
mm_position=PlaceholderRange(offset=0, length=10, is_embed=is_embed),
|
||||
)
|
||||
|
||||
manager = EncoderCacheManager(cache_size=50)
|
||||
manager.allocate(request, 0)
|
||||
|
||||
assert request.mm_features[0].mm_position.get_num_embeds == 5
|
||||
|
||||
start_idx = 2
|
||||
end_idx = 8
|
||||
num_embeds_before = is_embed[:start_idx].sum().item()
|
||||
num_embeds_in_range = is_embed[start_idx:end_idx].sum().item()
|
||||
|
||||
assert num_embeds_before == 0
|
||||
assert num_embeds_in_range == 5
|
||||
|
||||
start_idx = 0
|
||||
end_idx = 5
|
||||
num_embeds_before = is_embed[:start_idx].sum().item() if start_idx > 0 else 0
|
||||
num_embeds_in_range = is_embed[start_idx:end_idx].sum().item()
|
||||
|
||||
assert num_embeds_before == 0
|
||||
assert num_embeds_in_range == 2
|
||||
|
||||
@@ -38,7 +38,7 @@ class MockRequest:
|
||||
)
|
||||
self.mm_features.append(feature)
|
||||
|
||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
||||
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||
assert input_id < len(self._token_counts)
|
||||
return self._token_counts[input_id]
|
||||
|
||||
|
||||
@@ -86,6 +86,9 @@ def kernel_unified_attention_2d(
|
||||
USE_SOFTCAP: tl.constexpr, # bool
|
||||
USE_SINKS: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
USE_MM_PREFIX: tl.constexpr, # bool
|
||||
MAX_MM_RANGES: tl.constexpr, # int
|
||||
mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
stride_k_cache_2: tl.int64, # int
|
||||
@@ -270,7 +273,38 @@ def kernel_unified_attention_2d(
|
||||
else:
|
||||
V = V_load
|
||||
|
||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
||||
# Compute attention mask: causal by default (key <= query)
|
||||
query_abs_pos = context_len + query_pos[:, None]
|
||||
seq_mask = seq_offset[None, :] <= query_abs_pos
|
||||
|
||||
# Apply sliding window to base mask BEFORE mm_prefix OR.
|
||||
# Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix
|
||||
if SLIDING_WINDOW > 0:
|
||||
seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW)
|
||||
|
||||
# PrefixLM: extend mask with bidirectional ranges for multimodal tokens.
|
||||
# Applied AFTER sliding window so mm_prefix ranges override SW restriction.
|
||||
if USE_MM_PREFIX:
|
||||
for i in range(MAX_MM_RANGES):
|
||||
range_start = tl.load(
|
||||
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2
|
||||
)
|
||||
range_end = tl.load(
|
||||
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + 1
|
||||
)
|
||||
|
||||
is_valid = range_start < range_end
|
||||
q_in_range = (
|
||||
(query_abs_pos >= range_start)
|
||||
& (query_abs_pos <= range_end)
|
||||
& is_valid
|
||||
)
|
||||
k_in_range = (
|
||||
(seq_offset[None, :] >= range_start)
|
||||
& (seq_offset[None, :] <= range_end)
|
||||
& is_valid
|
||||
)
|
||||
seq_mask |= q_in_range & k_in_range
|
||||
|
||||
# S : (BLOCK_M, TILE_SIZE)
|
||||
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||
@@ -284,13 +318,6 @@ def kernel_unified_attention_2d(
|
||||
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
|
||||
)
|
||||
|
||||
if SLIDING_WINDOW > 0:
|
||||
S = tl.where(
|
||||
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
|
||||
S,
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
if USE_ALIBI_SLOPES:
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
|
||||
@@ -398,6 +425,9 @@ def kernel_unified_attention_3d(
|
||||
num_seqs: tl.int32,
|
||||
BLOCK_M: tl.constexpr, # int
|
||||
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
||||
USE_MM_PREFIX: tl.constexpr, # bool
|
||||
MAX_MM_RANGES: tl.constexpr, # int
|
||||
mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence
|
||||
):
|
||||
q_block_global_idx = tl.program_id(0)
|
||||
kv_head_idx = tl.program_id(1)
|
||||
@@ -559,7 +589,38 @@ def kernel_unified_attention_3d(
|
||||
else:
|
||||
V = V_load
|
||||
|
||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
||||
# Compute attention mask: causal by default (key <= query)
|
||||
query_abs_pos = context_len + query_pos[:, None]
|
||||
seq_mask = seq_offset[None, :] <= query_abs_pos
|
||||
|
||||
# Apply sliding window to base mask BEFORE mm_prefix OR.
|
||||
# Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix
|
||||
if SLIDING_WINDOW > 0:
|
||||
seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW)
|
||||
|
||||
# PrefixLM: extend mask with bidirectional ranges for multimodal tokens.
|
||||
# Applied AFTER sliding window so mm_prefix ranges override SW restriction.
|
||||
if USE_MM_PREFIX:
|
||||
for i in range(MAX_MM_RANGES):
|
||||
range_start = tl.load(
|
||||
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2
|
||||
)
|
||||
range_end = tl.load(
|
||||
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + 1
|
||||
)
|
||||
|
||||
is_valid = range_start < range_end
|
||||
q_in_range = (
|
||||
(query_abs_pos >= range_start)
|
||||
& (query_abs_pos <= range_end)
|
||||
& is_valid
|
||||
)
|
||||
k_in_range = (
|
||||
(seq_offset[None, :] >= range_start)
|
||||
& (seq_offset[None, :] <= range_end)
|
||||
& is_valid
|
||||
)
|
||||
seq_mask |= q_in_range & k_in_range
|
||||
|
||||
# S : (BLOCK_M, TILE_SIZE)
|
||||
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||
@@ -572,13 +633,6 @@ def kernel_unified_attention_3d(
|
||||
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
|
||||
)
|
||||
|
||||
if SLIDING_WINDOW > 0:
|
||||
S = tl.where(
|
||||
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
|
||||
S,
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
if USE_ALIBI_SLOPES:
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
|
||||
@@ -732,6 +786,43 @@ def reduce_segments(
|
||||
tl.store(output_ptr + output_offset, acc, mask=dim_mask)
|
||||
|
||||
|
||||
def _is_gemma3_attention(head_size: int, sliding_window: int) -> bool:
|
||||
"""Detect Gemma3 models via unique (head_size, sliding_window) signature.
|
||||
|
||||
Gemma3 models are the only ones using sliding_window=1024 with
|
||||
head_size 128 (27B) or 256 (1B, 4B, 12B). Other SWA models use
|
||||
different window sizes (Mistral=4096, Phi-3=2047).
|
||||
"""
|
||||
return sliding_window == 1024 and head_size in (128, 256)
|
||||
|
||||
|
||||
def _get_tile_size(
|
||||
head_size: int,
|
||||
sliding_window: int,
|
||||
element_size: int,
|
||||
is_mm_prefix: bool,
|
||||
is_prefill: bool,
|
||||
) -> int:
|
||||
"""Select tile size with Gemma3-specific optimization.
|
||||
|
||||
For Gemma3, use 32 for both prefill and decode to better utilize
|
||||
the larger head dimension (128/256). For other models, use
|
||||
the default vLLM behavior.
|
||||
"""
|
||||
if is_mm_prefix:
|
||||
# Multimodal bidirectional attention needs a larger tile size
|
||||
return 64
|
||||
|
||||
if _is_gemma3_attention(head_size, sliding_window):
|
||||
# Gemma3: use 32 for decode (default is 16)
|
||||
return 32
|
||||
|
||||
# Default behavior
|
||||
if is_prefill:
|
||||
return 32
|
||||
return 16 if element_size >= 2 else 32
|
||||
|
||||
|
||||
def unified_attention(
|
||||
q,
|
||||
k,
|
||||
@@ -759,6 +850,8 @@ def unified_attention(
|
||||
qq_bias=None,
|
||||
# Optional tensor for sinks
|
||||
sinks=None,
|
||||
# Optional tensor for prefix lengths (PrefixLM support)
|
||||
mm_prefix_range=None,
|
||||
):
|
||||
assert causal, "Only causal attention is supported"
|
||||
assert q_descale is None, "Q scales not supported"
|
||||
@@ -766,6 +859,17 @@ def unified_attention(
|
||||
if sinks is not None:
|
||||
assert sinks.shape[0] == q.shape[1], "Sinks must be num_query_heads size"
|
||||
|
||||
use_mm_prefix = False
|
||||
max_mm_ranges = 0
|
||||
if mm_prefix_range is not None:
|
||||
if mm_prefix_range.ndim == 3:
|
||||
use_mm_prefix = True
|
||||
max_mm_ranges = mm_prefix_range.shape[1]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported mm_prefix_range shape: {mm_prefix_range.shape}"
|
||||
)
|
||||
|
||||
use_alibi_slopes = alibi_slopes is not None
|
||||
use_qq_bias = qq_bias is not None
|
||||
|
||||
@@ -792,11 +896,23 @@ def unified_attention(
|
||||
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
|
||||
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
|
||||
|
||||
# Assigning default tile sizes for prefill and decode.
|
||||
# Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1)
|
||||
# and at least 16 for all other data types.
|
||||
TILE_SIZE_PREFILL = 32
|
||||
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
|
||||
# Tile sizes for prefill and decode. Gemma3 models use optimized values.
|
||||
# Note: tile size must be at least 32 for fp8 (element_size == 1).
|
||||
sliding_window_val = 1 + window_size[0] if window_size[0] >= 0 else 0
|
||||
TILE_SIZE_PREFILL = _get_tile_size(
|
||||
head_size,
|
||||
sliding_window_val,
|
||||
q.element_size(),
|
||||
is_mm_prefix=use_mm_prefix,
|
||||
is_prefill=True,
|
||||
)
|
||||
TILE_SIZE_DECODE = _get_tile_size(
|
||||
head_size,
|
||||
sliding_window_val,
|
||||
q.element_size(),
|
||||
is_mm_prefix=use_mm_prefix,
|
||||
is_prefill=False,
|
||||
)
|
||||
|
||||
# Launch the 2D kernel if
|
||||
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
|
||||
@@ -847,6 +963,9 @@ def unified_attention(
|
||||
USE_QQ_BIAS=use_qq_bias,
|
||||
USE_SOFTCAP=(softcap > 0),
|
||||
USE_SINKS=(sinks is not None),
|
||||
USE_MM_PREFIX=use_mm_prefix,
|
||||
MAX_MM_RANGES=max_mm_ranges,
|
||||
mm_prefix_range_ptr=mm_prefix_range,
|
||||
SLIDING_WINDOW=(1 + window_size[0]),
|
||||
stride_k_cache_0=k.stride(0),
|
||||
stride_k_cache_1=k.stride(1),
|
||||
@@ -895,6 +1014,9 @@ def unified_attention(
|
||||
USE_QQ_BIAS=use_qq_bias,
|
||||
USE_SOFTCAP=(softcap > 0),
|
||||
USE_SINKS=(sinks is not None),
|
||||
USE_MM_PREFIX=use_mm_prefix,
|
||||
MAX_MM_RANGES=max_mm_ranges,
|
||||
mm_prefix_range_ptr=mm_prefix_range,
|
||||
SLIDING_WINDOW=(1 + window_size[0]),
|
||||
stride_k_cache_0=k.stride(0),
|
||||
stride_k_cache_1=k.stride(1),
|
||||
|
||||
@@ -16,6 +16,7 @@ import einops
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
|
||||
@@ -89,6 +90,13 @@ def torch_sdpa_wrapper(
|
||||
v: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# Never remove the contiguous logic for ROCm
|
||||
# Without it, hallucinations occur with the backend
|
||||
if current_platform.is_rocm():
|
||||
q = q.contiguous()
|
||||
k = k.contiguous()
|
||||
v = v.contiguous()
|
||||
|
||||
outputs = []
|
||||
|
||||
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
|
||||
@@ -144,7 +144,7 @@ class ECExampleConnector(ECConnectorBase):
|
||||
Update ECConnector state after encoder cache allocation.
|
||||
"""
|
||||
mm_hash = request.mm_features[index].identifier
|
||||
num_encoder_token = request.get_num_encoder_tokens(index)
|
||||
num_encoder_token = request.get_num_encoder_embeds(index)
|
||||
# Insert mm_hash only if this block has not been recorded yet.
|
||||
self._mm_datas_need_loads[mm_hash] = num_encoder_token
|
||||
|
||||
|
||||
@@ -795,7 +795,10 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
top_k,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
# expert_tokens_meta help in allocating optimal/minimal
|
||||
# amount of workspace. Mark it None, so we allocate for
|
||||
# the worst-case scenario.
|
||||
expert_tokens_meta=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -27,6 +27,10 @@ from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8LinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
maybe_create_device_identity,
|
||||
)
|
||||
from vllm.model_executor.parameter import ModelWeightParameter
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -305,6 +309,37 @@ class XPUFp8LinearMethod(Fp8LinearMethod):
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
super().__init__(quant_config)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
maybe_create_device_identity()
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# If checkpoint not serialized fp8, quantize the weights.
|
||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
|
||||
from .common import apply_rotary_emb_torch
|
||||
from .common import ApplyRotaryEmb
|
||||
|
||||
|
||||
@CustomOp.register("rotary_embedding")
|
||||
@@ -49,6 +49,10 @@ class RotaryEmbeddingBase(CustomOp):
|
||||
rocm_aiter_ops.is_triton_rotary_embed_enabled()
|
||||
)
|
||||
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(
|
||||
is_neox_style=self.is_neox_style,
|
||||
)
|
||||
|
||||
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
||||
"""Compute the inverse frequency."""
|
||||
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||
@@ -123,7 +127,12 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
||||
query = query.view(num_tokens, -1, head_size)
|
||||
query_rot = query[..., :rotary_dim]
|
||||
query_pass = query[..., rotary_dim:]
|
||||
query_rot = apply_rotary_emb_torch(query_rot, cos, sin, is_neox_style)
|
||||
query_rot = ApplyRotaryEmb.forward_static(
|
||||
query_rot,
|
||||
cos,
|
||||
sin,
|
||||
is_neox_style,
|
||||
)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
# key may be None in some cases, e.g. cross-layer KV sharing
|
||||
@@ -132,7 +141,12 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
||||
key = key.view(num_tokens, -1, head_size)
|
||||
key_rot = key[..., :rotary_dim]
|
||||
key_pass = key[..., rotary_dim:]
|
||||
key_rot = apply_rotary_emb_torch(key_rot, cos, sin, is_neox_style)
|
||||
key_rot = ApplyRotaryEmb.forward_static(
|
||||
key_rot,
|
||||
cos,
|
||||
sin,
|
||||
is_neox_style,
|
||||
)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
|
||||
@@ -2,19 +2,14 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from functools import cache
|
||||
from importlib.util import find_spec
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -32,71 +27,6 @@ def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
||||
return x.flatten(-2)
|
||||
|
||||
|
||||
def apply_rotary_emb_torch(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
) -> torch.Tensor:
|
||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||
if is_neox_style:
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
else:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
if is_neox_style:
|
||||
return torch.cat((o1, o2), dim=-1)
|
||||
else:
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
|
||||
def apply_rotary_emb_dispatch(
|
||||
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, is_neox_style: bool
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [num_tokens, num_heads, head_size]
|
||||
cos: [num_tokens, head_size // 2]
|
||||
sin: [num_tokens, head_size // 2]
|
||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
||||
positional embeddings.
|
||||
"""
|
||||
if current_platform.is_cuda():
|
||||
return apply_rotary_emb(x.unsqueeze(0), cos, sin, not is_neox_style).squeeze(0)
|
||||
else:
|
||||
return apply_rotary_emb_torch(x, cos, sin, is_neox_style)
|
||||
|
||||
|
||||
@cache
|
||||
def dispatch_rotary_emb_function(
|
||||
default: Callable[..., torch.Tensor] | None = None,
|
||||
) -> Callable[..., torch.Tensor]:
|
||||
if current_platform.is_cuda():
|
||||
return apply_rotary_emb
|
||||
|
||||
# if torch compile is not enabled
|
||||
# use rotary embedding function from flash_attn package
|
||||
# otherwise use the naive pytorch embedding implementation
|
||||
# is faster when torch compile is enabled.
|
||||
if current_platform.is_rocm() and not torch.compiler.is_compiling():
|
||||
if find_spec("flash_attn") is not None:
|
||||
from flash_attn.ops.triton.rotary import apply_rotary
|
||||
|
||||
return apply_rotary
|
||||
else:
|
||||
logger.warning(
|
||||
"flash_attn is not installed. Falling back to PyTorch "
|
||||
"implementation for rotary embeddings."
|
||||
)
|
||||
if default is not None:
|
||||
return default
|
||||
|
||||
return apply_rotary_emb_torch
|
||||
|
||||
|
||||
# yarn functions
|
||||
# Inverse dim formula to find dim based on number of rotations
|
||||
def yarn_find_correction_dim(
|
||||
@@ -186,3 +116,164 @@ direct_register_custom_op(
|
||||
mutates_args=["query", "key"], # These tensors are modified in-place
|
||||
fake_impl=_flashinfer_rotary_embedding_fake,
|
||||
)
|
||||
|
||||
|
||||
@CustomOp.register("apply_rotary_emb")
|
||||
class ApplyRotaryEmb(CustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
enforce_enable: bool = False,
|
||||
is_neox_style: bool = True,
|
||||
enable_fp32_compute: bool = False,
|
||||
) -> None:
|
||||
super().__init__(enforce_enable)
|
||||
self.is_neox_style = is_neox_style
|
||||
self.enable_fp32_compute = enable_fp32_compute
|
||||
|
||||
self.apply_rotary_emb_flash_attn = None
|
||||
if find_spec("flash_attn") is not None:
|
||||
from flash_attn.ops.triton.rotary import apply_rotary
|
||||
|
||||
self.apply_rotary_emb_flash_attn = apply_rotary
|
||||
|
||||
@staticmethod
|
||||
def forward_static(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool = True,
|
||||
enable_fp32_compute: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [batch_size (optional), seq_len, num_heads, head_size]
|
||||
cos: [seq_len, head_size // 2]
|
||||
sin: [seq_len, head_size // 2]
|
||||
is_neox_style: Whether to use the Neox-style or GPT-J-style.
|
||||
enable_fp32_compute: Temporarily convert x, cos, sin to FP32 dtype
|
||||
for higher accuracy.
|
||||
"""
|
||||
origin_dtype = x.dtype
|
||||
if enable_fp32_compute:
|
||||
x = x.float()
|
||||
|
||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||
|
||||
if is_neox_style:
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
else:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
|
||||
if is_neox_style:
|
||||
output = torch.cat((o1, o2), dim=-1)
|
||||
else:
|
||||
output = torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
if enable_fp32_compute:
|
||||
output = output.to(origin_dtype)
|
||||
return output
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
output = self.forward_static(
|
||||
x, cos, sin, self.is_neox_style, self.enable_fp32_compute
|
||||
)
|
||||
return output
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
|
||||
origin_dtype = x.dtype
|
||||
if self.enable_fp32_compute:
|
||||
x = x.float()
|
||||
cos = cos.float()
|
||||
sin = sin.float()
|
||||
|
||||
origin_shape = x.shape
|
||||
if len(origin_shape) == 3:
|
||||
# x: [seq_len, num_heads, head_size]
|
||||
x = x.unsqueeze(0)
|
||||
|
||||
"""
|
||||
Arguments of apply_rotary_emb() in vllm_flash_attn:
|
||||
x: [batch_size, seq_len, nheads, headdim]
|
||||
cos, sin: [seqlen_rotary, rotary_dim / 2]
|
||||
interleaved: defalut as False (Neox-style).
|
||||
...
|
||||
"""
|
||||
interleaved = not self.is_neox_style
|
||||
output = apply_rotary_emb(x, cos, sin, interleaved)
|
||||
|
||||
if len(origin_shape) == 3:
|
||||
output = output.squeeze(0)
|
||||
if self.enable_fp32_compute:
|
||||
output = output.to(origin_dtype)
|
||||
return output
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if self.apply_rotary_emb_flash_attn is not None:
|
||||
origin_dtype = x.dtype
|
||||
if self.enable_fp32_compute:
|
||||
x = x.float()
|
||||
cos = cos.float()
|
||||
sin = sin.float()
|
||||
|
||||
origin_shape = x.shape
|
||||
if len(origin_shape) == 3:
|
||||
# x: [seq_len, num_heads, head_size]
|
||||
x = x.unsqueeze(0)
|
||||
|
||||
"""
|
||||
Arguments of apply_rotary() in flash_attn:
|
||||
x: [batch_size, seq_len, nheads, headdim]
|
||||
cos, sin: [seqlen_rotary, rotary_dim / 2]
|
||||
interleaved: defalut as False (Neox-style).
|
||||
...
|
||||
"""
|
||||
interleaved = not self.is_neox_style
|
||||
output = self.apply_rotary_emb_flash_attn(
|
||||
x, cos, sin, interleaved=interleaved
|
||||
).type_as(x)
|
||||
|
||||
if len(origin_shape) == 3:
|
||||
output = output.squeeze(0)
|
||||
if self.enable_fp32_compute:
|
||||
output = output.to(origin_dtype)
|
||||
else:
|
||||
# Falling back to PyTorch native implementation.
|
||||
output = self.forward_native(x, cos, sin)
|
||||
|
||||
return output
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# TODO (bigPYJ1151): need to enable fused CPU ROPE here
|
||||
return self.forward_native(x, cos, sin)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"is_neox_style={self.is_neox_style}"
|
||||
s += f"enable_fp32_compute={self.enable_fp32_compute}"
|
||||
return s
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
|
||||
import torch
|
||||
|
||||
from .common import apply_rotary_emb_dispatch
|
||||
from .mrope import MRotaryEmbedding
|
||||
|
||||
|
||||
@@ -55,14 +54,22 @@ class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
|
||||
query_rot = self.apply_rotary_emb.forward_native(
|
||||
query_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
|
||||
key_rot = self.apply_rotary_emb.forward_native(
|
||||
key_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .base import RotaryEmbeddingBase
|
||||
from .common import apply_rotary_emb_dispatch
|
||||
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale
|
||||
|
||||
|
||||
@@ -301,14 +300,22 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
|
||||
query_rot = self.apply_rotary_emb.forward_native(
|
||||
query_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
|
||||
key_rot = self.apply_rotary_emb.forward_native(
|
||||
key_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
@@ -347,13 +354,21 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
|
||||
query_rot = self.apply_rotary_emb(
|
||||
query_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
|
||||
key_rot = self.apply_rotary_emb(
|
||||
key_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .common import apply_rotary_emb_dispatch
|
||||
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
|
||||
|
||||
|
||||
@@ -36,7 +35,7 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
|
||||
dtype,
|
||||
)
|
||||
|
||||
def forward(
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
@@ -68,14 +67,73 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
|
||||
query_rot = self.apply_rotary_emb.forward_native(
|
||||
query_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
|
||||
key_rot = self.apply_rotary_emb.forward_native(
|
||||
key_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor | None = None,
|
||||
offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""PyTorch-native implementation equivalent to forward().
|
||||
|
||||
Args:
|
||||
positions:
|
||||
[4, num_tokens] (P/W/H/T positions with multimodal inputs)
|
||||
query: [num_tokens, num_heads * head_size]
|
||||
key: [num_tokens, num_kv_heads * head_size]
|
||||
"""
|
||||
assert positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
cos = torch.cat(
|
||||
[m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1
|
||||
)
|
||||
sin = torch.cat(
|
||||
[m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1
|
||||
)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., : self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim :]
|
||||
query_rot = self.apply_rotary_emb(
|
||||
query_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., : self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim :]
|
||||
key_rot = self.apply_rotary_emb(
|
||||
key_rot,
|
||||
cos,
|
||||
sin,
|
||||
)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
|
||||
@@ -29,6 +29,9 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
ApplyRotaryEmb,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
@@ -158,32 +161,6 @@ class DotsOCRProcessingInfo(Qwen2VLProcessingInfo):
|
||||
return processor
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(
|
||||
tensor: torch.Tensor, freqs: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
orig_dtype = tensor.dtype
|
||||
tensor = tensor.float()
|
||||
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
|
||||
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
||||
|
||||
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
||||
|
||||
output = output.to(orig_dtype)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class VisionRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
||||
super().__init__()
|
||||
@@ -298,6 +275,11 @@ class DotsVisionAttention(nn.Module):
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(
|
||||
enforce_enable=True,
|
||||
enable_fp32_compute=True,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -318,7 +300,11 @@ class DotsVisionAttention(nn.Module):
|
||||
|
||||
if rotary_pos_emb is not None:
|
||||
qk_concat = torch.cat([q, k], dim=0)
|
||||
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
||||
qk_rotated = self.apply_rotary_emb(
|
||||
qk_concat,
|
||||
rotary_pos_emb.cos(),
|
||||
rotary_pos_emb.sin(),
|
||||
)
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
context_layer = self.attn(
|
||||
|
||||
@@ -33,7 +33,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from einops import rearrange
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
@@ -53,6 +53,9 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
ApplyRotaryEmb,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
@@ -69,7 +72,6 @@ from vllm.multimodal.processing import (
|
||||
PromptUpdate,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
@@ -89,52 +91,6 @@ logger = init_logger(__name__)
|
||||
# === Vision Transformer === #
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
|
||||
if not interleaved:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
else:
|
||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
||||
return rearrange(
|
||||
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
||||
)
|
||||
|
||||
|
||||
def apply_rotary_emb_torch(
|
||||
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x: (batch_size, seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
||||
"""
|
||||
ro_dim = cos.shape[-1] * 2
|
||||
assert ro_dim <= x.shape[-1]
|
||||
cos = repeat(
|
||||
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||
)
|
||||
sin = repeat(
|
||||
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||
)
|
||||
return torch.cat(
|
||||
[
|
||||
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
||||
x[..., ro_dim:],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
||||
t_ = t.float()
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
apply_rotary_emb = apply_rotary_emb_torch
|
||||
if current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
output = apply_rotary_emb(t_, cos, sin).type_as(t)
|
||||
return output
|
||||
|
||||
|
||||
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
|
||||
"""All-gather the input tensor interleavely across model parallel group."""
|
||||
import torch.distributed as dist
|
||||
@@ -200,6 +156,11 @@ class Ernie4_5_VisionAttention(nn.Module):
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(
|
||||
enforce_enable=True,
|
||||
enable_fp32_compute=True,
|
||||
)
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
@@ -244,7 +205,11 @@ class Ernie4_5_VisionAttention(nn.Module):
|
||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
|
||||
if rotary_pos_emb is not None:
|
||||
qk_concat = torch.cat([q, k], dim=0)
|
||||
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
||||
qk_rotated = self.apply_rotary_emb(
|
||||
qk_concat,
|
||||
rotary_pos_emb.cos(),
|
||||
rotary_pos_emb.sin(),
|
||||
)
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
output = self.attn(
|
||||
|
||||
@@ -19,7 +19,6 @@ from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import Gemma3TextConfig
|
||||
|
||||
@@ -223,77 +222,9 @@ class Gemma3Attention(nn.Module):
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
if not kwargs.get("has_images", False):
|
||||
# Fast path for text-only inputs. The performance for the text-only
|
||||
# inputs are not affected by the naive attention below.
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
# NOTE(woosuk): Gemma3 uses bidirectional attention between image tokens
|
||||
# that correspond to the same image while using causal attention
|
||||
# otherwise. Current attention backends cannot handle this pattern, so
|
||||
# we temporarily use a naive attention implementation with mask tensors.
|
||||
|
||||
# We intentionally keep the attention backend as-is and only override
|
||||
# `attn_output` with the naive implementation's output. This minimizes
|
||||
# changes to existing model runners and attention backends. The call to
|
||||
# `self.attn(q, k, v)` is only used to populate the KV cache - its
|
||||
# output is discarded and overwritten below. While this duplicates
|
||||
# computation, it maintains compatibility.
|
||||
# TODO(woosuk): Optimize by implementing custom attention kernels.
|
||||
attn_output = self.naive_attn_with_masks(q, k, v, out=attn_output, **kwargs)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
def naive_attn_with_masks(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(woosuk): As described in the comment above, this code is not
|
||||
# meant to be performant. It is only meant to be correct.
|
||||
q = q.view(-1, self.num_heads, self.head_dim)
|
||||
# Expand the key and value to handle GQA.
|
||||
num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||
k = k.repeat_interleave(num_queries_per_kv, dim=-2)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||
v = v.repeat_interleave(num_queries_per_kv, dim=-2)
|
||||
|
||||
if self.is_sliding:
|
||||
attn_masks = kwargs["local_attn_masks"]
|
||||
else:
|
||||
attn_masks = kwargs["global_attn_masks"]
|
||||
|
||||
seq_lens = kwargs["seq_lens"]
|
||||
start_idx = 0
|
||||
for seq_len, attn_mask in zip(seq_lens, attn_masks):
|
||||
end_idx = start_idx + seq_len
|
||||
query = q[start_idx:end_idx].unsqueeze(0)
|
||||
key = k[start_idx:end_idx].unsqueeze(0)
|
||||
value = v[start_idx:end_idx].unsqueeze(0)
|
||||
|
||||
# Transpose.
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
output = F.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
self.scaling,
|
||||
)
|
||||
output = output.transpose(1, 2).flatten(-2, -1)
|
||||
out[start_idx:end_idx] = output
|
||||
start_idx = end_idx
|
||||
return out
|
||||
|
||||
|
||||
class Gemma3DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
|
||||
@@ -65,6 +65,9 @@ from vllm.model_executor.layers.linear import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
ApplyRotaryEmb,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@@ -95,7 +98,7 @@ from .interfaces import (
|
||||
SupportsMultiModal,
|
||||
SupportsPP,
|
||||
)
|
||||
from .qwen2_vl import _create_qwen2vl_field_factory, apply_rotary_pos_emb_vision
|
||||
from .qwen2_vl import _create_qwen2vl_field_factory
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
@@ -304,6 +307,8 @@ class Glm4vVisionAttention(nn.Module):
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
@@ -339,8 +344,10 @@ class Glm4vVisionAttention(nn.Module):
|
||||
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
|
||||
# [2 * b, s, heads, head_dim]
|
||||
qk_concat = torch.cat([q, k], dim=0)
|
||||
qk_rotated = apply_rotary_pos_emb_vision(
|
||||
qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
|
||||
qk_rotated = self.apply_rotary_emb(
|
||||
qk_concat,
|
||||
rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin,
|
||||
)
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
|
||||
@@ -30,6 +30,9 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
ApplyRotaryEmb,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
@@ -59,7 +62,6 @@ from vllm.multimodal.processing import (
|
||||
PromptUpdate,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
@@ -341,20 +343,14 @@ def apply_rotary_pos_emb_flashatt(
|
||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
elif current_platform.is_rocm():
|
||||
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
|
||||
else:
|
||||
# For other platforms, use PyTorch fallback
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
apply_rotary_emb_torch,
|
||||
)
|
||||
apply_rotary_emb = ApplyRotaryEmb(
|
||||
enforce_enable=True,
|
||||
enable_fp32_compute=True,
|
||||
)
|
||||
|
||||
apply_rotary_emb = partial(apply_rotary_emb_torch, is_neox_style=True)
|
||||
q_embed = apply_rotary_emb(q, cos, sin)
|
||||
k_embed = apply_rotary_emb(k, cos, sin)
|
||||
|
||||
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
|
||||
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from typing import Annotated, Literal
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
from einops import rearrange
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
from transformers.activations import GELUActivation
|
||||
from transformers.modeling_outputs import (
|
||||
@@ -47,7 +47,7 @@ from vllm.model_executor.layers.linear import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
dispatch_rotary_emb_function,
|
||||
ApplyRotaryEmb,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
@@ -130,47 +130,6 @@ def smart_resize(
|
||||
return h_bar, w_bar
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
|
||||
if not interleaved:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
||||
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
|
||||
|
||||
|
||||
def apply_rotary_emb_torch(
|
||||
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
x: (batch_size, seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
||||
"""
|
||||
ro_dim = cos.shape[-1] * 2
|
||||
assert ro_dim <= x.shape[-1]
|
||||
cos = repeat(
|
||||
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||
)
|
||||
sin = repeat(
|
||||
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||
)
|
||||
return torch.cat(
|
||||
[
|
||||
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
||||
x[..., ro_dim:],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
||||
rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch)
|
||||
t_ = t.float()
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
output = rotary_emb_function(t_, cos, sin).type_as(t)
|
||||
return output
|
||||
|
||||
|
||||
class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config()
|
||||
@@ -609,6 +568,10 @@ class SiglipAttention(nn.Module):
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(
|
||||
enforce_enable=True,
|
||||
enable_fp32_compute=True,
|
||||
)
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
seq_len, bs, _ = qkv.shape
|
||||
@@ -651,7 +614,11 @@ class SiglipAttention(nn.Module):
|
||||
|
||||
if rotary_pos_emb is not None:
|
||||
qk_concat = torch.cat([q, k], dim=0)
|
||||
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
||||
qk_rotated = self.apply_rotary_emb(
|
||||
qk_concat,
|
||||
rotary_pos_emb.cos(),
|
||||
rotary_pos_emb.sin(),
|
||||
)
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
context_layer = self.attn(
|
||||
|
||||
@@ -60,6 +60,9 @@ from vllm.model_executor.layers.linear import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
ApplyRotaryEmb,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.vision import should_torch_compile_mm_vit
|
||||
@@ -95,7 +98,6 @@ from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
|
||||
from .qwen2_vl import (
|
||||
Qwen2VLMultiModalProcessor,
|
||||
Qwen2VLProcessingInfo,
|
||||
apply_rotary_pos_emb_vision,
|
||||
)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
@@ -353,6 +355,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@@ -378,8 +382,10 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
qk_reshaped = einops.rearrange(
|
||||
qk, "b s two head head_dim -> (two b) s head head_dim", two=2
|
||||
)
|
||||
qk_rotated = apply_rotary_pos_emb_vision(
|
||||
qk_reshaped, cos=rotary_pos_emb_cos, sin=rotary_pos_emb_sin
|
||||
qk_rotated = self.apply_rotary_emb(
|
||||
qk_reshaped,
|
||||
rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin,
|
||||
)
|
||||
qk_rotated = qk_rotated.view(
|
||||
2,
|
||||
|
||||
@@ -59,8 +59,7 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
apply_rotary_emb_torch,
|
||||
dispatch_rotary_emb_function,
|
||||
ApplyRotaryEmb,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
@@ -280,16 +279,6 @@ class Qwen2VisionMLP(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def apply_rotary_pos_emb_vision(
|
||||
t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
rotary_emb_function = dispatch_rotary_emb_function(
|
||||
default=partial(apply_rotary_emb_torch, is_neox_style=True)
|
||||
)
|
||||
output = rotary_emb_function(t, cos, sin).type_as(t)
|
||||
return output
|
||||
|
||||
|
||||
class Qwen2VisionAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -341,6 +330,8 @@ class Qwen2VisionAttention(nn.Module):
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
@@ -387,8 +378,10 @@ class Qwen2VisionAttention(nn.Module):
|
||||
|
||||
# [2 * b, s, heads, head_dim]
|
||||
qk_concat = torch.cat([q, k], dim=0)
|
||||
qk_rotated = apply_rotary_pos_emb_vision(
|
||||
qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
|
||||
qk_rotated = self.apply_rotary_emb(
|
||||
qk_concat,
|
||||
rotary_pos_emb_cos,
|
||||
rotary_pos_emb_sin,
|
||||
)
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
|
||||
|
||||
@@ -713,17 +713,13 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
video_soft_tokens = self.get_num_video_tokens(
|
||||
num_video_soft_tokens = self.get_num_video_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
|
||||
image_processor=None,
|
||||
)
|
||||
|
||||
# NOTE: By default in Qwen3-VL, one video token is converted to
|
||||
# "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501
|
||||
formatted_video_soft_tokens = video_soft_tokens * 12.5
|
||||
return int(formatted_video_soft_tokens)
|
||||
return num_video_soft_tokens
|
||||
|
||||
def _calculate_timestamps(
|
||||
self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int
|
||||
|
||||
@@ -6,7 +6,6 @@ within a vision language model."""
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from transformers import Siglip2VisionConfig
|
||||
@@ -26,6 +25,9 @@ from vllm.model_executor.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||
ApplyRotaryEmb,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -146,40 +148,6 @@ class Siglip2VisionEmbeddings(nn.Module):
|
||||
return patch_embeds
|
||||
|
||||
|
||||
# copy from flash_attn/layers/rotary.py
|
||||
def rotate_half(x, interleaved=False):
|
||||
if not interleaved:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
else:
|
||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
||||
return rearrange(
|
||||
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
||||
)
|
||||
|
||||
|
||||
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
||||
"""
|
||||
x: (batch_size, seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
||||
"""
|
||||
ro_dim = cos.shape[-1] * 2
|
||||
assert ro_dim <= x.shape[-1]
|
||||
cos = repeat(
|
||||
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||
)
|
||||
sin = repeat(
|
||||
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
||||
)
|
||||
return torch.cat(
|
||||
[
|
||||
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
||||
x[..., ro_dim:],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
@@ -189,14 +157,20 @@ def apply_rotary_pos_emb(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||
if is_flash_attn_backend and current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
|
||||
apply_rotary_emb_func = apply_rotary_emb
|
||||
apply_rotary_emb = ApplyRotaryEmb(
|
||||
enforce_enable=True,
|
||||
enable_fp32_compute=True,
|
||||
)
|
||||
|
||||
if is_flash_attn_backend and not current_platform.is_cuda():
|
||||
apply_rotary_emb_func = apply_rotary_emb.forward_cuda
|
||||
else:
|
||||
apply_rotary_emb_func = apply_rotary_emb_torch
|
||||
q_embed = apply_rotary_emb_func(q.float(), cos.float(), sin.float()).type_as(q)
|
||||
k_embed = apply_rotary_emb_func(k.float(), cos.float(), sin.float()).type_as(k)
|
||||
apply_rotary_emb_func = apply_rotary_emb.forward_native
|
||||
|
||||
q_embed = apply_rotary_emb_func(q, cos, sin)
|
||||
k_embed = apply_rotary_emb_func(k, cos, sin)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@@ -88,16 +88,10 @@ def get_vit_attn_backend(
|
||||
"""
|
||||
Get the available attention backend for Vision Transformer.
|
||||
"""
|
||||
attn_backend = attn_backend_override
|
||||
|
||||
selected_backend = get_current_vllm_config().attention_config.backend
|
||||
if attn_backend is None:
|
||||
attn_backend = selected_backend
|
||||
|
||||
return current_platform.get_vit_attn_backend(
|
||||
head_size,
|
||||
dtype,
|
||||
backend=attn_backend,
|
||||
backend=attn_backend_override,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
||||
from collections import UserDict, defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from functools import cached_property, partial
|
||||
from itertools import accumulate
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -169,11 +169,42 @@ class PlaceholderRange:
|
||||
between `offset` and `offset + length` to assign embeddings to.
|
||||
"""
|
||||
|
||||
def get_num_embeds(self) -> int:
|
||||
@cached_property
|
||||
def embeds_cumsum(self) -> torch.Tensor | None:
|
||||
if self.is_embed is None:
|
||||
return None
|
||||
|
||||
return self.is_embed.cumsum(dim=0)
|
||||
|
||||
@cached_property
|
||||
def get_num_embeds(self) -> int:
|
||||
if self.embeds_cumsum is None:
|
||||
return self.length
|
||||
|
||||
return int(self.is_embed.sum().item())
|
||||
return int(self.embeds_cumsum[-1])
|
||||
|
||||
def get_embeds_indices_in_range(
|
||||
self, start_idx: int, end_idx: int
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the starting and ending indices of the embeddings of encoder outputs
|
||||
in the range of [start_idx, end_idx) in the placeholders.
|
||||
|
||||
For example, given:
|
||||
PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True])
|
||||
|
||||
If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get
|
||||
the second and the third embeddings from the encoder output.
|
||||
"""
|
||||
if self.embeds_cumsum is None:
|
||||
return start_idx, end_idx
|
||||
|
||||
embeds_start_idx = (
|
||||
int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0
|
||||
)
|
||||
embeds_end_idx = int(self.embeds_cumsum[end_idx - 1])
|
||||
|
||||
return embeds_start_idx, embeds_end_idx
|
||||
|
||||
def extract_embeds_range(self) -> list[tuple[int, int]]:
|
||||
"""Extract the start and end indices of the embedded region in prompt.
|
||||
@@ -188,7 +219,7 @@ class PlaceholderRange:
|
||||
Returns full placeholder range if `is_embed` is `None`.
|
||||
"""
|
||||
if self.is_embed is None:
|
||||
return [(self.offset, self.offset + self.length)]
|
||||
return [(self.offset, self.offset + self.length - 1)]
|
||||
|
||||
mask_i = self.is_embed.int()
|
||||
starts = torch.nonzero(
|
||||
|
||||
@@ -274,15 +274,11 @@ class MultiModalProfiler(Generic[_I]):
|
||||
def _get_mm_num_tokens(
|
||||
self,
|
||||
mm_inputs: MultiModalInputs,
|
||||
mm_embeddings_only: bool = True,
|
||||
) -> Mapping[str, int]:
|
||||
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
||||
|
||||
return {
|
||||
modality: sum(
|
||||
item.get_num_embeds() if mm_embeddings_only else item.length
|
||||
for item in placeholders
|
||||
)
|
||||
modality: sum(item.get_num_embeds for item in placeholders)
|
||||
for modality, placeholders in placeholders_by_modality.items()
|
||||
}
|
||||
|
||||
@@ -328,12 +324,15 @@ class MultiModalProfiler(Generic[_I]):
|
||||
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
||||
)
|
||||
|
||||
def _get_mm_max_tokens(
|
||||
def get_mm_max_tokens(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
mm_embeddings_only: bool = True,
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Returns the maximum number of embeddings per item of each modality, excluding
|
||||
any break/text tokens in-between multimodal embeddings/encoder outputs.
|
||||
"""
|
||||
if mm_counts is None:
|
||||
mm_counts = self.get_mm_limits()
|
||||
|
||||
@@ -349,21 +348,4 @@ class MultiModalProfiler(Generic[_I]):
|
||||
}
|
||||
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||
return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only)
|
||||
|
||||
def get_mm_max_contiguous_tokens(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int] | None = None,
|
||||
) -> Mapping[str, int]:
|
||||
"""
|
||||
Returns the maximum length of the multimodal (image placeholders+text)
|
||||
tokens, including any break/text tokens in-between image embeddings.
|
||||
|
||||
`<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>`
|
||||
Returns 9, even when the number of image embeddings is 6.
|
||||
|
||||
This is important to take into account when profiling and
|
||||
initializing the encoder cache size.
|
||||
"""
|
||||
return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)
|
||||
return self._get_mm_num_tokens(mm_inputs)
|
||||
|
||||
@@ -164,7 +164,7 @@ class MultiModalRegistry:
|
||||
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
|
||||
)
|
||||
|
||||
return profiler.get_mm_max_contiguous_tokens(
|
||||
return profiler.get_mm_max_tokens(
|
||||
seq_len,
|
||||
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
|
||||
)
|
||||
|
||||
@@ -76,6 +76,39 @@ class TritonAttentionMetadata:
|
||||
# Optional aot scheduling
|
||||
scheduler_metadata: torch.Tensor | None = None
|
||||
prefix_scheduler_metadata: torch.Tensor | None = None
|
||||
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
|
||||
|
||||
@property
|
||||
def mm_prefix_range_tensor(self) -> torch.Tensor | None:
|
||||
"""Convert mm_prefix_range dict to padded tensor for Triton kernel.
|
||||
|
||||
Returns shape: (num_seqs, max_ranges, 2) with 0-padding for empty ranges.
|
||||
Empty ranges have start==end==0, which kernel skips via is_valid check.
|
||||
"""
|
||||
# TODO(Isotr0py): Move to model runner's attention metadata
|
||||
# preparation to avoid duplicate computation.
|
||||
if self.mm_prefix_range is None:
|
||||
return None
|
||||
|
||||
num_seqs = self.seq_lens.shape[0]
|
||||
device = self.seq_lens.device
|
||||
|
||||
# Collect ranges, using [(0,0)] for empty sequences to ensure uniform dims
|
||||
range_lists = [
|
||||
self.mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)] for i in range(num_seqs)
|
||||
]
|
||||
|
||||
# Return None if all ranges are trivial (only (0,0) placeholders)
|
||||
if all(r == [(0, 0)] for r in range_lists):
|
||||
return None
|
||||
|
||||
# Create 2D tensors with shape (num_ranges, 2) for each sequence
|
||||
range_tensors = [
|
||||
torch.tensor(r, dtype=torch.int32, device=device).view(-1, 2)
|
||||
for r in range_lists
|
||||
]
|
||||
|
||||
return torch.nested.nested_tensor(range_tensors).to_padded_tensor(0)
|
||||
|
||||
|
||||
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||
@@ -268,6 +301,10 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
def supports_head_size(cls, head_size: int) -> bool:
|
||||
return head_size >= 32
|
||||
|
||||
@classmethod
|
||||
def supports_mm_prefix(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_sink(cls) -> bool:
|
||||
return True
|
||||
@@ -427,6 +464,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
|
||||
mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor
|
||||
|
||||
unified_attention(
|
||||
q=query[:num_actual_tokens],
|
||||
@@ -453,6 +491,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
softmax_segm_expsum=softmax_segm_expsum,
|
||||
sinks=self.sinks,
|
||||
output_scale=output_scale,
|
||||
mm_prefix_range=mm_prefix_range_tensor,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@@ -39,20 +39,26 @@ class EncoderCacheManager:
|
||||
space for new embeddings.
|
||||
Oldest cached embeddings with no request referenced will be first evicted.
|
||||
|
||||
NOTE: The EncoderCacheManager operates on the level of multimodal embeddings
|
||||
instead of encoder tokens (i.e. all tokens that represent the multimodal data
|
||||
in the input sequence). This means all break/text tokens in-between multimodal
|
||||
embeddings are not considered with respect to the cache size and the number
|
||||
of free slots.
|
||||
|
||||
Args:
|
||||
cache_size: Limit the size of the cache, measured by the number of
|
||||
tokens from the input sequence.
|
||||
encoder embeddings from the input sequence.
|
||||
|
||||
Attributes:
|
||||
cache_size: Total cache capacity in encoder tokens.
|
||||
num_free_slots: Current available cache capacity in encoder tokens.
|
||||
cache_size: Total cache capacity in encoder embeddings.
|
||||
num_free_slots: Current available cache capacity in encoder embeddings.
|
||||
num_freeable_slots: Capacity that can be immediately reclaimed by
|
||||
evicting entries with zero references (in encoder tokens).
|
||||
evicting entries with zero references (in encoder embeddings).
|
||||
cached: Mapping from mm_hash to a set of request IDs that currently
|
||||
reference the cached entry. If the set is empty, the entry exists
|
||||
but is not referenced by any request and is eligible for
|
||||
reclamation.
|
||||
freeable: List of tuples (mm_hash, num_tokens) representing entries
|
||||
freeable: List of tuples (mm_hash, num_encoder_embeds) representing entries
|
||||
whose no current running request is needed and that can be freed to
|
||||
make space when needed.
|
||||
freed: List of mm_hash strings that were actually evicted since the
|
||||
@@ -67,7 +73,7 @@ class EncoderCacheManager:
|
||||
# mm_hash of mm_data => ids of requests that reference the mm_data
|
||||
self.cached: dict[str, set[str]] = {}
|
||||
|
||||
# mm_hash of mm_data => num_encoder_tokens of the mm_data
|
||||
# mm_hash of mm_data => num_encoder_embeds of the mm_data
|
||||
self.freeable: OrderedDict[str, int] = OrderedDict()
|
||||
self.freed: list[str] = []
|
||||
|
||||
@@ -93,8 +99,8 @@ class EncoderCacheManager:
|
||||
|
||||
# Cached but currently not referenced by any request
|
||||
if not self.cached[mm_hash]:
|
||||
num_tokens = self.freeable.pop(mm_hash)
|
||||
self.num_freeable_slots -= num_tokens
|
||||
num_encoder_embeds = self.freeable.pop(mm_hash)
|
||||
self.num_freeable_slots -= num_encoder_embeds
|
||||
|
||||
self.cached[mm_hash].add(request.request_id)
|
||||
return True
|
||||
@@ -104,7 +110,7 @@ class EncoderCacheManager:
|
||||
request: Request,
|
||||
input_id: int,
|
||||
encoder_compute_budget: int,
|
||||
num_tokens_to_schedule: int,
|
||||
num_embeds_to_schedule: int,
|
||||
) -> bool:
|
||||
"""Check if there's sufficient cache space for a multimodal input.
|
||||
If there is, return True and update EncoderCacheManager state.
|
||||
@@ -121,9 +127,9 @@ class EncoderCacheManager:
|
||||
Args:
|
||||
request: The request containing the multimodal input.
|
||||
input_id: Index of the multimodal input within the request.
|
||||
encoder_compute_budget: Number of encoder tokens allowed to be
|
||||
encoder_compute_budget: Number of encoder embeddings allowed to be
|
||||
computed when this method is invoked.
|
||||
num_tokens_to_schedule: Number of tokens already scheduled to be
|
||||
num_embeds_to_schedule: Number of encoder embeddings already scheduled to be
|
||||
allocated with cache space when this method is invoked.
|
||||
|
||||
Returns:
|
||||
@@ -134,30 +140,30 @@ class EncoderCacheManager:
|
||||
Note: This method does not allocate physical memory for the encoder
|
||||
output but only the state of EncoderCacheManager.
|
||||
"""
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
num_embeds = request.get_num_encoder_embeds(input_id)
|
||||
|
||||
# Not enough compute budget
|
||||
if num_tokens > encoder_compute_budget:
|
||||
if num_embeds > encoder_compute_budget:
|
||||
return False
|
||||
|
||||
num_tokens += num_tokens_to_schedule
|
||||
num_embeds += num_embeds_to_schedule
|
||||
|
||||
# Enough free slots
|
||||
if num_tokens <= self.num_free_slots:
|
||||
if num_embeds <= self.num_free_slots:
|
||||
return True
|
||||
|
||||
# Not enough reclaimable slots
|
||||
if num_tokens > self.num_freeable_slots:
|
||||
if num_embeds > self.num_freeable_slots:
|
||||
return False
|
||||
|
||||
# Not enough free slots but enough reclaimable slots
|
||||
# NOTE: Eviction takes place here, but physical memory is not freed
|
||||
# until model runner is notified by the scheduler output.
|
||||
while num_tokens > self.num_free_slots:
|
||||
mm_hash, num_free_token = self.freeable.popitem(last=False)
|
||||
while num_embeds > self.num_free_slots:
|
||||
mm_hash, num_free_embeds = self.freeable.popitem(last=False)
|
||||
del self.cached[mm_hash]
|
||||
self.freed.append(mm_hash)
|
||||
self.num_free_slots += num_free_token
|
||||
self.num_free_slots += num_free_embeds
|
||||
return True
|
||||
|
||||
def allocate(self, request: Request, input_id: int) -> None:
|
||||
@@ -176,16 +182,16 @@ class EncoderCacheManager:
|
||||
if mm_hash not in self.cached:
|
||||
self.cached[mm_hash] = set()
|
||||
|
||||
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
|
||||
# NOTE: Encoder cache should always have enough space for encoder inputs
|
||||
# that are scheduled since eviction takes place at can_allocate().
|
||||
assert self.num_free_slots >= num_encoder_tokens
|
||||
assert self.num_freeable_slots >= num_encoder_tokens
|
||||
assert self.num_free_slots >= num_encoder_embeds
|
||||
assert self.num_freeable_slots >= num_encoder_embeds
|
||||
|
||||
self.cached[mm_hash].add(request_id)
|
||||
self.num_free_slots -= num_encoder_tokens
|
||||
self.num_freeable_slots -= num_encoder_tokens
|
||||
self.num_free_slots -= num_encoder_embeds
|
||||
self.num_freeable_slots -= num_encoder_embeds
|
||||
|
||||
def get_cached_input_ids(self, request: Request) -> set[int]:
|
||||
"""Get all cached multimodal input IDs for a request.
|
||||
@@ -206,7 +212,7 @@ class EncoderCacheManager:
|
||||
|
||||
When the reference set for the corresponding `mm_hash` becomes empty,
|
||||
the entry is appended to `freeable` and `num_freeable_slots` is
|
||||
increased by the number of encoder tokens for that input.
|
||||
increased by the number of encoder embeddings for that input.
|
||||
|
||||
The entry is NOT physically freed until capacity is needed (e.g., by
|
||||
`can_allocate`).
|
||||
@@ -218,9 +224,9 @@ class EncoderCacheManager:
|
||||
return
|
||||
self.cached[mm_hash].discard(req_id)
|
||||
if not self.cached[mm_hash]:
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
self.freeable[mm_hash] = num_tokens
|
||||
self.num_freeable_slots += num_tokens
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
self.freeable[mm_hash] = num_encoder_embeds
|
||||
self.num_freeable_slots += num_encoder_embeds
|
||||
|
||||
def free(self, request: Request) -> None:
|
||||
"""Free all encoder input cache reference held by *request*.
|
||||
@@ -361,20 +367,20 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
|
||||
request: Request,
|
||||
input_id: int,
|
||||
encoder_compute_budget: int,
|
||||
num_tokens_to_schedule: int,
|
||||
num_embeds_to_schedule: int,
|
||||
) -> bool:
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
# Not enough compute budget
|
||||
if num_tokens > encoder_compute_budget:
|
||||
if num_encoder_embeds > encoder_compute_budget:
|
||||
return False
|
||||
|
||||
num_tokens += num_tokens_to_schedule
|
||||
num_encoder_embeds += num_embeds_to_schedule
|
||||
# Enough free slots
|
||||
return num_tokens <= self.num_free_slots
|
||||
return num_encoder_embeds <= self.num_free_slots
|
||||
|
||||
def allocate(self, request: Request, input_id: int) -> None:
|
||||
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
|
||||
self.num_free_slots -= num_encoder_tokens
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
self.num_free_slots -= num_encoder_embeds
|
||||
|
||||
mm_hash = request.mm_features[input_id].identifier
|
||||
self.freed.append(mm_hash)
|
||||
@@ -392,5 +398,5 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
|
||||
return freed
|
||||
|
||||
def free_encoder_input(self, request: Request, input_id: int) -> None:
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
self.num_free_slots += num_tokens
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
self.num_free_slots += num_encoder_embeds
|
||||
|
||||
@@ -349,11 +349,11 @@ class Scheduler(SchedulerInterface):
|
||||
if preempted_encoder_inputs:
|
||||
# Restore encoder compute budget if the preempted
|
||||
# request had encoder inputs scheduled in this step.
|
||||
num_tokens_to_restore = sum(
|
||||
preempted_req.get_num_encoder_tokens(i)
|
||||
num_embeds_to_restore = sum(
|
||||
preempted_req.get_num_encoder_embeds(i)
|
||||
for i in preempted_encoder_inputs
|
||||
)
|
||||
encoder_compute_budget += num_tokens_to_restore
|
||||
encoder_compute_budget += num_embeds_to_restore
|
||||
req_index -= 1
|
||||
else:
|
||||
preempted_req = self.running.pop()
|
||||
@@ -911,10 +911,11 @@ class Scheduler(SchedulerInterface):
|
||||
# multiple encoder inputs per request), we need to create temporary
|
||||
# trackers for accounting at the encoder input level.
|
||||
mm_hashes_to_schedule = set()
|
||||
num_tokens_to_schedule = 0
|
||||
num_embeds_to_schedule = 0
|
||||
for i, mm_feature in enumerate(mm_features):
|
||||
start_pos = mm_feature.mm_position.offset
|
||||
num_encoder_tokens = mm_feature.mm_position.length
|
||||
num_encoder_embeds = mm_feature.mm_position.get_num_embeds
|
||||
|
||||
# The encoder output is needed if the two ranges overlap:
|
||||
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
|
||||
@@ -970,9 +971,8 @@ class Scheduler(SchedulerInterface):
|
||||
):
|
||||
num_new_tokens = start_pos - num_computed_tokens
|
||||
break
|
||||
|
||||
if not self.encoder_cache_manager.can_allocate(
|
||||
request, i, encoder_compute_budget, num_tokens_to_schedule
|
||||
request, i, encoder_compute_budget, num_embeds_to_schedule
|
||||
):
|
||||
# The encoder cache is full or the encoder budget is exhausted.
|
||||
# NOTE(woosuk): We assume that the encoder input tokens should
|
||||
@@ -992,14 +992,31 @@ class Scheduler(SchedulerInterface):
|
||||
num_new_tokens = 0
|
||||
break
|
||||
|
||||
# Calculate the number of embeddings to schedule in the current range
|
||||
# of scheduled encoder placholder tokens.
|
||||
start_idx_rel = max(0, num_computed_tokens - start_pos)
|
||||
end_idx_rel = min(
|
||||
num_encoder_tokens, num_computed_tokens + num_new_tokens - start_pos
|
||||
)
|
||||
curr_embeds_start, curr_embeds_end = (
|
||||
mm_feature.mm_position.get_embeds_indices_in_range(
|
||||
start_idx_rel,
|
||||
end_idx_rel,
|
||||
)
|
||||
)
|
||||
# There's no embeddings in the current range of encoder placeholder tokens
|
||||
# so we can skip the encoder input.
|
||||
if curr_embeds_end - curr_embeds_start == 0:
|
||||
continue
|
||||
|
||||
if self.ec_connector is not None and remote_cache_has_item[i]:
|
||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
||||
external_load_encoder_input.append(i)
|
||||
num_tokens_to_schedule += num_encoder_tokens
|
||||
num_embeds_to_schedule += num_encoder_embeds
|
||||
continue
|
||||
|
||||
num_tokens_to_schedule += num_encoder_tokens
|
||||
encoder_compute_budget -= num_encoder_tokens
|
||||
num_embeds_to_schedule += num_encoder_embeds
|
||||
encoder_compute_budget -= num_encoder_embeds
|
||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
||||
encoder_inputs_to_schedule.append(i)
|
||||
|
||||
|
||||
@@ -209,10 +209,10 @@ class Request:
|
||||
def get_finished_reason(self) -> FinishReason | None:
|
||||
return RequestStatus.get_finished_reason(self.status)
|
||||
|
||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
||||
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||
assert input_id < len(self.mm_features)
|
||||
num_tokens = self.mm_features[input_id].mm_position.length
|
||||
return num_tokens
|
||||
num_embeds = self.mm_features[input_id].mm_position.get_num_embeds
|
||||
return num_embeds
|
||||
|
||||
def record_event(
|
||||
self,
|
||||
|
||||
@@ -169,9 +169,7 @@ from .utils import (
|
||||
MultiModalBudget,
|
||||
add_kv_sharing_layers_to_kv_cache_groups,
|
||||
bind_kv_cache,
|
||||
gather_mm_placeholders,
|
||||
sanity_check_mm_encoder_outputs,
|
||||
scatter_mm_placeholders,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -2187,10 +2185,7 @@ class GPUModelRunner(
|
||||
|
||||
# Cache the encoder outputs by mm_hash
|
||||
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
|
||||
self.encoder_cache[mm_hash] = scatter_mm_placeholders(
|
||||
output,
|
||||
is_embed=pos_info.is_embed,
|
||||
)
|
||||
self.encoder_cache[mm_hash] = output
|
||||
logger.debug("Finish execute for mm hash %s", mm_hash)
|
||||
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
|
||||
|
||||
@@ -2241,6 +2236,13 @@ class GPUModelRunner(
|
||||
num_encoder_tokens,
|
||||
)
|
||||
assert start_idx < end_idx
|
||||
curr_embeds_start, curr_embeds_end = (
|
||||
pos_info.get_embeds_indices_in_range(start_idx, end_idx)
|
||||
)
|
||||
# If there are no embeddings in the current range, we skip
|
||||
# gathering the embeddings.
|
||||
if curr_embeds_start == curr_embeds_end:
|
||||
continue
|
||||
|
||||
mm_hash = mm_feature.identifier
|
||||
encoder_output = self.encoder_cache.get(mm_hash, None)
|
||||
@@ -2248,16 +2250,14 @@ class GPUModelRunner(
|
||||
|
||||
if (is_embed := pos_info.is_embed) is not None:
|
||||
is_embed = is_embed[start_idx:end_idx]
|
||||
mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end]
|
||||
else:
|
||||
mm_embeds_item = encoder_output[start_idx:end_idx]
|
||||
|
||||
req_start_pos = req_start_idx + start_pos - num_computed_tokens
|
||||
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
|
||||
True if is_embed is None else is_embed
|
||||
)
|
||||
|
||||
mm_embeds_item = gather_mm_placeholders(
|
||||
encoder_output[start_idx:end_idx],
|
||||
is_embed=is_embed,
|
||||
)
|
||||
mm_embeds_req.append(mm_embeds_item)
|
||||
|
||||
if self.is_multimodal_pruning_enabled and self.uses_mrope:
|
||||
@@ -4467,31 +4467,8 @@ class GPUModelRunner(
|
||||
dummy_encoder_outputs,
|
||||
expected_num_items=max_mm_items_per_batch,
|
||||
)
|
||||
|
||||
# NOTE: This happens when encoder cache needs to store
|
||||
# the embeddings that encoder outputs are scattered onto.
|
||||
# In this case we create dummy embeddings of size
|
||||
# (max_tokens_for_modality, hidden_size) and scatter
|
||||
# encoder output into it.
|
||||
encoder_output_shape = dummy_encoder_outputs[0].shape
|
||||
max_mm_tokens_per_item = mm_budget.max_tokens_by_modality[
|
||||
dummy_modality
|
||||
]
|
||||
if encoder_output_shape[0] < max_mm_tokens_per_item:
|
||||
encoder_hidden_size = encoder_output_shape[-1]
|
||||
expanded_outputs = []
|
||||
for output in dummy_encoder_outputs:
|
||||
expanded = output.new_zeros(
|
||||
(max_mm_tokens_per_item, encoder_hidden_size)
|
||||
)
|
||||
num_tokens = output.shape[0]
|
||||
expanded[:num_tokens].copy_(output)
|
||||
expanded_outputs.append(expanded)
|
||||
|
||||
dummy_encoder_outputs = expanded_outputs
|
||||
|
||||
# Cache the dummy encoder outputs.
|
||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
||||
for i, output in enumerate(dummy_encoder_outputs):
|
||||
self.encoder_cache[f"tmp_{i}"] = output
|
||||
|
||||
# Add `is_profile` here to pre-allocate communication buffers
|
||||
hidden_states, last_hidden_states = self._dummy_run(
|
||||
|
||||
@@ -4,10 +4,12 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
from vllm.multimodal.cache import processor_only_cache_from_config
|
||||
@@ -17,6 +19,8 @@ from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
||||
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiModalBudget:
|
||||
"""Helper class to calculate budget information for multi-modal models."""
|
||||
@@ -198,6 +202,7 @@ def sanity_check_mm_encoder_outputs(
|
||||
)
|
||||
|
||||
|
||||
@deprecated("`scatter_mm_placeholders` is deprecated and will be removed in v0.15.0.")
|
||||
def scatter_mm_placeholders(
|
||||
embeds: torch.Tensor,
|
||||
is_embed: torch.Tensor | None,
|
||||
@@ -226,6 +231,7 @@ def scatter_mm_placeholders(
|
||||
return placeholders
|
||||
|
||||
|
||||
@deprecated("`gather_mm_placeholders` is deprecated and will be removed in v0.15.0.")
|
||||
def gather_mm_placeholders(
|
||||
placeholders: torch.Tensor,
|
||||
is_embed: torch.Tensor | None,
|
||||
|
||||
@@ -145,12 +145,20 @@ class WorkspaceManager:
|
||||
|
||||
for ubatch_id in range(self._num_ubatches):
|
||||
current_workspace = self._current_workspaces[ubatch_id]
|
||||
if current_workspace is None:
|
||||
if (
|
||||
current_workspace is None
|
||||
or self._workspace_size_bytes(current_workspace) < required_bytes
|
||||
):
|
||||
# Delete old tensor before allocating new one to avoid
|
||||
# memory spike from resize_(). resize_() allocates new
|
||||
# memory before freeing old, which can cause OOM.
|
||||
# Must clear the list reference first since local var
|
||||
# is just a copy of the reference.
|
||||
self._current_workspaces[ubatch_id] = None
|
||||
del current_workspace
|
||||
self._current_workspaces[ubatch_id] = torch.empty(
|
||||
(required_bytes,), dtype=torch.uint8, device=self._device
|
||||
)
|
||||
elif self._workspace_size_bytes(current_workspace) < required_bytes:
|
||||
current_workspace.resize_(required_bytes)
|
||||
|
||||
if envs.VLLM_DEBUG_WORKSPACE:
|
||||
logger.info(
|
||||
|
||||
Reference in New Issue
Block a user