Compare commits
12 Commits
v0.17.2rc0
...
v0.13.0rc4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
55f1fc1b1b | ||
|
|
17f3988094 | ||
|
|
682c38583c | ||
|
|
f124b56786 | ||
|
|
d78e128b8b | ||
|
|
761b730dcb | ||
|
|
f34eca5f01 | ||
|
|
4cd332f3cf | ||
|
|
16484d394c | ||
|
|
e397bd6592 | ||
|
|
6a88d590bb | ||
|
|
ad8c073131 |
@@ -1223,6 +1223,8 @@ steps:
|
|||||||
# FIXIT: find out which code initialize cuda before running the test
|
# FIXIT: find out which code initialize cuda before running the test
|
||||||
# before the fix, we need to use spawn to test it
|
# before the fix, we need to use spawn to test it
|
||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
|
# Alot of these tests are on the edge of OOMing
|
||||||
|
- export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||||
# There is some Tensor Parallelism related processing logic in LoRA that
|
# There is some Tensor Parallelism related processing logic in LoRA that
|
||||||
# requires multi-GPU testing for validation.
|
# requires multi-GPU testing for validation.
|
||||||
- pytest -v -s -x lora/test_chatglm3_tp.py
|
- pytest -v -s -x lora/test_chatglm3_tp.py
|
||||||
|
|||||||
@@ -523,6 +523,8 @@ CUSTOM_OPS_QUANT_RMS_NORM = ["+quant_fp8,+rms_norm"]
|
|||||||
list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)),
|
list[tuple[Any, ...]](flat_product(MODELS_GROUP_FP8, CUSTOM_OPS_QUANT_RMS_NORM)),
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
@pytest.mark.parametrize("inductor_graph_partition", [True, False])
|
||||||
|
# TODO: remove skip after we fix the fusion thoroughly
|
||||||
|
@pytest.mark.skipif(is_blackwell(), reason="Temporarily disabled on Blackwell")
|
||||||
def test_rms_group_quant(
|
def test_rms_group_quant(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
model_kwargs: dict[str, Any],
|
model_kwargs: dict[str, Any],
|
||||||
@@ -562,7 +564,7 @@ def test_rms_group_quant(
|
|||||||
splitting_ops=splitting_ops,
|
splitting_ops=splitting_ops,
|
||||||
# Common
|
# Common
|
||||||
mode=CompilationMode.VLLM_COMPILE,
|
mode=CompilationMode.VLLM_COMPILE,
|
||||||
pass_config=PassConfig(eliminate_noops=True, enable_fusion=True),
|
pass_config=PassConfig(eliminate_noops=True, fuse_norm_quant=True),
|
||||||
# Inductor caches custom passes by default as well via uuid
|
# Inductor caches custom passes by default as well via uuid
|
||||||
inductor_compile_config={"force_disable_caches": True},
|
inductor_compile_config={"force_disable_caches": True},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -254,7 +254,9 @@ async def test_single_chat_session_input_audio(
|
|||||||
async def test_chat_streaming_audio(
|
async def test_chat_streaming_audio(
|
||||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str
|
client: openai.AsyncOpenAI, model_name: str, audio_url: str
|
||||||
):
|
):
|
||||||
messages = dummy_messages_from_audio_url(audio_url)
|
messages = dummy_messages_from_audio_url(
|
||||||
|
audio_url, "What's a short title for this audio?"
|
||||||
|
)
|
||||||
|
|
||||||
# test single completion
|
# test single completion
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
|
|||||||
203
tests/kernels/core/test_apply_rotary_emb.py
Normal file
203
tests/kernels/core/test_apply_rotary_emb.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""
|
||||||
|
Tests for ApplyRotaryEmb CustomOp dispatch behavior.
|
||||||
|
|
||||||
|
This test ensures that RotaryEmbedding classes correctly call the appropriate
|
||||||
|
ApplyRotaryEmb methods based on the calling context:
|
||||||
|
|
||||||
|
1. RotaryEmbedding.forward_native() -> ApplyRotaryEmb.forward_native()
|
||||||
|
2. RotaryEmbedding.forward_cuda() -> ApplyRotaryEmb.forward() (auto-dispatch)
|
||||||
|
3. RotaryEmbedding.forward_hip() -> ApplyRotaryEmb.forward() (auto-dispatch)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import (
|
||||||
|
CompilationConfig,
|
||||||
|
VllmConfig,
|
||||||
|
get_cached_compilation_config,
|
||||||
|
set_current_vllm_config,
|
||||||
|
)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
CUDA_DEVICES = ["cuda:0"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RotaryEmbeddingTestCase:
|
||||||
|
"""Test case configuration for RotaryEmbedding dispatch tests."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
rope_class: type
|
||||||
|
rope_kwargs: dict
|
||||||
|
method_name: str # forward_native, forward_cuda, forward
|
||||||
|
positions_shape: tuple # (num_tokens,) or (3, num_tokens) or (4, num_tokens)
|
||||||
|
expect_forward_native: bool # Should call ApplyRotaryEmb.forward_native()
|
||||||
|
expect_forward: bool # Should call ApplyRotaryEmb.forward()
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_cases() -> list[RotaryEmbeddingTestCase]:
|
||||||
|
"""Generate test cases for all RotaryEmbedding classes."""
|
||||||
|
from vllm.model_executor.layers.rotary_embedding.ernie45_vl_rope import (
|
||||||
|
Ernie4_5_VLRotaryEmbedding,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.rotary_embedding.mrope import MRotaryEmbedding
|
||||||
|
from vllm.model_executor.layers.rotary_embedding.xdrope import XDRotaryEmbedding
|
||||||
|
|
||||||
|
common_kwargs = {
|
||||||
|
"head_size": 128,
|
||||||
|
"rotary_dim": 128,
|
||||||
|
"max_position_embeddings": 4096,
|
||||||
|
"base": 10000,
|
||||||
|
"is_neox_style": True,
|
||||||
|
"dtype": torch.bfloat16,
|
||||||
|
}
|
||||||
|
|
||||||
|
return [
|
||||||
|
# MRotaryEmbedding tests
|
||||||
|
RotaryEmbeddingTestCase(
|
||||||
|
name="MRotaryEmbedding.forward_native",
|
||||||
|
rope_class=MRotaryEmbedding,
|
||||||
|
rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]},
|
||||||
|
method_name="forward_native",
|
||||||
|
positions_shape=(3, 32), # 2D for multimodal
|
||||||
|
expect_forward_native=True,
|
||||||
|
expect_forward=False,
|
||||||
|
),
|
||||||
|
RotaryEmbeddingTestCase(
|
||||||
|
name="MRotaryEmbedding.forward_cuda_1d",
|
||||||
|
rope_class=MRotaryEmbedding,
|
||||||
|
rope_kwargs={**common_kwargs, "mrope_section": [16, 24, 24]},
|
||||||
|
method_name="forward_cuda",
|
||||||
|
positions_shape=(32,), # 1D triggers apply_rotary_emb path
|
||||||
|
expect_forward_native=False,
|
||||||
|
expect_forward=True,
|
||||||
|
),
|
||||||
|
# XDRotaryEmbedding tests
|
||||||
|
RotaryEmbeddingTestCase(
|
||||||
|
name="XDRotaryEmbedding.forward",
|
||||||
|
rope_class=XDRotaryEmbedding,
|
||||||
|
rope_kwargs={
|
||||||
|
**common_kwargs,
|
||||||
|
"scaling_alpha": 1.0,
|
||||||
|
"xdrope_section": [16, 16, 16, 16],
|
||||||
|
},
|
||||||
|
method_name="forward",
|
||||||
|
positions_shape=(4, 32), # 4D for P/W/H/T
|
||||||
|
expect_forward_native=False,
|
||||||
|
expect_forward=True,
|
||||||
|
),
|
||||||
|
# Ernie4_5_VLRotaryEmbedding tests
|
||||||
|
RotaryEmbeddingTestCase(
|
||||||
|
name="Ernie4_5_VLRotaryEmbedding.forward_native",
|
||||||
|
rope_class=Ernie4_5_VLRotaryEmbedding,
|
||||||
|
rope_kwargs={**common_kwargs, "mrope_section": [22, 22, 20]},
|
||||||
|
method_name="forward_native",
|
||||||
|
positions_shape=(3, 32), # 2D for multimodal
|
||||||
|
expect_forward_native=True,
|
||||||
|
expect_forward=False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def run_dispatch_test(
|
||||||
|
test_case: RotaryEmbeddingTestCase,
|
||||||
|
device: str,
|
||||||
|
):
|
||||||
|
"""Run a dispatch test for a RotaryEmbedding class."""
|
||||||
|
vllm_config = VllmConfig(
|
||||||
|
compilation_config=CompilationConfig(custom_ops=["all", "+apply_rotary_emb"])
|
||||||
|
)
|
||||||
|
get_cached_compilation_config.cache_clear()
|
||||||
|
|
||||||
|
with set_current_vllm_config(vllm_config):
|
||||||
|
rope = test_case.rope_class(**test_case.rope_kwargs).to(device=device)
|
||||||
|
|
||||||
|
apply_rotary_emb = rope.apply_rotary_emb
|
||||||
|
|
||||||
|
# Verify custom op is enabled
|
||||||
|
if test_case.expect_forward_native:
|
||||||
|
assert (
|
||||||
|
apply_rotary_emb._forward_method != apply_rotary_emb.forward_native
|
||||||
|
), "Test setup error: ApplyRotaryEmb custom op should be enabled"
|
||||||
|
|
||||||
|
# Setup call tracking
|
||||||
|
call_tracker = {"forward_native_called": False, "forward_called": False}
|
||||||
|
original_forward_native = apply_rotary_emb.forward_native
|
||||||
|
original_forward = apply_rotary_emb.forward
|
||||||
|
|
||||||
|
def tracked_forward_native(*args, **kwargs):
|
||||||
|
call_tracker["forward_native_called"] = True
|
||||||
|
return original_forward_native(*args, **kwargs)
|
||||||
|
|
||||||
|
def tracked_forward(*args, **kwargs):
|
||||||
|
call_tracker["forward_called"] = True
|
||||||
|
return original_forward(*args, **kwargs)
|
||||||
|
|
||||||
|
apply_rotary_emb.forward_native = tracked_forward_native
|
||||||
|
apply_rotary_emb.forward = tracked_forward
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_tokens = test_case.positions_shape[-1]
|
||||||
|
num_q_heads = 8
|
||||||
|
num_kv_heads = 2
|
||||||
|
head_size = test_case.rope_kwargs["head_size"]
|
||||||
|
max_position = test_case.rope_kwargs["max_position_embeddings"]
|
||||||
|
|
||||||
|
positions = torch.randint(
|
||||||
|
0, max_position // 4, test_case.positions_shape, device=device
|
||||||
|
)
|
||||||
|
query = torch.randn(
|
||||||
|
num_tokens, num_q_heads * head_size, dtype=torch.bfloat16, device=device
|
||||||
|
)
|
||||||
|
key = torch.randn(
|
||||||
|
num_tokens,
|
||||||
|
num_kv_heads * head_size,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call the method under test
|
||||||
|
method = getattr(rope, test_case.method_name)
|
||||||
|
method(positions, query.clone(), key.clone())
|
||||||
|
|
||||||
|
# Verify expectations
|
||||||
|
if test_case.expect_forward_native:
|
||||||
|
assert call_tracker["forward_native_called"], (
|
||||||
|
f"{test_case.name} should call ApplyRotaryEmb.forward_native()"
|
||||||
|
)
|
||||||
|
if not test_case.expect_forward:
|
||||||
|
assert not call_tracker["forward_called"], (
|
||||||
|
f"{test_case.name} should NOT call ApplyRotaryEmb.forward(). "
|
||||||
|
"Bug: when +apply_rotary_emb is enabled, forward_native() "
|
||||||
|
"incorrectly dispatches to CUDA/HIP kernels."
|
||||||
|
)
|
||||||
|
if test_case.expect_forward:
|
||||||
|
assert call_tracker["forward_called"], (
|
||||||
|
f"{test_case.name} should call ApplyRotaryEmb.forward()"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
apply_rotary_emb.forward_native = original_forward_native
|
||||||
|
apply_rotary_emb.forward = original_forward
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not current_platform.is_cuda_alike(), reason="Skipping CUDA/ROCm only tests."
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("test_case", get_test_cases(), ids=lambda tc: tc.name)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
def test_rotary_embedding_dispatch(
|
||||||
|
test_case: RotaryEmbeddingTestCase,
|
||||||
|
device: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test that RotaryEmbedding classes dispatch to the correct ApplyRotaryEmb method.
|
||||||
|
|
||||||
|
- forward_native methods should call ApplyRotaryEmb.forward_native()
|
||||||
|
- forward_cuda/forward methods should call ApplyRotaryEmb.forward()
|
||||||
|
"""
|
||||||
|
run_dispatch_test(test_case, device)
|
||||||
@@ -1,17 +1,23 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Literal, NamedTuple
|
import os
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||||
|
|
||||||
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from pytest import MarkDecorator
|
from pytest import MarkDecorator
|
||||||
|
from transformers import AutoModelForImageTextToText
|
||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
|
from vllm.multimodal.image import rescale_image_size
|
||||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||||
|
|
||||||
from ....conftest import PromptImageInput, VllmRunner
|
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner
|
||||||
from ...utils import check_logprobs_close
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
|
|
||||||
@@ -21,9 +27,10 @@ class GGUFMMTestConfig(NamedTuple):
|
|||||||
gguf_backbone: str
|
gguf_backbone: str
|
||||||
gguf_mmproj: str
|
gguf_mmproj: str
|
||||||
prompt: list[str]
|
prompt: list[str]
|
||||||
mm_data: dict[Literal["images"], PromptImageInput]
|
image_names: list[str] # Store names, load PIL images at runtime
|
||||||
max_model_len: int = 4096
|
max_model_len: int = 4096
|
||||||
marks: list[MarkDecorator] = []
|
marks: list[MarkDecorator] = []
|
||||||
|
mm_processor_kwargs: dict[str, Any] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def gguf_model(self):
|
def gguf_model(self):
|
||||||
@@ -31,27 +38,75 @@ class GGUFMMTestConfig(NamedTuple):
|
|||||||
return hf_hub_download(self.gguf_repo, filename=self.gguf_backbone)
|
return hf_hub_download(self.gguf_repo, filename=self.gguf_backbone)
|
||||||
|
|
||||||
|
|
||||||
|
# Common prompts aligned with test_common.py "gemma3" entry format
|
||||||
|
_GEMMA3_PROMPTS = IMAGE_ASSETS.prompts(
|
||||||
|
{
|
||||||
|
"stop_sign": (
|
||||||
|
"<bos><start_of_turn>user\n"
|
||||||
|
"<start_of_image>What's the content in the center of the image?"
|
||||||
|
"<end_of_turn>\n<start_of_turn>model\n"
|
||||||
|
),
|
||||||
|
"cherry_blossom": (
|
||||||
|
"<bos><start_of_turn>user\n"
|
||||||
|
"<start_of_image>What is the season?"
|
||||||
|
"<end_of_turn>\n<start_of_turn>model\n"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Image asset names - load at runtime to avoid pickle issues with subprocess
|
||||||
|
_GEMMA3_IMAGE_NAMES = ["stop_sign", "cherry_blossom"]
|
||||||
|
|
||||||
|
# Regular multimodal (no pan-and-scan) - uses QAT Q4_0 GGUF
|
||||||
GEMMA3_CONFIG = GGUFMMTestConfig(
|
GEMMA3_CONFIG = GGUFMMTestConfig(
|
||||||
original_model="google/gemma-3-4b-it",
|
original_model="google/gemma-3-4b-it",
|
||||||
gguf_repo="google/gemma-3-4b-it-qat-q4_0-gguf",
|
gguf_repo="google/gemma-3-4b-it-qat-q4_0-gguf",
|
||||||
gguf_backbone="gemma-3-4b-it-q4_0.gguf",
|
gguf_backbone="gemma-3-4b-it-q4_0.gguf",
|
||||||
gguf_mmproj="mmproj-model-f16-4B.gguf",
|
gguf_mmproj="mmproj-model-f16-4B.gguf",
|
||||||
prompt=["<start_of_image>Describe this image in detail:"],
|
prompt=_GEMMA3_PROMPTS,
|
||||||
mm_data={"images": [ImageAsset("stop_sign").pil_image]},
|
image_names=_GEMMA3_IMAGE_NAMES,
|
||||||
|
max_model_len=4096,
|
||||||
marks=[pytest.mark.core_model],
|
marks=[pytest.mark.core_model],
|
||||||
|
mm_processor_kwargs={},
|
||||||
)
|
)
|
||||||
|
|
||||||
MODELS_TO_TEST = [GEMMA3_CONFIG]
|
# Pan-and-scan multimodal - uses unquantized BF16 GGUF
|
||||||
|
GEMMA3_CONFIG_PAN_AND_SCAN = GGUFMMTestConfig(
|
||||||
|
original_model="google/gemma-3-4b-it",
|
||||||
|
gguf_repo="unsloth/gemma-3-4b-it-GGUF",
|
||||||
|
gguf_backbone="gemma-3-4b-it-BF16.gguf",
|
||||||
|
gguf_mmproj="mmproj-BF16.gguf",
|
||||||
|
prompt=_GEMMA3_PROMPTS,
|
||||||
|
image_names=_GEMMA3_IMAGE_NAMES,
|
||||||
|
max_model_len=4096,
|
||||||
|
marks=[pytest.mark.core_model],
|
||||||
|
mm_processor_kwargs={"do_pan_and_scan": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
MODELS_TO_TEST = [GEMMA3_CONFIG, GEMMA3_CONFIG_PAN_AND_SCAN]
|
||||||
|
|
||||||
|
|
||||||
def run_multimodal_gguf_test(
|
def run_multimodal_gguf_test(
|
||||||
|
hf_runner: type[HfRunner],
|
||||||
vllm_runner: type[VllmRunner],
|
vllm_runner: type[VllmRunner],
|
||||||
model: GGUFMMTestConfig,
|
model: GGUFMMTestConfig,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
):
|
):
|
||||||
# Run gguf model.
|
# Load images at runtime (inside subprocess) to avoid pickle issues
|
||||||
|
images = [ImageAsset(name).pil_image for name in model.image_names]
|
||||||
|
size_factors = [0.25, 0.5, 1.0]
|
||||||
|
inputs_per_image = [
|
||||||
|
(
|
||||||
|
[prompt for _ in size_factors],
|
||||||
|
[rescale_image_size(image, factor) for factor in size_factors],
|
||||||
|
)
|
||||||
|
for image, prompt in zip(images, model.prompt)
|
||||||
|
]
|
||||||
|
|
||||||
|
# NOTE: Run vLLM first to avoid CUDA init issues with multiprocessing fork.
|
||||||
|
# Run GGUF model via vLLM.
|
||||||
with (
|
with (
|
||||||
set_default_torch_num_threads(1),
|
set_default_torch_num_threads(1),
|
||||||
vllm_runner(
|
vllm_runner(
|
||||||
@@ -60,33 +115,40 @@ def run_multimodal_gguf_test(
|
|||||||
tokenizer_name=model.original_model,
|
tokenizer_name=model.original_model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
max_model_len=model.max_model_len,
|
max_model_len=model.max_model_len,
|
||||||
|
mm_processor_kwargs=model.mm_processor_kwargs,
|
||||||
) as gguf_model,
|
) as gguf_model,
|
||||||
):
|
):
|
||||||
gguf_outputs = gguf_model.generate_greedy_logprobs(
|
gguf_outputs_per_case = [
|
||||||
prompts=model.prompt,
|
gguf_model.generate_greedy_logprobs(
|
||||||
max_tokens=max_tokens,
|
prompts,
|
||||||
|
max_tokens,
|
||||||
num_logprobs=num_logprobs,
|
num_logprobs=num_logprobs,
|
||||||
**model.mm_data,
|
images=images,
|
||||||
)
|
)
|
||||||
|
for prompts, images in inputs_per_image
|
||||||
|
]
|
||||||
|
|
||||||
# Run unquantized model.
|
# Then run HfRunner for HuggingFace baseline comparison.
|
||||||
with vllm_runner(
|
with hf_runner(
|
||||||
model_name=model.original_model,
|
model.original_model,
|
||||||
enforce_eager=True, # faster tests
|
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
max_model_len=model.max_model_len,
|
auto_cls=AutoModelForImageTextToText,
|
||||||
) as original_model:
|
) as hf_model:
|
||||||
original_outputs = original_model.generate_greedy_logprobs(
|
hf_outputs_per_case = [
|
||||||
prompts=model.prompt,
|
hf_model.generate_greedy_logprobs_limit(
|
||||||
max_tokens=max_tokens,
|
prompts,
|
||||||
|
max_tokens,
|
||||||
num_logprobs=num_logprobs,
|
num_logprobs=num_logprobs,
|
||||||
**model.mm_data,
|
images=images,
|
||||||
)
|
)
|
||||||
|
for prompts, images in inputs_per_image
|
||||||
|
]
|
||||||
|
|
||||||
|
for hf_outputs, gguf_outputs in zip(hf_outputs_per_case, gguf_outputs_per_case):
|
||||||
check_logprobs_close(
|
check_logprobs_close(
|
||||||
outputs_0_lst=original_outputs,
|
outputs_0_lst=hf_outputs,
|
||||||
outputs_1_lst=gguf_outputs,
|
outputs_1_lst=gguf_outputs,
|
||||||
name_0="original",
|
name_0="hf",
|
||||||
name_1="gguf",
|
name_1="gguf",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -105,11 +167,14 @@ def run_multimodal_gguf_test(
|
|||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@pytest.mark.parametrize("max_tokens", [32])
|
@pytest.mark.parametrize("max_tokens", [32])
|
||||||
@pytest.mark.parametrize("num_logprobs", [10])
|
@pytest.mark.parametrize("num_logprobs", [10])
|
||||||
def test_models(
|
def test_gemma3_mm_gguf(
|
||||||
|
hf_runner: type[HfRunner],
|
||||||
vllm_runner: type[VllmRunner],
|
vllm_runner: type[VllmRunner],
|
||||||
model: GGUFMMTestConfig,
|
model: GGUFMMTestConfig,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
run_multimodal_gguf_test(vllm_runner, model, dtype, max_tokens, num_logprobs)
|
run_multimodal_gguf_test(
|
||||||
|
hf_runner, vllm_runner, model, dtype, max_tokens, num_logprobs
|
||||||
|
)
|
||||||
|
|||||||
@@ -388,6 +388,7 @@ def run_video_test(config, mm_encoder_attn_backend, video_assets, vllm_runner):
|
|||||||
"mm_encoder_attn_backend",
|
"mm_encoder_attn_backend",
|
||||||
[None] + current_platform.get_supported_vit_attn_backends(),
|
[None] + current_platform.get_supported_vit_attn_backends(),
|
||||||
)
|
)
|
||||||
|
@pytest.mark.skip(reason="Broken test due to memory segmentation fault")
|
||||||
@create_new_process_for_each_test()
|
@create_new_process_for_each_test()
|
||||||
def test_vit_backend_functionality(
|
def test_vit_backend_functionality(
|
||||||
model_key: str,
|
model_key: str,
|
||||||
|
|||||||
@@ -60,12 +60,12 @@ def test_profiling(model_id: str, max_model_len: int):
|
|||||||
total_num_patches.item() + num_tiles.item() + 3
|
total_num_patches.item() + num_tiles.item() + 3
|
||||||
) # image start, image, image end
|
) # image start, image, image end
|
||||||
|
|
||||||
profiled_tokens = profiler.get_mm_max_contiguous_tokens(
|
profiled_tokens = profiler.get_mm_max_tokens(
|
||||||
max_model_len,
|
max_model_len,
|
||||||
mm_counts=mm_counts,
|
mm_counts=mm_counts,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert total_tokens == profiled_tokens["image"]
|
assert total_num_patches == profiled_tokens["image"]
|
||||||
assert total_tokens == sum(
|
assert total_tokens == sum(
|
||||||
placeholder.length
|
placeholder.length
|
||||||
for placeholder in decoder_dummy_data.multi_modal_placeholders["image"]
|
for placeholder in decoder_dummy_data.multi_modal_placeholders["image"]
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from tempfile import NamedTemporaryFile, TemporaryDirectory
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
from PIL import Image, ImageChops
|
from PIL import Image, ImageChops
|
||||||
|
|
||||||
from vllm.multimodal.image import convert_image_mode
|
from vllm.multimodal.image import convert_image_mode
|
||||||
@@ -410,6 +411,97 @@ def test_argsort_mm_positions(case):
|
|||||||
assert modality_idxs == expected_modality_idxs
|
assert modality_idxs == expected_modality_idxs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"is_embed,expected",
|
||||||
|
[
|
||||||
|
(None, 5),
|
||||||
|
(torch.tensor([True, True, True, True, True]), 5),
|
||||||
|
(torch.tensor([False, False, False, False, False]), 0),
|
||||||
|
(torch.tensor([True, False, True, False, True]), 3),
|
||||||
|
(torch.tensor([True]), 1),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_placeholder_range_get_num_embeds(is_embed, expected):
|
||||||
|
length = len(is_embed) if is_embed is not None else 5
|
||||||
|
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
|
||||||
|
assert pr.get_num_embeds == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"is_embed,expected",
|
||||||
|
[
|
||||||
|
(None, None),
|
||||||
|
(
|
||||||
|
torch.tensor([False, True, False, True, True]),
|
||||||
|
torch.tensor([0, 1, 1, 2, 3]),
|
||||||
|
),
|
||||||
|
(torch.tensor([True, True, True]), torch.tensor([1, 2, 3])),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_placeholder_range_embeds_cumsum(is_embed, expected):
|
||||||
|
length = len(is_embed) if is_embed is not None else 5
|
||||||
|
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
|
||||||
|
|
||||||
|
if expected is None:
|
||||||
|
assert pr.embeds_cumsum is None
|
||||||
|
return
|
||||||
|
|
||||||
|
assert torch.equal(pr.embeds_cumsum, expected)
|
||||||
|
# cached_property should return the same object on repeated access
|
||||||
|
assert pr.embeds_cumsum is pr.embeds_cumsum
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"is_embed,start_idx,end_idx,expected",
|
||||||
|
[
|
||||||
|
(None, 2, 4, (2, 4)),
|
||||||
|
(
|
||||||
|
torch.tensor([False, True, False, True, True]),
|
||||||
|
3,
|
||||||
|
5,
|
||||||
|
(1, 3),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
torch.tensor([False, True, False, True, True]),
|
||||||
|
0,
|
||||||
|
2,
|
||||||
|
(0, 1),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
torch.tensor([True, False, True, False]),
|
||||||
|
2,
|
||||||
|
2,
|
||||||
|
(1, 1),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_placeholder_range_get_embeds_indices_in_range(
|
||||||
|
is_embed, start_idx, end_idx, expected
|
||||||
|
):
|
||||||
|
length = len(is_embed) if is_embed is not None else 5
|
||||||
|
pr = PlaceholderRange(offset=0, length=length, is_embed=is_embed)
|
||||||
|
assert pr.get_embeds_indices_in_range(start_idx, end_idx) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"offset,is_embed,expected",
|
||||||
|
[
|
||||||
|
(0, None, [(0, 4)]),
|
||||||
|
(
|
||||||
|
2,
|
||||||
|
torch.tensor([False, True, False, True, True]),
|
||||||
|
[(3, 3), (5, 6)],
|
||||||
|
),
|
||||||
|
(0, torch.tensor([True, True, True, True]), [(0, 3)]),
|
||||||
|
(0, torch.tensor([False, False, False, False]), []),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_placeholder_range_extract_embeds_range(offset, is_embed, expected):
|
||||||
|
length = len(is_embed) if is_embed is not None else 5
|
||||||
|
pr = PlaceholderRange(offset=offset, length=length, is_embed=is_embed)
|
||||||
|
assert pr.extract_embeds_range() == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
@pytest.mark.parametrize("video_url", TEST_VIDEO_URLS)
|
||||||
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
|
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
|
from vllm.multimodal.inputs import MultiModalFeatureSpec, PlaceholderRange
|
||||||
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
|
from vllm.v1.core.encoder_cache_manager import EncoderCacheManager
|
||||||
@@ -23,7 +24,7 @@ class MockRequest:
|
|||||||
)
|
)
|
||||||
self.mm_features.append(feature)
|
self.mm_features.append(feature)
|
||||||
|
|
||||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||||
return self._token_counts[input_id]
|
return self._token_counts[input_id]
|
||||||
|
|
||||||
|
|
||||||
@@ -162,8 +163,8 @@ def test_schedule_request_multi_images_respect_space_limit():
|
|||||||
|
|
||||||
num_tokens_to_schedule = 0
|
num_tokens_to_schedule = 0
|
||||||
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
|
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
|
||||||
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
|
num_tokens_to_schedule += req.get_num_encoder_embeds(0)
|
||||||
compute_budget -= req.get_num_encoder_tokens(0)
|
compute_budget -= req.get_num_encoder_embeds(0)
|
||||||
|
|
||||||
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
|
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
|
||||||
|
|
||||||
@@ -174,7 +175,75 @@ def test_schedule_request_multi_images_respect_compute_limit():
|
|||||||
compute_budget = 10
|
compute_budget = 10
|
||||||
num_tokens_to_schedule = 0
|
num_tokens_to_schedule = 0
|
||||||
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
|
assert manager.can_allocate(req, 0, compute_budget, num_tokens_to_schedule)
|
||||||
num_tokens_to_schedule += req.get_num_encoder_tokens(0)
|
num_tokens_to_schedule += req.get_num_encoder_embeds(0)
|
||||||
compute_budget -= req.get_num_encoder_tokens(0)
|
compute_budget -= req.get_num_encoder_embeds(0)
|
||||||
|
|
||||||
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
|
assert not manager.can_allocate(req, 1, compute_budget, num_tokens_to_schedule)
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder_cache_with_is_embed_mask():
|
||||||
|
class MockRequestWithMask(MockRequest):
|
||||||
|
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||||
|
return self.mm_features[input_id].mm_position.get_num_embeds
|
||||||
|
|
||||||
|
is_embed = torch.zeros(100, dtype=torch.bool)
|
||||||
|
is_embed[torch.tensor([5, 15, 25, 35, 45, 55, 65, 75])] = True
|
||||||
|
|
||||||
|
request = MockRequestWithMask("r1", ["img1"], [100])
|
||||||
|
request.mm_features[0] = MultiModalFeatureSpec(
|
||||||
|
data=None,
|
||||||
|
modality="image",
|
||||||
|
identifier="img1",
|
||||||
|
mm_position=PlaceholderRange(offset=0, length=100, is_embed=is_embed),
|
||||||
|
)
|
||||||
|
|
||||||
|
manager = EncoderCacheManager(cache_size=100)
|
||||||
|
manager.allocate(request, 0)
|
||||||
|
|
||||||
|
assert manager.num_free_slots == 92
|
||||||
|
assert "img1" in manager.cached
|
||||||
|
|
||||||
|
old_size = 100
|
||||||
|
new_size = request.mm_features[0].mm_position.get_num_embeds
|
||||||
|
assert new_size == 8
|
||||||
|
savings_ratio = old_size / new_size
|
||||||
|
assert savings_ratio == 12.5
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoder_cache_mask_based_retrieval():
|
||||||
|
class MockRequestWithMask(MockRequest):
|
||||||
|
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||||
|
return self.mm_features[input_id].mm_position.get_num_embeds
|
||||||
|
|
||||||
|
is_embed = torch.tensor(
|
||||||
|
[False, False, True, True, False, True, True, True, False, False]
|
||||||
|
)
|
||||||
|
|
||||||
|
request = MockRequestWithMask("r1", ["img1"], [10])
|
||||||
|
request.mm_features[0] = MultiModalFeatureSpec(
|
||||||
|
data=None,
|
||||||
|
modality="image",
|
||||||
|
identifier="img1",
|
||||||
|
mm_position=PlaceholderRange(offset=0, length=10, is_embed=is_embed),
|
||||||
|
)
|
||||||
|
|
||||||
|
manager = EncoderCacheManager(cache_size=50)
|
||||||
|
manager.allocate(request, 0)
|
||||||
|
|
||||||
|
assert request.mm_features[0].mm_position.get_num_embeds == 5
|
||||||
|
|
||||||
|
start_idx = 2
|
||||||
|
end_idx = 8
|
||||||
|
num_embeds_before = is_embed[:start_idx].sum().item()
|
||||||
|
num_embeds_in_range = is_embed[start_idx:end_idx].sum().item()
|
||||||
|
|
||||||
|
assert num_embeds_before == 0
|
||||||
|
assert num_embeds_in_range == 5
|
||||||
|
|
||||||
|
start_idx = 0
|
||||||
|
end_idx = 5
|
||||||
|
num_embeds_before = is_embed[:start_idx].sum().item() if start_idx > 0 else 0
|
||||||
|
num_embeds_in_range = is_embed[start_idx:end_idx].sum().item()
|
||||||
|
|
||||||
|
assert num_embeds_before == 0
|
||||||
|
assert num_embeds_in_range == 2
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ class MockRequest:
|
|||||||
)
|
)
|
||||||
self.mm_features.append(feature)
|
self.mm_features.append(feature)
|
||||||
|
|
||||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||||
assert input_id < len(self._token_counts)
|
assert input_id < len(self._token_counts)
|
||||||
return self._token_counts[input_id]
|
return self._token_counts[input_id]
|
||||||
|
|
||||||
|
|||||||
@@ -86,6 +86,9 @@ def kernel_unified_attention_2d(
|
|||||||
USE_SOFTCAP: tl.constexpr, # bool
|
USE_SOFTCAP: tl.constexpr, # bool
|
||||||
USE_SINKS: tl.constexpr, # bool
|
USE_SINKS: tl.constexpr, # bool
|
||||||
SLIDING_WINDOW: tl.constexpr, # int
|
SLIDING_WINDOW: tl.constexpr, # int
|
||||||
|
USE_MM_PREFIX: tl.constexpr, # bool
|
||||||
|
MAX_MM_RANGES: tl.constexpr, # int
|
||||||
|
mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence
|
||||||
stride_k_cache_0: tl.int64, # int
|
stride_k_cache_0: tl.int64, # int
|
||||||
stride_k_cache_1: tl.int64, # int
|
stride_k_cache_1: tl.int64, # int
|
||||||
stride_k_cache_2: tl.int64, # int
|
stride_k_cache_2: tl.int64, # int
|
||||||
@@ -270,7 +273,38 @@ def kernel_unified_attention_2d(
|
|||||||
else:
|
else:
|
||||||
V = V_load
|
V = V_load
|
||||||
|
|
||||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
# Compute attention mask: causal by default (key <= query)
|
||||||
|
query_abs_pos = context_len + query_pos[:, None]
|
||||||
|
seq_mask = seq_offset[None, :] <= query_abs_pos
|
||||||
|
|
||||||
|
# Apply sliding window to base mask BEFORE mm_prefix OR.
|
||||||
|
# Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix
|
||||||
|
if SLIDING_WINDOW > 0:
|
||||||
|
seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW)
|
||||||
|
|
||||||
|
# PrefixLM: extend mask with bidirectional ranges for multimodal tokens.
|
||||||
|
# Applied AFTER sliding window so mm_prefix ranges override SW restriction.
|
||||||
|
if USE_MM_PREFIX:
|
||||||
|
for i in range(MAX_MM_RANGES):
|
||||||
|
range_start = tl.load(
|
||||||
|
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2
|
||||||
|
)
|
||||||
|
range_end = tl.load(
|
||||||
|
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
is_valid = range_start < range_end
|
||||||
|
q_in_range = (
|
||||||
|
(query_abs_pos >= range_start)
|
||||||
|
& (query_abs_pos <= range_end)
|
||||||
|
& is_valid
|
||||||
|
)
|
||||||
|
k_in_range = (
|
||||||
|
(seq_offset[None, :] >= range_start)
|
||||||
|
& (seq_offset[None, :] <= range_end)
|
||||||
|
& is_valid
|
||||||
|
)
|
||||||
|
seq_mask |= q_in_range & k_in_range
|
||||||
|
|
||||||
# S : (BLOCK_M, TILE_SIZE)
|
# S : (BLOCK_M, TILE_SIZE)
|
||||||
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||||
@@ -284,13 +318,6 @@ def kernel_unified_attention_2d(
|
|||||||
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
|
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
|
||||||
)
|
)
|
||||||
|
|
||||||
if SLIDING_WINDOW > 0:
|
|
||||||
S = tl.where(
|
|
||||||
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
|
|
||||||
S,
|
|
||||||
float("-inf"),
|
|
||||||
)
|
|
||||||
|
|
||||||
if USE_ALIBI_SLOPES:
|
if USE_ALIBI_SLOPES:
|
||||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||||
|
|
||||||
@@ -398,6 +425,9 @@ def kernel_unified_attention_3d(
|
|||||||
num_seqs: tl.int32,
|
num_seqs: tl.int32,
|
||||||
BLOCK_M: tl.constexpr, # int
|
BLOCK_M: tl.constexpr, # int
|
||||||
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
||||||
|
USE_MM_PREFIX: tl.constexpr, # bool
|
||||||
|
MAX_MM_RANGES: tl.constexpr, # int
|
||||||
|
mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence
|
||||||
):
|
):
|
||||||
q_block_global_idx = tl.program_id(0)
|
q_block_global_idx = tl.program_id(0)
|
||||||
kv_head_idx = tl.program_id(1)
|
kv_head_idx = tl.program_id(1)
|
||||||
@@ -559,7 +589,38 @@ def kernel_unified_attention_3d(
|
|||||||
else:
|
else:
|
||||||
V = V_load
|
V = V_load
|
||||||
|
|
||||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
# Compute attention mask: causal by default (key <= query)
|
||||||
|
query_abs_pos = context_len + query_pos[:, None]
|
||||||
|
seq_mask = seq_offset[None, :] <= query_abs_pos
|
||||||
|
|
||||||
|
# Apply sliding window to base mask BEFORE mm_prefix OR.
|
||||||
|
# Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix
|
||||||
|
if SLIDING_WINDOW > 0:
|
||||||
|
seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW)
|
||||||
|
|
||||||
|
# PrefixLM: extend mask with bidirectional ranges for multimodal tokens.
|
||||||
|
# Applied AFTER sliding window so mm_prefix ranges override SW restriction.
|
||||||
|
if USE_MM_PREFIX:
|
||||||
|
for i in range(MAX_MM_RANGES):
|
||||||
|
range_start = tl.load(
|
||||||
|
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2
|
||||||
|
)
|
||||||
|
range_end = tl.load(
|
||||||
|
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
is_valid = range_start < range_end
|
||||||
|
q_in_range = (
|
||||||
|
(query_abs_pos >= range_start)
|
||||||
|
& (query_abs_pos <= range_end)
|
||||||
|
& is_valid
|
||||||
|
)
|
||||||
|
k_in_range = (
|
||||||
|
(seq_offset[None, :] >= range_start)
|
||||||
|
& (seq_offset[None, :] <= range_end)
|
||||||
|
& is_valid
|
||||||
|
)
|
||||||
|
seq_mask |= q_in_range & k_in_range
|
||||||
|
|
||||||
# S : (BLOCK_M, TILE_SIZE)
|
# S : (BLOCK_M, TILE_SIZE)
|
||||||
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||||
@@ -572,13 +633,6 @@ def kernel_unified_attention_3d(
|
|||||||
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
|
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
|
||||||
)
|
)
|
||||||
|
|
||||||
if SLIDING_WINDOW > 0:
|
|
||||||
S = tl.where(
|
|
||||||
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
|
|
||||||
S,
|
|
||||||
float("-inf"),
|
|
||||||
)
|
|
||||||
|
|
||||||
if USE_ALIBI_SLOPES:
|
if USE_ALIBI_SLOPES:
|
||||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||||
|
|
||||||
@@ -732,6 +786,43 @@ def reduce_segments(
|
|||||||
tl.store(output_ptr + output_offset, acc, mask=dim_mask)
|
tl.store(output_ptr + output_offset, acc, mask=dim_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_gemma3_attention(head_size: int, sliding_window: int) -> bool:
|
||||||
|
"""Detect Gemma3 models via unique (head_size, sliding_window) signature.
|
||||||
|
|
||||||
|
Gemma3 models are the only ones using sliding_window=1024 with
|
||||||
|
head_size 128 (27B) or 256 (1B, 4B, 12B). Other SWA models use
|
||||||
|
different window sizes (Mistral=4096, Phi-3=2047).
|
||||||
|
"""
|
||||||
|
return sliding_window == 1024 and head_size in (128, 256)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tile_size(
|
||||||
|
head_size: int,
|
||||||
|
sliding_window: int,
|
||||||
|
element_size: int,
|
||||||
|
is_mm_prefix: bool,
|
||||||
|
is_prefill: bool,
|
||||||
|
) -> int:
|
||||||
|
"""Select tile size with Gemma3-specific optimization.
|
||||||
|
|
||||||
|
For Gemma3, use 32 for both prefill and decode to better utilize
|
||||||
|
the larger head dimension (128/256). For other models, use
|
||||||
|
the default vLLM behavior.
|
||||||
|
"""
|
||||||
|
if is_mm_prefix:
|
||||||
|
# Multimodal bidirectional attention needs a larger tile size
|
||||||
|
return 64
|
||||||
|
|
||||||
|
if _is_gemma3_attention(head_size, sliding_window):
|
||||||
|
# Gemma3: use 32 for decode (default is 16)
|
||||||
|
return 32
|
||||||
|
|
||||||
|
# Default behavior
|
||||||
|
if is_prefill:
|
||||||
|
return 32
|
||||||
|
return 16 if element_size >= 2 else 32
|
||||||
|
|
||||||
|
|
||||||
def unified_attention(
|
def unified_attention(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
@@ -759,6 +850,8 @@ def unified_attention(
|
|||||||
qq_bias=None,
|
qq_bias=None,
|
||||||
# Optional tensor for sinks
|
# Optional tensor for sinks
|
||||||
sinks=None,
|
sinks=None,
|
||||||
|
# Optional tensor for prefix lengths (PrefixLM support)
|
||||||
|
mm_prefix_range=None,
|
||||||
):
|
):
|
||||||
assert causal, "Only causal attention is supported"
|
assert causal, "Only causal attention is supported"
|
||||||
assert q_descale is None, "Q scales not supported"
|
assert q_descale is None, "Q scales not supported"
|
||||||
@@ -766,6 +859,17 @@ def unified_attention(
|
|||||||
if sinks is not None:
|
if sinks is not None:
|
||||||
assert sinks.shape[0] == q.shape[1], "Sinks must be num_query_heads size"
|
assert sinks.shape[0] == q.shape[1], "Sinks must be num_query_heads size"
|
||||||
|
|
||||||
|
use_mm_prefix = False
|
||||||
|
max_mm_ranges = 0
|
||||||
|
if mm_prefix_range is not None:
|
||||||
|
if mm_prefix_range.ndim == 3:
|
||||||
|
use_mm_prefix = True
|
||||||
|
max_mm_ranges = mm_prefix_range.shape[1]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported mm_prefix_range shape: {mm_prefix_range.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
use_alibi_slopes = alibi_slopes is not None
|
use_alibi_slopes = alibi_slopes is not None
|
||||||
use_qq_bias = qq_bias is not None
|
use_qq_bias = qq_bias is not None
|
||||||
|
|
||||||
@@ -792,11 +896,23 @@ def unified_attention(
|
|||||||
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
|
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
|
||||||
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
|
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
|
||||||
|
|
||||||
# Assigning default tile sizes for prefill and decode.
|
# Tile sizes for prefill and decode. Gemma3 models use optimized values.
|
||||||
# Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1)
|
# Note: tile size must be at least 32 for fp8 (element_size == 1).
|
||||||
# and at least 16 for all other data types.
|
sliding_window_val = 1 + window_size[0] if window_size[0] >= 0 else 0
|
||||||
TILE_SIZE_PREFILL = 32
|
TILE_SIZE_PREFILL = _get_tile_size(
|
||||||
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
|
head_size,
|
||||||
|
sliding_window_val,
|
||||||
|
q.element_size(),
|
||||||
|
is_mm_prefix=use_mm_prefix,
|
||||||
|
is_prefill=True,
|
||||||
|
)
|
||||||
|
TILE_SIZE_DECODE = _get_tile_size(
|
||||||
|
head_size,
|
||||||
|
sliding_window_val,
|
||||||
|
q.element_size(),
|
||||||
|
is_mm_prefix=use_mm_prefix,
|
||||||
|
is_prefill=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Launch the 2D kernel if
|
# Launch the 2D kernel if
|
||||||
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
|
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
|
||||||
@@ -847,6 +963,9 @@ def unified_attention(
|
|||||||
USE_QQ_BIAS=use_qq_bias,
|
USE_QQ_BIAS=use_qq_bias,
|
||||||
USE_SOFTCAP=(softcap > 0),
|
USE_SOFTCAP=(softcap > 0),
|
||||||
USE_SINKS=(sinks is not None),
|
USE_SINKS=(sinks is not None),
|
||||||
|
USE_MM_PREFIX=use_mm_prefix,
|
||||||
|
MAX_MM_RANGES=max_mm_ranges,
|
||||||
|
mm_prefix_range_ptr=mm_prefix_range,
|
||||||
SLIDING_WINDOW=(1 + window_size[0]),
|
SLIDING_WINDOW=(1 + window_size[0]),
|
||||||
stride_k_cache_0=k.stride(0),
|
stride_k_cache_0=k.stride(0),
|
||||||
stride_k_cache_1=k.stride(1),
|
stride_k_cache_1=k.stride(1),
|
||||||
@@ -895,6 +1014,9 @@ def unified_attention(
|
|||||||
USE_QQ_BIAS=use_qq_bias,
|
USE_QQ_BIAS=use_qq_bias,
|
||||||
USE_SOFTCAP=(softcap > 0),
|
USE_SOFTCAP=(softcap > 0),
|
||||||
USE_SINKS=(sinks is not None),
|
USE_SINKS=(sinks is not None),
|
||||||
|
USE_MM_PREFIX=use_mm_prefix,
|
||||||
|
MAX_MM_RANGES=max_mm_ranges,
|
||||||
|
mm_prefix_range_ptr=mm_prefix_range,
|
||||||
SLIDING_WINDOW=(1 + window_size[0]),
|
SLIDING_WINDOW=(1 + window_size[0]),
|
||||||
stride_k_cache_0=k.stride(0),
|
stride_k_cache_0=k.stride(0),
|
||||||
stride_k_cache_1=k.stride(1),
|
stride_k_cache_1=k.stride(1),
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import einops
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
@@ -89,6 +90,13 @@ def torch_sdpa_wrapper(
|
|||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
cu_seqlens: torch.Tensor,
|
cu_seqlens: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
# Never remove the contiguous logic for ROCm
|
||||||
|
# Without it, hallucinations occur with the backend
|
||||||
|
if current_platform.is_rocm():
|
||||||
|
q = q.contiguous()
|
||||||
|
k = k.contiguous()
|
||||||
|
v = v.contiguous()
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ class ECExampleConnector(ECConnectorBase):
|
|||||||
Update ECConnector state after encoder cache allocation.
|
Update ECConnector state after encoder cache allocation.
|
||||||
"""
|
"""
|
||||||
mm_hash = request.mm_features[index].identifier
|
mm_hash = request.mm_features[index].identifier
|
||||||
num_encoder_token = request.get_num_encoder_tokens(index)
|
num_encoder_token = request.get_num_encoder_embeds(index)
|
||||||
# Insert mm_hash only if this block has not been recorded yet.
|
# Insert mm_hash only if this block has not been recorded yet.
|
||||||
self._mm_datas_need_loads[mm_hash] = num_encoder_token
|
self._mm_datas_need_loads[mm_hash] = num_encoder_token
|
||||||
|
|
||||||
|
|||||||
@@ -795,7 +795,10 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
top_k,
|
top_k,
|
||||||
global_num_experts,
|
global_num_experts,
|
||||||
local_num_experts,
|
local_num_experts,
|
||||||
expert_tokens_meta,
|
# expert_tokens_meta help in allocating optimal/minimal
|
||||||
|
# amount of workspace. Mark it None, so we allocate for
|
||||||
|
# the worst-case scenario.
|
||||||
|
expert_tokens_meta=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,10 @@ from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
|
|||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8LinearMethod
|
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8LinearMethod
|
||||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
maybe_create_device_identity,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.parameter import ModelWeightParameter
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@@ -305,6 +309,37 @@ class XPUFp8LinearMethod(Fp8LinearMethod):
|
|||||||
def __init__(self, quant_config: Fp8Config):
|
def __init__(self, quant_config: Fp8Config):
|
||||||
super().__init__(quant_config)
|
super().__init__(quant_config)
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
input_size_per_partition: int,
|
||||||
|
output_partition_sizes: list[int],
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
maybe_create_device_identity()
|
||||||
|
|
||||||
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
|
layer.logical_widths = output_partition_sizes
|
||||||
|
layer.input_size_per_partition = input_size_per_partition
|
||||||
|
layer.output_size_per_partition = output_size_per_partition
|
||||||
|
layer.orig_dtype = params_dtype
|
||||||
|
layer.weight_block_size = None
|
||||||
|
weight = ModelWeightParameter(
|
||||||
|
data=torch.empty(
|
||||||
|
output_size_per_partition,
|
||||||
|
input_size_per_partition,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("weight", weight)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
# If checkpoint not serialized fp8, quantize the weights.
|
# If checkpoint not serialized fp8, quantize the weights.
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import torch
|
|||||||
from vllm._aiter_ops import rocm_aiter_ops
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
from .common import apply_rotary_emb_torch
|
from .common import ApplyRotaryEmb
|
||||||
|
|
||||||
|
|
||||||
@CustomOp.register("rotary_embedding")
|
@CustomOp.register("rotary_embedding")
|
||||||
@@ -49,6 +49,10 @@ class RotaryEmbeddingBase(CustomOp):
|
|||||||
rocm_aiter_ops.is_triton_rotary_embed_enabled()
|
rocm_aiter_ops.is_triton_rotary_embed_enabled()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.apply_rotary_emb = ApplyRotaryEmb(
|
||||||
|
is_neox_style=self.is_neox_style,
|
||||||
|
)
|
||||||
|
|
||||||
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
||||||
"""Compute the inverse frequency."""
|
"""Compute the inverse frequency."""
|
||||||
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||||
@@ -123,7 +127,12 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
|||||||
query = query.view(num_tokens, -1, head_size)
|
query = query.view(num_tokens, -1, head_size)
|
||||||
query_rot = query[..., :rotary_dim]
|
query_rot = query[..., :rotary_dim]
|
||||||
query_pass = query[..., rotary_dim:]
|
query_pass = query[..., rotary_dim:]
|
||||||
query_rot = apply_rotary_emb_torch(query_rot, cos, sin, is_neox_style)
|
query_rot = ApplyRotaryEmb.forward_static(
|
||||||
|
query_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
is_neox_style,
|
||||||
|
)
|
||||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||||
|
|
||||||
# key may be None in some cases, e.g. cross-layer KV sharing
|
# key may be None in some cases, e.g. cross-layer KV sharing
|
||||||
@@ -132,7 +141,12 @@ class RotaryEmbedding(RotaryEmbeddingBase):
|
|||||||
key = key.view(num_tokens, -1, head_size)
|
key = key.view(num_tokens, -1, head_size)
|
||||||
key_rot = key[..., :rotary_dim]
|
key_rot = key[..., :rotary_dim]
|
||||||
key_pass = key[..., rotary_dim:]
|
key_pass = key[..., rotary_dim:]
|
||||||
key_rot = apply_rotary_emb_torch(key_rot, cos, sin, is_neox_style)
|
key_rot = ApplyRotaryEmb.forward_static(
|
||||||
|
key_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
is_neox_style,
|
||||||
|
)
|
||||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
|||||||
@@ -2,19 +2,14 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from collections.abc import Callable
|
|
||||||
from functools import cache
|
|
||||||
from importlib.util import find_spec
|
from importlib.util import find_spec
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
from vllm.utils.torch_utils import direct_register_custom_op
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
|
||||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -32,71 +27,6 @@ def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
|
|||||||
return x.flatten(-2)
|
return x.flatten(-2)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb_torch(
|
|
||||||
x: torch.Tensor,
|
|
||||||
cos: torch.Tensor,
|
|
||||||
sin: torch.Tensor,
|
|
||||||
is_neox_style: bool,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
|
||||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
|
||||||
if is_neox_style:
|
|
||||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
|
||||||
else:
|
|
||||||
x1 = x[..., ::2]
|
|
||||||
x2 = x[..., 1::2]
|
|
||||||
o1 = x1 * cos - x2 * sin
|
|
||||||
o2 = x2 * cos + x1 * sin
|
|
||||||
if is_neox_style:
|
|
||||||
return torch.cat((o1, o2), dim=-1)
|
|
||||||
else:
|
|
||||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb_dispatch(
|
|
||||||
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, is_neox_style: bool
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
x: [num_tokens, num_heads, head_size]
|
|
||||||
cos: [num_tokens, head_size // 2]
|
|
||||||
sin: [num_tokens, head_size // 2]
|
|
||||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
|
||||||
positional embeddings.
|
|
||||||
"""
|
|
||||||
if current_platform.is_cuda():
|
|
||||||
return apply_rotary_emb(x.unsqueeze(0), cos, sin, not is_neox_style).squeeze(0)
|
|
||||||
else:
|
|
||||||
return apply_rotary_emb_torch(x, cos, sin, is_neox_style)
|
|
||||||
|
|
||||||
|
|
||||||
@cache
|
|
||||||
def dispatch_rotary_emb_function(
|
|
||||||
default: Callable[..., torch.Tensor] | None = None,
|
|
||||||
) -> Callable[..., torch.Tensor]:
|
|
||||||
if current_platform.is_cuda():
|
|
||||||
return apply_rotary_emb
|
|
||||||
|
|
||||||
# if torch compile is not enabled
|
|
||||||
# use rotary embedding function from flash_attn package
|
|
||||||
# otherwise use the naive pytorch embedding implementation
|
|
||||||
# is faster when torch compile is enabled.
|
|
||||||
if current_platform.is_rocm() and not torch.compiler.is_compiling():
|
|
||||||
if find_spec("flash_attn") is not None:
|
|
||||||
from flash_attn.ops.triton.rotary import apply_rotary
|
|
||||||
|
|
||||||
return apply_rotary
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"flash_attn is not installed. Falling back to PyTorch "
|
|
||||||
"implementation for rotary embeddings."
|
|
||||||
)
|
|
||||||
if default is not None:
|
|
||||||
return default
|
|
||||||
|
|
||||||
return apply_rotary_emb_torch
|
|
||||||
|
|
||||||
|
|
||||||
# yarn functions
|
# yarn functions
|
||||||
# Inverse dim formula to find dim based on number of rotations
|
# Inverse dim formula to find dim based on number of rotations
|
||||||
def yarn_find_correction_dim(
|
def yarn_find_correction_dim(
|
||||||
@@ -186,3 +116,164 @@ direct_register_custom_op(
|
|||||||
mutates_args=["query", "key"], # These tensors are modified in-place
|
mutates_args=["query", "key"], # These tensors are modified in-place
|
||||||
fake_impl=_flashinfer_rotary_embedding_fake,
|
fake_impl=_flashinfer_rotary_embedding_fake,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@CustomOp.register("apply_rotary_emb")
|
||||||
|
class ApplyRotaryEmb(CustomOp):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
enforce_enable: bool = False,
|
||||||
|
is_neox_style: bool = True,
|
||||||
|
enable_fp32_compute: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(enforce_enable)
|
||||||
|
self.is_neox_style = is_neox_style
|
||||||
|
self.enable_fp32_compute = enable_fp32_compute
|
||||||
|
|
||||||
|
self.apply_rotary_emb_flash_attn = None
|
||||||
|
if find_spec("flash_attn") is not None:
|
||||||
|
from flash_attn.ops.triton.rotary import apply_rotary
|
||||||
|
|
||||||
|
self.apply_rotary_emb_flash_attn = apply_rotary
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward_static(
|
||||||
|
x: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
is_neox_style: bool = True,
|
||||||
|
enable_fp32_compute: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: [batch_size (optional), seq_len, num_heads, head_size]
|
||||||
|
cos: [seq_len, head_size // 2]
|
||||||
|
sin: [seq_len, head_size // 2]
|
||||||
|
is_neox_style: Whether to use the Neox-style or GPT-J-style.
|
||||||
|
enable_fp32_compute: Temporarily convert x, cos, sin to FP32 dtype
|
||||||
|
for higher accuracy.
|
||||||
|
"""
|
||||||
|
origin_dtype = x.dtype
|
||||||
|
if enable_fp32_compute:
|
||||||
|
x = x.float()
|
||||||
|
|
||||||
|
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||||
|
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||||
|
|
||||||
|
if is_neox_style:
|
||||||
|
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||||
|
else:
|
||||||
|
x1 = x[..., ::2]
|
||||||
|
x2 = x[..., 1::2]
|
||||||
|
|
||||||
|
o1 = x1 * cos - x2 * sin
|
||||||
|
o2 = x2 * cos + x1 * sin
|
||||||
|
|
||||||
|
if is_neox_style:
|
||||||
|
output = torch.cat((o1, o2), dim=-1)
|
||||||
|
else:
|
||||||
|
output = torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||||
|
|
||||||
|
if enable_fp32_compute:
|
||||||
|
output = output.to(origin_dtype)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward_native(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
output = self.forward_static(
|
||||||
|
x, cos, sin, self.is_neox_style, self.enable_fp32_compute
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward_cuda(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||||
|
|
||||||
|
origin_dtype = x.dtype
|
||||||
|
if self.enable_fp32_compute:
|
||||||
|
x = x.float()
|
||||||
|
cos = cos.float()
|
||||||
|
sin = sin.float()
|
||||||
|
|
||||||
|
origin_shape = x.shape
|
||||||
|
if len(origin_shape) == 3:
|
||||||
|
# x: [seq_len, num_heads, head_size]
|
||||||
|
x = x.unsqueeze(0)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Arguments of apply_rotary_emb() in vllm_flash_attn:
|
||||||
|
x: [batch_size, seq_len, nheads, headdim]
|
||||||
|
cos, sin: [seqlen_rotary, rotary_dim / 2]
|
||||||
|
interleaved: defalut as False (Neox-style).
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
interleaved = not self.is_neox_style
|
||||||
|
output = apply_rotary_emb(x, cos, sin, interleaved)
|
||||||
|
|
||||||
|
if len(origin_shape) == 3:
|
||||||
|
output = output.squeeze(0)
|
||||||
|
if self.enable_fp32_compute:
|
||||||
|
output = output.to(origin_dtype)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward_hip(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if self.apply_rotary_emb_flash_attn is not None:
|
||||||
|
origin_dtype = x.dtype
|
||||||
|
if self.enable_fp32_compute:
|
||||||
|
x = x.float()
|
||||||
|
cos = cos.float()
|
||||||
|
sin = sin.float()
|
||||||
|
|
||||||
|
origin_shape = x.shape
|
||||||
|
if len(origin_shape) == 3:
|
||||||
|
# x: [seq_len, num_heads, head_size]
|
||||||
|
x = x.unsqueeze(0)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Arguments of apply_rotary() in flash_attn:
|
||||||
|
x: [batch_size, seq_len, nheads, headdim]
|
||||||
|
cos, sin: [seqlen_rotary, rotary_dim / 2]
|
||||||
|
interleaved: defalut as False (Neox-style).
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
interleaved = not self.is_neox_style
|
||||||
|
output = self.apply_rotary_emb_flash_attn(
|
||||||
|
x, cos, sin, interleaved=interleaved
|
||||||
|
).type_as(x)
|
||||||
|
|
||||||
|
if len(origin_shape) == 3:
|
||||||
|
output = output.squeeze(0)
|
||||||
|
if self.enable_fp32_compute:
|
||||||
|
output = output.to(origin_dtype)
|
||||||
|
else:
|
||||||
|
# Falling back to PyTorch native implementation.
|
||||||
|
output = self.forward_native(x, cos, sin)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward_cpu(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# TODO (bigPYJ1151): need to enable fused CPU ROPE here
|
||||||
|
return self.forward_native(x, cos, sin)
|
||||||
|
|
||||||
|
def extra_repr(self) -> str:
|
||||||
|
s = f"is_neox_style={self.is_neox_style}"
|
||||||
|
s += f"enable_fp32_compute={self.enable_fp32_compute}"
|
||||||
|
return s
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .common import apply_rotary_emb_dispatch
|
|
||||||
from .mrope import MRotaryEmbedding
|
from .mrope import MRotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
@@ -55,14 +54,22 @@ class Ernie4_5_VLRotaryEmbedding(MRotaryEmbedding):
|
|||||||
query = query.view(num_tokens, -1, self.head_size)
|
query = query.view(num_tokens, -1, self.head_size)
|
||||||
query_rot = query[..., : self.rotary_dim]
|
query_rot = query[..., : self.rotary_dim]
|
||||||
query_pass = query[..., self.rotary_dim :]
|
query_pass = query[..., self.rotary_dim :]
|
||||||
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
|
query_rot = self.apply_rotary_emb.forward_native(
|
||||||
|
query_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
)
|
||||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||||
|
|
||||||
key_shape = key.shape
|
key_shape = key.shape
|
||||||
key = key.view(num_tokens, -1, self.head_size)
|
key = key.view(num_tokens, -1, self.head_size)
|
||||||
key_rot = key[..., : self.rotary_dim]
|
key_rot = key[..., : self.rotary_dim]
|
||||||
key_pass = key[..., self.rotary_dim :]
|
key_pass = key[..., self.rotary_dim :]
|
||||||
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
|
key_rot = self.apply_rotary_emb.forward_native(
|
||||||
|
key_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
)
|
||||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import torch
|
|||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
from .base import RotaryEmbeddingBase
|
from .base import RotaryEmbeddingBase
|
||||||
from .common import apply_rotary_emb_dispatch
|
|
||||||
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale
|
from .yarn_scaling_rope import YaRNScalingRotaryEmbedding, yarn_get_mscale
|
||||||
|
|
||||||
|
|
||||||
@@ -301,14 +300,22 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
|
|||||||
query = query.view(num_tokens, -1, self.head_size)
|
query = query.view(num_tokens, -1, self.head_size)
|
||||||
query_rot = query[..., : self.rotary_dim]
|
query_rot = query[..., : self.rotary_dim]
|
||||||
query_pass = query[..., self.rotary_dim :]
|
query_pass = query[..., self.rotary_dim :]
|
||||||
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
|
query_rot = self.apply_rotary_emb.forward_native(
|
||||||
|
query_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
)
|
||||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||||
|
|
||||||
key_shape = key.shape
|
key_shape = key.shape
|
||||||
key = key.view(num_tokens, -1, self.head_size)
|
key = key.view(num_tokens, -1, self.head_size)
|
||||||
key_rot = key[..., : self.rotary_dim]
|
key_rot = key[..., : self.rotary_dim]
|
||||||
key_pass = key[..., self.rotary_dim :]
|
key_pass = key[..., self.rotary_dim :]
|
||||||
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
|
key_rot = self.apply_rotary_emb.forward_native(
|
||||||
|
key_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
)
|
||||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
@@ -347,13 +354,21 @@ class MRotaryEmbedding(RotaryEmbeddingBase):
|
|||||||
query = query.view(num_tokens, -1, self.head_size)
|
query = query.view(num_tokens, -1, self.head_size)
|
||||||
query_rot = query[..., : self.rotary_dim]
|
query_rot = query[..., : self.rotary_dim]
|
||||||
query_pass = query[..., self.rotary_dim :]
|
query_pass = query[..., self.rotary_dim :]
|
||||||
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
|
query_rot = self.apply_rotary_emb(
|
||||||
|
query_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
)
|
||||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||||
|
|
||||||
key = key.view(num_tokens, -1, self.head_size)
|
key = key.view(num_tokens, -1, self.head_size)
|
||||||
key_rot = key[..., : self.rotary_dim]
|
key_rot = key[..., : self.rotary_dim]
|
||||||
key_pass = key[..., self.rotary_dim :]
|
key_pass = key[..., self.rotary_dim :]
|
||||||
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
|
key_rot = self.apply_rotary_emb(
|
||||||
|
key_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
)
|
||||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .common import apply_rotary_emb_dispatch
|
|
||||||
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
|
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
@@ -36,7 +35,7 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
|
|||||||
dtype,
|
dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward_native(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@@ -68,14 +67,73 @@ class XDRotaryEmbedding(DynamicNTKAlphaRotaryEmbedding):
|
|||||||
query = query.view(num_tokens, -1, self.head_size)
|
query = query.view(num_tokens, -1, self.head_size)
|
||||||
query_rot = query[..., : self.rotary_dim]
|
query_rot = query[..., : self.rotary_dim]
|
||||||
query_pass = query[..., self.rotary_dim :]
|
query_pass = query[..., self.rotary_dim :]
|
||||||
query_rot = apply_rotary_emb_dispatch(query_rot, cos, sin, self.is_neox_style)
|
query_rot = self.apply_rotary_emb.forward_native(
|
||||||
|
query_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
)
|
||||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||||
|
|
||||||
key_shape = key.shape
|
key_shape = key.shape
|
||||||
key = key.view(num_tokens, -1, self.head_size)
|
key = key.view(num_tokens, -1, self.head_size)
|
||||||
key_rot = key[..., : self.rotary_dim]
|
key_rot = key[..., : self.rotary_dim]
|
||||||
key_pass = key[..., self.rotary_dim :]
|
key_pass = key[..., self.rotary_dim :]
|
||||||
key_rot = apply_rotary_emb_dispatch(key_rot, cos, sin, self.is_neox_style)
|
key_rot = self.apply_rotary_emb.forward_native(
|
||||||
|
key_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
)
|
||||||
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
|
return query, key
|
||||||
|
|
||||||
|
def forward_cuda(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor | None = None,
|
||||||
|
offsets: torch.Tensor | None = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||||
|
"""PyTorch-native implementation equivalent to forward().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
positions:
|
||||||
|
[4, num_tokens] (P/W/H/T positions with multimodal inputs)
|
||||||
|
query: [num_tokens, num_heads * head_size]
|
||||||
|
key: [num_tokens, num_kv_heads * head_size]
|
||||||
|
"""
|
||||||
|
assert positions.ndim == 2
|
||||||
|
assert key is not None
|
||||||
|
|
||||||
|
num_tokens = positions.shape[-1]
|
||||||
|
cos_sin = self.cos_sin_cache[positions]
|
||||||
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||||
|
cos = torch.cat(
|
||||||
|
[m[i] for i, m in enumerate(cos.split(self.xdrope_section, dim=-1))], dim=-1
|
||||||
|
)
|
||||||
|
sin = torch.cat(
|
||||||
|
[m[i] for i, m in enumerate(sin.split(self.xdrope_section, dim=-1))], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
query_shape = query.shape
|
||||||
|
query = query.view(num_tokens, -1, self.head_size)
|
||||||
|
query_rot = query[..., : self.rotary_dim]
|
||||||
|
query_pass = query[..., self.rotary_dim :]
|
||||||
|
query_rot = self.apply_rotary_emb(
|
||||||
|
query_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
)
|
||||||
|
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||||
|
|
||||||
|
key_shape = key.shape
|
||||||
|
key = key.view(num_tokens, -1, self.head_size)
|
||||||
|
key_rot = key[..., : self.rotary_dim]
|
||||||
|
key_pass = key[..., self.rotary_dim :]
|
||||||
|
key_rot = self.apply_rotary_emb(
|
||||||
|
key_rot,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
)
|
||||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,9 @@ from vllm.model_executor.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||||
|
ApplyRotaryEmb,
|
||||||
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.interfaces import (
|
from vllm.model_executor.models.interfaces import (
|
||||||
MultiModalEmbeddings,
|
MultiModalEmbeddings,
|
||||||
@@ -158,32 +161,6 @@ class DotsOCRProcessingInfo(Qwen2VLProcessingInfo):
|
|||||||
return processor
|
return processor
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
"""Rotates half the hidden dims of the input."""
|
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb_vision(
|
|
||||||
tensor: torch.Tensor, freqs: torch.Tensor
|
|
||||||
) -> torch.Tensor:
|
|
||||||
orig_dtype = tensor.dtype
|
|
||||||
tensor = tensor.float()
|
|
||||||
|
|
||||||
cos = freqs.cos()
|
|
||||||
sin = freqs.sin()
|
|
||||||
|
|
||||||
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
|
||||||
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
|
||||||
|
|
||||||
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
|
||||||
|
|
||||||
output = output.to(orig_dtype)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class VisionRotaryEmbedding(nn.Module):
|
class VisionRotaryEmbedding(nn.Module):
|
||||||
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -298,6 +275,11 @@ class DotsVisionAttention(nn.Module):
|
|||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.apply_rotary_emb = ApplyRotaryEmb(
|
||||||
|
enforce_enable=True,
|
||||||
|
enable_fp32_compute=True,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -318,7 +300,11 @@ class DotsVisionAttention(nn.Module):
|
|||||||
|
|
||||||
if rotary_pos_emb is not None:
|
if rotary_pos_emb is not None:
|
||||||
qk_concat = torch.cat([q, k], dim=0)
|
qk_concat = torch.cat([q, k], dim=0)
|
||||||
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
qk_rotated = self.apply_rotary_emb(
|
||||||
|
qk_concat,
|
||||||
|
rotary_pos_emb.cos(),
|
||||||
|
rotary_pos_emb.sin(),
|
||||||
|
)
|
||||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||||
|
|
||||||
context_layer = self.attn(
|
context_layer = self.attn(
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange
|
||||||
from transformers import BatchFeature
|
from transformers import BatchFeature
|
||||||
|
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
@@ -53,6 +53,9 @@ from vllm.model_executor.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||||
|
ApplyRotaryEmb,
|
||||||
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (
|
from vllm.multimodal.inputs import (
|
||||||
@@ -69,7 +72,6 @@ from vllm.multimodal.processing import (
|
|||||||
PromptUpdate,
|
PromptUpdate,
|
||||||
)
|
)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
@@ -89,52 +91,6 @@ logger = init_logger(__name__)
|
|||||||
# === Vision Transformer === #
|
# === Vision Transformer === #
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
|
|
||||||
if not interleaved:
|
|
||||||
x1, x2 = x.chunk(2, dim=-1)
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
else:
|
|
||||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
|
||||||
return rearrange(
|
|
||||||
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb_torch(
|
|
||||||
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
x: (batch_size, seqlen, nheads, headdim)
|
|
||||||
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
|
||||||
"""
|
|
||||||
ro_dim = cos.shape[-1] * 2
|
|
||||||
assert ro_dim <= x.shape[-1]
|
|
||||||
cos = repeat(
|
|
||||||
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
|
||||||
)
|
|
||||||
sin = repeat(
|
|
||||||
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
|
||||||
)
|
|
||||||
return torch.cat(
|
|
||||||
[
|
|
||||||
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
|
||||||
x[..., ro_dim:],
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
|
||||||
t_ = t.float()
|
|
||||||
cos = freqs.cos()
|
|
||||||
sin = freqs.sin()
|
|
||||||
apply_rotary_emb = apply_rotary_emb_torch
|
|
||||||
if current_platform.is_cuda():
|
|
||||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
|
||||||
output = apply_rotary_emb(t_, cos, sin).type_as(t)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
|
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
|
||||||
"""All-gather the input tensor interleavely across model parallel group."""
|
"""All-gather the input tensor interleavely across model parallel group."""
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -200,6 +156,11 @@ class Ernie4_5_VisionAttention(nn.Module):
|
|||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.apply_rotary_emb = ApplyRotaryEmb(
|
||||||
|
enforce_enable=True,
|
||||||
|
enable_fp32_compute=True,
|
||||||
|
)
|
||||||
|
|
||||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||||
# [s, b, 3 * head * head_dim]
|
# [s, b, 3 * head * head_dim]
|
||||||
seq_len, bs, _ = qkv.shape
|
seq_len, bs, _ = qkv.shape
|
||||||
@@ -244,7 +205,11 @@ class Ernie4_5_VisionAttention(nn.Module):
|
|||||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
|
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
|
||||||
if rotary_pos_emb is not None:
|
if rotary_pos_emb is not None:
|
||||||
qk_concat = torch.cat([q, k], dim=0)
|
qk_concat = torch.cat([q, k], dim=0)
|
||||||
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
qk_rotated = self.apply_rotary_emb(
|
||||||
|
qk_concat,
|
||||||
|
rotary_pos_emb.cos(),
|
||||||
|
rotary_pos_emb.sin(),
|
||||||
|
)
|
||||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||||
|
|
||||||
output = self.attn(
|
output = self.attn(
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from collections.abc import Iterable
|
|||||||
from itertools import islice
|
from itertools import islice
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import Gemma3TextConfig
|
from transformers import Gemma3TextConfig
|
||||||
|
|
||||||
@@ -223,77 +222,9 @@ class Gemma3Attention(nn.Module):
|
|||||||
|
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v)
|
attn_output = self.attn(q, k, v)
|
||||||
|
|
||||||
if not kwargs.get("has_images", False):
|
|
||||||
# Fast path for text-only inputs. The performance for the text-only
|
|
||||||
# inputs are not affected by the naive attention below.
|
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
# NOTE(woosuk): Gemma3 uses bidirectional attention between image tokens
|
|
||||||
# that correspond to the same image while using causal attention
|
|
||||||
# otherwise. Current attention backends cannot handle this pattern, so
|
|
||||||
# we temporarily use a naive attention implementation with mask tensors.
|
|
||||||
|
|
||||||
# We intentionally keep the attention backend as-is and only override
|
|
||||||
# `attn_output` with the naive implementation's output. This minimizes
|
|
||||||
# changes to existing model runners and attention backends. The call to
|
|
||||||
# `self.attn(q, k, v)` is only used to populate the KV cache - its
|
|
||||||
# output is discarded and overwritten below. While this duplicates
|
|
||||||
# computation, it maintains compatibility.
|
|
||||||
# TODO(woosuk): Optimize by implementing custom attention kernels.
|
|
||||||
attn_output = self.naive_attn_with_masks(q, k, v, out=attn_output, **kwargs)
|
|
||||||
output, _ = self.o_proj(attn_output)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def naive_attn_with_masks(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
out: torch.Tensor,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# NOTE(woosuk): As described in the comment above, this code is not
|
|
||||||
# meant to be performant. It is only meant to be correct.
|
|
||||||
q = q.view(-1, self.num_heads, self.head_dim)
|
|
||||||
# Expand the key and value to handle GQA.
|
|
||||||
num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
||||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
|
||||||
k = k.repeat_interleave(num_queries_per_kv, dim=-2)
|
|
||||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
|
||||||
v = v.repeat_interleave(num_queries_per_kv, dim=-2)
|
|
||||||
|
|
||||||
if self.is_sliding:
|
|
||||||
attn_masks = kwargs["local_attn_masks"]
|
|
||||||
else:
|
|
||||||
attn_masks = kwargs["global_attn_masks"]
|
|
||||||
|
|
||||||
seq_lens = kwargs["seq_lens"]
|
|
||||||
start_idx = 0
|
|
||||||
for seq_len, attn_mask in zip(seq_lens, attn_masks):
|
|
||||||
end_idx = start_idx + seq_len
|
|
||||||
query = q[start_idx:end_idx].unsqueeze(0)
|
|
||||||
key = k[start_idx:end_idx].unsqueeze(0)
|
|
||||||
value = v[start_idx:end_idx].unsqueeze(0)
|
|
||||||
|
|
||||||
# Transpose.
|
|
||||||
query = query.transpose(1, 2)
|
|
||||||
key = key.transpose(1, 2)
|
|
||||||
value = value.transpose(1, 2)
|
|
||||||
|
|
||||||
output = F.scaled_dot_product_attention(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
attn_mask,
|
|
||||||
self.scaling,
|
|
||||||
)
|
|
||||||
output = output.transpose(1, 2).flatten(-2, -1)
|
|
||||||
out[start_idx:end_idx] = output
|
|
||||||
start_idx = end_idx
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class Gemma3DecoderLayer(nn.Module):
|
class Gemma3DecoderLayer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -65,6 +65,9 @@ from vllm.model_executor.layers.linear import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||||
|
ApplyRotaryEmb,
|
||||||
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
@@ -95,7 +98,7 @@ from .interfaces import (
|
|||||||
SupportsMultiModal,
|
SupportsMultiModal,
|
||||||
SupportsPP,
|
SupportsPP,
|
||||||
)
|
)
|
||||||
from .qwen2_vl import _create_qwen2vl_field_factory, apply_rotary_pos_emb_vision
|
from .qwen2_vl import _create_qwen2vl_field_factory
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
WeightsMapper,
|
WeightsMapper,
|
||||||
@@ -304,6 +307,8 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
multimodal_config=multimodal_config,
|
multimodal_config=multimodal_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
|
||||||
|
|
||||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||||
# [s, b, 3 * head * head_dim]
|
# [s, b, 3 * head * head_dim]
|
||||||
seq_len, bs, _ = qkv.shape
|
seq_len, bs, _ = qkv.shape
|
||||||
@@ -339,8 +344,10 @@ class Glm4vVisionAttention(nn.Module):
|
|||||||
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
|
if rotary_pos_emb_cos is not None and rotary_pos_emb_sin is not None:
|
||||||
# [2 * b, s, heads, head_dim]
|
# [2 * b, s, heads, head_dim]
|
||||||
qk_concat = torch.cat([q, k], dim=0)
|
qk_concat = torch.cat([q, k], dim=0)
|
||||||
qk_rotated = apply_rotary_pos_emb_vision(
|
qk_rotated = self.apply_rotary_emb(
|
||||||
qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
|
qk_concat,
|
||||||
|
rotary_pos_emb_cos,
|
||||||
|
rotary_pos_emb_sin,
|
||||||
)
|
)
|
||||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,9 @@ from vllm.model_executor.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||||
|
ApplyRotaryEmb,
|
||||||
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
maybe_remap_kv_scale_name,
|
maybe_remap_kv_scale_name,
|
||||||
@@ -59,7 +62,6 @@ from vllm.multimodal.processing import (
|
|||||||
PromptUpdate,
|
PromptUpdate,
|
||||||
)
|
)
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.platforms import current_platform
|
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||||
|
|
||||||
@@ -341,20 +343,14 @@ def apply_rotary_pos_emb_flashatt(
|
|||||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||||
|
|
||||||
if current_platform.is_cuda():
|
apply_rotary_emb = ApplyRotaryEmb(
|
||||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
enforce_enable=True,
|
||||||
elif current_platform.is_rocm():
|
enable_fp32_compute=True,
|
||||||
from flash_attn.ops.triton.rotary import apply_rotary as apply_rotary_emb
|
|
||||||
else:
|
|
||||||
# For other platforms, use PyTorch fallback
|
|
||||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
|
||||||
apply_rotary_emb_torch,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
apply_rotary_emb = partial(apply_rotary_emb_torch, is_neox_style=True)
|
q_embed = apply_rotary_emb(q, cos, sin)
|
||||||
|
k_embed = apply_rotary_emb(k, cos, sin)
|
||||||
|
|
||||||
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
|
|
||||||
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
|
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from typing import Annotated, Literal
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange
|
||||||
from transformers import BatchFeature, PretrainedConfig
|
from transformers import BatchFeature, PretrainedConfig
|
||||||
from transformers.activations import GELUActivation
|
from transformers.activations import GELUActivation
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
@@ -47,7 +47,7 @@ from vllm.model_executor.layers.linear import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||||
dispatch_rotary_emb_function,
|
ApplyRotaryEmb,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
default_weight_loader,
|
default_weight_loader,
|
||||||
@@ -130,47 +130,6 @@ def smart_resize(
|
|||||||
return h_bar, w_bar
|
return h_bar, w_bar
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
|
|
||||||
if not interleaved:
|
|
||||||
x1, x2 = x.chunk(2, dim=-1)
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
|
||||||
return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb_torch(
|
|
||||||
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
x: (batch_size, seqlen, nheads, headdim)
|
|
||||||
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
|
||||||
"""
|
|
||||||
ro_dim = cos.shape[-1] * 2
|
|
||||||
assert ro_dim <= x.shape[-1]
|
|
||||||
cos = repeat(
|
|
||||||
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
|
||||||
)
|
|
||||||
sin = repeat(
|
|
||||||
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
|
||||||
)
|
|
||||||
return torch.cat(
|
|
||||||
[
|
|
||||||
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
|
||||||
x[..., ro_dim:],
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
|
||||||
rotary_emb_function = dispatch_rotary_emb_function(default=apply_rotary_emb_torch)
|
|
||||||
t_ = t.float()
|
|
||||||
cos = freqs.cos()
|
|
||||||
sin = freqs.sin()
|
|
||||||
output = rotary_emb_function(t_, cos, sin).type_as(t)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
|
class PaddleOCRVLProcessingInfo(BaseProcessingInfo):
|
||||||
def get_hf_config(self):
|
def get_hf_config(self):
|
||||||
return self.ctx.get_hf_config()
|
return self.ctx.get_hf_config()
|
||||||
@@ -609,6 +568,10 @@ class SiglipAttention(nn.Module):
|
|||||||
multimodal_config=multimodal_config,
|
multimodal_config=multimodal_config,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
)
|
)
|
||||||
|
self.apply_rotary_emb = ApplyRotaryEmb(
|
||||||
|
enforce_enable=True,
|
||||||
|
enable_fp32_compute=True,
|
||||||
|
)
|
||||||
|
|
||||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||||
seq_len, bs, _ = qkv.shape
|
seq_len, bs, _ = qkv.shape
|
||||||
@@ -651,7 +614,11 @@ class SiglipAttention(nn.Module):
|
|||||||
|
|
||||||
if rotary_pos_emb is not None:
|
if rotary_pos_emb is not None:
|
||||||
qk_concat = torch.cat([q, k], dim=0)
|
qk_concat = torch.cat([q, k], dim=0)
|
||||||
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
qk_rotated = self.apply_rotary_emb(
|
||||||
|
qk_concat,
|
||||||
|
rotary_pos_emb.cos(),
|
||||||
|
rotary_pos_emb.sin(),
|
||||||
|
)
|
||||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||||
|
|
||||||
context_layer = self.attn(
|
context_layer = self.attn(
|
||||||
|
|||||||
@@ -60,6 +60,9 @@ from vllm.model_executor.layers.linear import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||||
|
ApplyRotaryEmb,
|
||||||
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.model_executor.models.vision import should_torch_compile_mm_vit
|
from vllm.model_executor.models.vision import should_torch_compile_mm_vit
|
||||||
@@ -95,7 +98,6 @@ from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
|
|||||||
from .qwen2_vl import (
|
from .qwen2_vl import (
|
||||||
Qwen2VLMultiModalProcessor,
|
Qwen2VLMultiModalProcessor,
|
||||||
Qwen2VLProcessingInfo,
|
Qwen2VLProcessingInfo,
|
||||||
apply_rotary_pos_emb_vision,
|
|
||||||
)
|
)
|
||||||
from .utils import (
|
from .utils import (
|
||||||
AutoWeightsLoader,
|
AutoWeightsLoader,
|
||||||
@@ -353,6 +355,8 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
multimodal_config=multimodal_config,
|
multimodal_config=multimodal_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@@ -378,8 +382,10 @@ class Qwen2_5_VisionAttention(nn.Module):
|
|||||||
qk_reshaped = einops.rearrange(
|
qk_reshaped = einops.rearrange(
|
||||||
qk, "b s two head head_dim -> (two b) s head head_dim", two=2
|
qk, "b s two head head_dim -> (two b) s head head_dim", two=2
|
||||||
)
|
)
|
||||||
qk_rotated = apply_rotary_pos_emb_vision(
|
qk_rotated = self.apply_rotary_emb(
|
||||||
qk_reshaped, cos=rotary_pos_emb_cos, sin=rotary_pos_emb_sin
|
qk_reshaped,
|
||||||
|
rotary_pos_emb_cos,
|
||||||
|
rotary_pos_emb_sin,
|
||||||
)
|
)
|
||||||
qk_rotated = qk_rotated.view(
|
qk_rotated = qk_rotated.view(
|
||||||
2,
|
2,
|
||||||
|
|||||||
@@ -59,8 +59,7 @@ from vllm.model_executor.layers.linear import (
|
|||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.rotary_embedding.common import (
|
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||||
apply_rotary_emb_torch,
|
ApplyRotaryEmb,
|
||||||
dispatch_rotary_emb_function,
|
|
||||||
)
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
@@ -280,16 +279,6 @@ class Qwen2VisionMLP(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb_vision(
|
|
||||||
t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
|
|
||||||
) -> torch.Tensor:
|
|
||||||
rotary_emb_function = dispatch_rotary_emb_function(
|
|
||||||
default=partial(apply_rotary_emb_torch, is_neox_style=True)
|
|
||||||
)
|
|
||||||
output = rotary_emb_function(t, cos, sin).type_as(t)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen2VisionAttention(nn.Module):
|
class Qwen2VisionAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -341,6 +330,8 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
multimodal_config=multimodal_config,
|
multimodal_config=multimodal_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
|
||||||
|
|
||||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||||
# [s, b, 3 * head * head_dim]
|
# [s, b, 3 * head * head_dim]
|
||||||
seq_len, bs, _ = qkv.shape
|
seq_len, bs, _ = qkv.shape
|
||||||
@@ -387,8 +378,10 @@ class Qwen2VisionAttention(nn.Module):
|
|||||||
|
|
||||||
# [2 * b, s, heads, head_dim]
|
# [2 * b, s, heads, head_dim]
|
||||||
qk_concat = torch.cat([q, k], dim=0)
|
qk_concat = torch.cat([q, k], dim=0)
|
||||||
qk_rotated = apply_rotary_pos_emb_vision(
|
qk_rotated = self.apply_rotary_emb(
|
||||||
qk_concat, rotary_pos_emb_cos, rotary_pos_emb_sin
|
qk_concat,
|
||||||
|
rotary_pos_emb_cos,
|
||||||
|
rotary_pos_emb_sin,
|
||||||
)
|
)
|
||||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||||
|
|
||||||
|
|||||||
@@ -713,17 +713,13 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
|
|||||||
mm_counts: Mapping[str, int],
|
mm_counts: Mapping[str, int],
|
||||||
) -> int:
|
) -> int:
|
||||||
target_width, target_height = self.get_image_size_with_most_features()
|
target_width, target_height = self.get_image_size_with_most_features()
|
||||||
video_soft_tokens = self.get_num_video_tokens(
|
num_video_soft_tokens = self.get_num_video_tokens(
|
||||||
image_width=target_width,
|
image_width=target_width,
|
||||||
image_height=target_height,
|
image_height=target_height,
|
||||||
num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
|
num_frames=self.get_num_frames_with_most_features(seq_len, mm_counts),
|
||||||
image_processor=None,
|
image_processor=None,
|
||||||
)
|
)
|
||||||
|
return num_video_soft_tokens
|
||||||
# NOTE: By default in Qwen3-VL, one video token is converted to
|
|
||||||
# "<{timestamp} seconds>" (on average 9.5 tokens) + vision_start_token + video_token + vision_end_token # noqa: E501
|
|
||||||
formatted_video_soft_tokens = video_soft_tokens * 12.5
|
|
||||||
return int(formatted_video_soft_tokens)
|
|
||||||
|
|
||||||
def _calculate_timestamps(
|
def _calculate_timestamps(
|
||||||
self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int
|
self, indices: list[int] | torch.Tensor, video_fps: float, merge_size: int
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ within a vision language model."""
|
|||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange, repeat
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from transformers import Siglip2VisionConfig
|
from transformers import Siglip2VisionConfig
|
||||||
@@ -26,6 +25,9 @@ from vllm.model_executor.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.rotary_embedding.common import (
|
||||||
|
ApplyRotaryEmb,
|
||||||
|
)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@@ -146,40 +148,6 @@ class Siglip2VisionEmbeddings(nn.Module):
|
|||||||
return patch_embeds
|
return patch_embeds
|
||||||
|
|
||||||
|
|
||||||
# copy from flash_attn/layers/rotary.py
|
|
||||||
def rotate_half(x, interleaved=False):
|
|
||||||
if not interleaved:
|
|
||||||
x1, x2 = x.chunk(2, dim=-1)
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
else:
|
|
||||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
|
||||||
return rearrange(
|
|
||||||
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
|
|
||||||
"""
|
|
||||||
x: (batch_size, seqlen, nheads, headdim)
|
|
||||||
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
|
||||||
"""
|
|
||||||
ro_dim = cos.shape[-1] * 2
|
|
||||||
assert ro_dim <= x.shape[-1]
|
|
||||||
cos = repeat(
|
|
||||||
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
|
||||||
)
|
|
||||||
sin = repeat(
|
|
||||||
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
|
||||||
)
|
|
||||||
return torch.cat(
|
|
||||||
[
|
|
||||||
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
|
||||||
x[..., ro_dim:],
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(
|
def apply_rotary_pos_emb(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
@@ -189,14 +157,20 @@ def apply_rotary_pos_emb(
|
|||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
cos = cos.chunk(2, dim=-1)[0].contiguous()
|
||||||
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
sin = sin.chunk(2, dim=-1)[0].contiguous()
|
||||||
if is_flash_attn_backend and current_platform.is_cuda():
|
|
||||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
|
||||||
|
|
||||||
apply_rotary_emb_func = apply_rotary_emb
|
apply_rotary_emb = ApplyRotaryEmb(
|
||||||
|
enforce_enable=True,
|
||||||
|
enable_fp32_compute=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_flash_attn_backend and not current_platform.is_cuda():
|
||||||
|
apply_rotary_emb_func = apply_rotary_emb.forward_cuda
|
||||||
else:
|
else:
|
||||||
apply_rotary_emb_func = apply_rotary_emb_torch
|
apply_rotary_emb_func = apply_rotary_emb.forward_native
|
||||||
q_embed = apply_rotary_emb_func(q.float(), cos.float(), sin.float()).type_as(q)
|
|
||||||
k_embed = apply_rotary_emb_func(k.float(), cos.float(), sin.float()).type_as(k)
|
q_embed = apply_rotary_emb_func(q, cos, sin)
|
||||||
|
k_embed = apply_rotary_emb_func(k, cos, sin)
|
||||||
|
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import torch
|
|||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||||
from vllm.config import VllmConfig, get_current_vllm_config
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import (
|
from vllm.distributed import (
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
@@ -88,16 +88,10 @@ def get_vit_attn_backend(
|
|||||||
"""
|
"""
|
||||||
Get the available attention backend for Vision Transformer.
|
Get the available attention backend for Vision Transformer.
|
||||||
"""
|
"""
|
||||||
attn_backend = attn_backend_override
|
|
||||||
|
|
||||||
selected_backend = get_current_vllm_config().attention_config.backend
|
|
||||||
if attn_backend is None:
|
|
||||||
attn_backend = selected_backend
|
|
||||||
|
|
||||||
return current_platform.get_vit_attn_backend(
|
return current_platform.get_vit_attn_backend(
|
||||||
head_size,
|
head_size,
|
||||||
dtype,
|
dtype,
|
||||||
backend=attn_backend,
|
backend=attn_backend_override,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
|||||||
from collections import UserDict, defaultdict
|
from collections import UserDict, defaultdict
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import cached_property, partial
|
||||||
from itertools import accumulate
|
from itertools import accumulate
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
@@ -169,11 +169,42 @@ class PlaceholderRange:
|
|||||||
between `offset` and `offset + length` to assign embeddings to.
|
between `offset` and `offset + length` to assign embeddings to.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_num_embeds(self) -> int:
|
@cached_property
|
||||||
|
def embeds_cumsum(self) -> torch.Tensor | None:
|
||||||
if self.is_embed is None:
|
if self.is_embed is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self.is_embed.cumsum(dim=0)
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def get_num_embeds(self) -> int:
|
||||||
|
if self.embeds_cumsum is None:
|
||||||
return self.length
|
return self.length
|
||||||
|
|
||||||
return int(self.is_embed.sum().item())
|
return int(self.embeds_cumsum[-1])
|
||||||
|
|
||||||
|
def get_embeds_indices_in_range(
|
||||||
|
self, start_idx: int, end_idx: int
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Returns the starting and ending indices of the embeddings of encoder outputs
|
||||||
|
in the range of [start_idx, end_idx) in the placeholders.
|
||||||
|
|
||||||
|
For example, given:
|
||||||
|
PlaceholderRange(offset=2, length=5, is_embed=[False, True, False, True, True])
|
||||||
|
|
||||||
|
If start_idx=3 and end_idx=5, the output is (1, 3) because we want to get
|
||||||
|
the second and the third embeddings from the encoder output.
|
||||||
|
"""
|
||||||
|
if self.embeds_cumsum is None:
|
||||||
|
return start_idx, end_idx
|
||||||
|
|
||||||
|
embeds_start_idx = (
|
||||||
|
int(self.embeds_cumsum[start_idx - 1]) if start_idx > 0 else 0
|
||||||
|
)
|
||||||
|
embeds_end_idx = int(self.embeds_cumsum[end_idx - 1])
|
||||||
|
|
||||||
|
return embeds_start_idx, embeds_end_idx
|
||||||
|
|
||||||
def extract_embeds_range(self) -> list[tuple[int, int]]:
|
def extract_embeds_range(self) -> list[tuple[int, int]]:
|
||||||
"""Extract the start and end indices of the embedded region in prompt.
|
"""Extract the start and end indices of the embedded region in prompt.
|
||||||
@@ -188,7 +219,7 @@ class PlaceholderRange:
|
|||||||
Returns full placeholder range if `is_embed` is `None`.
|
Returns full placeholder range if `is_embed` is `None`.
|
||||||
"""
|
"""
|
||||||
if self.is_embed is None:
|
if self.is_embed is None:
|
||||||
return [(self.offset, self.offset + self.length)]
|
return [(self.offset, self.offset + self.length - 1)]
|
||||||
|
|
||||||
mask_i = self.is_embed.int()
|
mask_i = self.is_embed.int()
|
||||||
starts = torch.nonzero(
|
starts = torch.nonzero(
|
||||||
|
|||||||
@@ -274,15 +274,11 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
def _get_mm_num_tokens(
|
def _get_mm_num_tokens(
|
||||||
self,
|
self,
|
||||||
mm_inputs: MultiModalInputs,
|
mm_inputs: MultiModalInputs,
|
||||||
mm_embeddings_only: bool = True,
|
|
||||||
) -> Mapping[str, int]:
|
) -> Mapping[str, int]:
|
||||||
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
modality: sum(
|
modality: sum(item.get_num_embeds for item in placeholders)
|
||||||
item.get_num_embeds() if mm_embeddings_only else item.length
|
|
||||||
for item in placeholders
|
|
||||||
)
|
|
||||||
for modality, placeholders in placeholders_by_modality.items()
|
for modality, placeholders in placeholders_by_modality.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -328,12 +324,15 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
multi_modal_placeholders=mm_inputs["mm_placeholders"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_mm_max_tokens(
|
def get_mm_max_tokens(
|
||||||
self,
|
self,
|
||||||
seq_len: int,
|
seq_len: int,
|
||||||
mm_counts: Mapping[str, int] | None = None,
|
mm_counts: Mapping[str, int] | None = None,
|
||||||
mm_embeddings_only: bool = True,
|
|
||||||
) -> Mapping[str, int]:
|
) -> Mapping[str, int]:
|
||||||
|
"""
|
||||||
|
Returns the maximum number of embeddings per item of each modality, excluding
|
||||||
|
any break/text tokens in-between multimodal embeddings/encoder outputs.
|
||||||
|
"""
|
||||||
if mm_counts is None:
|
if mm_counts is None:
|
||||||
mm_counts = self.get_mm_limits()
|
mm_counts = self.get_mm_limits()
|
||||||
|
|
||||||
@@ -349,21 +348,4 @@ class MultiModalProfiler(Generic[_I]):
|
|||||||
}
|
}
|
||||||
|
|
||||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||||
return self._get_mm_num_tokens(mm_inputs, mm_embeddings_only=mm_embeddings_only)
|
return self._get_mm_num_tokens(mm_inputs)
|
||||||
|
|
||||||
def get_mm_max_contiguous_tokens(
|
|
||||||
self,
|
|
||||||
seq_len: int,
|
|
||||||
mm_counts: Mapping[str, int] | None = None,
|
|
||||||
) -> Mapping[str, int]:
|
|
||||||
"""
|
|
||||||
Returns the maximum length of the multimodal (image placeholders+text)
|
|
||||||
tokens, including any break/text tokens in-between image embeddings.
|
|
||||||
|
|
||||||
`<im_start> [IMG] [IMG] [IMG] <row_break> [IMG] [IMG] [IMG] <im_end>`
|
|
||||||
Returns 9, even when the number of image embeddings is 6.
|
|
||||||
|
|
||||||
This is important to take into account when profiling and
|
|
||||||
initializing the encoder cache size.
|
|
||||||
"""
|
|
||||||
return self._get_mm_max_tokens(seq_len, mm_counts, mm_embeddings_only=False)
|
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ class MultiModalRegistry:
|
|||||||
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
|
profiler.get_mm_limits() if profiler_limits is None else profiler_limits
|
||||||
)
|
)
|
||||||
|
|
||||||
return profiler.get_mm_max_contiguous_tokens(
|
return profiler.get_mm_max_tokens(
|
||||||
seq_len,
|
seq_len,
|
||||||
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
|
{modality: 1 for modality, limit in profiler_limits.items() if limit > 0},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -76,6 +76,39 @@ class TritonAttentionMetadata:
|
|||||||
# Optional aot scheduling
|
# Optional aot scheduling
|
||||||
scheduler_metadata: torch.Tensor | None = None
|
scheduler_metadata: torch.Tensor | None = None
|
||||||
prefix_scheduler_metadata: torch.Tensor | None = None
|
prefix_scheduler_metadata: torch.Tensor | None = None
|
||||||
|
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mm_prefix_range_tensor(self) -> torch.Tensor | None:
|
||||||
|
"""Convert mm_prefix_range dict to padded tensor for Triton kernel.
|
||||||
|
|
||||||
|
Returns shape: (num_seqs, max_ranges, 2) with 0-padding for empty ranges.
|
||||||
|
Empty ranges have start==end==0, which kernel skips via is_valid check.
|
||||||
|
"""
|
||||||
|
# TODO(Isotr0py): Move to model runner's attention metadata
|
||||||
|
# preparation to avoid duplicate computation.
|
||||||
|
if self.mm_prefix_range is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
num_seqs = self.seq_lens.shape[0]
|
||||||
|
device = self.seq_lens.device
|
||||||
|
|
||||||
|
# Collect ranges, using [(0,0)] for empty sequences to ensure uniform dims
|
||||||
|
range_lists = [
|
||||||
|
self.mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)] for i in range(num_seqs)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Return None if all ranges are trivial (only (0,0) placeholders)
|
||||||
|
if all(r == [(0, 0)] for r in range_lists):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create 2D tensors with shape (num_ranges, 2) for each sequence
|
||||||
|
range_tensors = [
|
||||||
|
torch.tensor(r, dtype=torch.int32, device=device).view(-1, 2)
|
||||||
|
for r in range_lists
|
||||||
|
]
|
||||||
|
|
||||||
|
return torch.nested.nested_tensor(range_tensors).to_padded_tensor(0)
|
||||||
|
|
||||||
|
|
||||||
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
|
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||||
@@ -268,6 +301,10 @@ class TritonAttentionBackend(AttentionBackend):
|
|||||||
def supports_head_size(cls, head_size: int) -> bool:
|
def supports_head_size(cls, head_size: int) -> bool:
|
||||||
return head_size >= 32
|
return head_size >= 32
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_mm_prefix(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def supports_sink(cls) -> bool:
|
def supports_sink(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
@@ -427,6 +464,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
|
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
|
||||||
|
|
||||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
|
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
|
||||||
|
mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor
|
||||||
|
|
||||||
unified_attention(
|
unified_attention(
|
||||||
q=query[:num_actual_tokens],
|
q=query[:num_actual_tokens],
|
||||||
@@ -453,6 +491,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
softmax_segm_expsum=softmax_segm_expsum,
|
softmax_segm_expsum=softmax_segm_expsum,
|
||||||
sinks=self.sinks,
|
sinks=self.sinks,
|
||||||
output_scale=output_scale,
|
output_scale=output_scale,
|
||||||
|
mm_prefix_range=mm_prefix_range_tensor,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -39,20 +39,26 @@ class EncoderCacheManager:
|
|||||||
space for new embeddings.
|
space for new embeddings.
|
||||||
Oldest cached embeddings with no request referenced will be first evicted.
|
Oldest cached embeddings with no request referenced will be first evicted.
|
||||||
|
|
||||||
|
NOTE: The EncoderCacheManager operates on the level of multimodal embeddings
|
||||||
|
instead of encoder tokens (i.e. all tokens that represent the multimodal data
|
||||||
|
in the input sequence). This means all break/text tokens in-between multimodal
|
||||||
|
embeddings are not considered with respect to the cache size and the number
|
||||||
|
of free slots.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cache_size: Limit the size of the cache, measured by the number of
|
cache_size: Limit the size of the cache, measured by the number of
|
||||||
tokens from the input sequence.
|
encoder embeddings from the input sequence.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
cache_size: Total cache capacity in encoder tokens.
|
cache_size: Total cache capacity in encoder embeddings.
|
||||||
num_free_slots: Current available cache capacity in encoder tokens.
|
num_free_slots: Current available cache capacity in encoder embeddings.
|
||||||
num_freeable_slots: Capacity that can be immediately reclaimed by
|
num_freeable_slots: Capacity that can be immediately reclaimed by
|
||||||
evicting entries with zero references (in encoder tokens).
|
evicting entries with zero references (in encoder embeddings).
|
||||||
cached: Mapping from mm_hash to a set of request IDs that currently
|
cached: Mapping from mm_hash to a set of request IDs that currently
|
||||||
reference the cached entry. If the set is empty, the entry exists
|
reference the cached entry. If the set is empty, the entry exists
|
||||||
but is not referenced by any request and is eligible for
|
but is not referenced by any request and is eligible for
|
||||||
reclamation.
|
reclamation.
|
||||||
freeable: List of tuples (mm_hash, num_tokens) representing entries
|
freeable: List of tuples (mm_hash, num_encoder_embeds) representing entries
|
||||||
whose no current running request is needed and that can be freed to
|
whose no current running request is needed and that can be freed to
|
||||||
make space when needed.
|
make space when needed.
|
||||||
freed: List of mm_hash strings that were actually evicted since the
|
freed: List of mm_hash strings that were actually evicted since the
|
||||||
@@ -67,7 +73,7 @@ class EncoderCacheManager:
|
|||||||
# mm_hash of mm_data => ids of requests that reference the mm_data
|
# mm_hash of mm_data => ids of requests that reference the mm_data
|
||||||
self.cached: dict[str, set[str]] = {}
|
self.cached: dict[str, set[str]] = {}
|
||||||
|
|
||||||
# mm_hash of mm_data => num_encoder_tokens of the mm_data
|
# mm_hash of mm_data => num_encoder_embeds of the mm_data
|
||||||
self.freeable: OrderedDict[str, int] = OrderedDict()
|
self.freeable: OrderedDict[str, int] = OrderedDict()
|
||||||
self.freed: list[str] = []
|
self.freed: list[str] = []
|
||||||
|
|
||||||
@@ -93,8 +99,8 @@ class EncoderCacheManager:
|
|||||||
|
|
||||||
# Cached but currently not referenced by any request
|
# Cached but currently not referenced by any request
|
||||||
if not self.cached[mm_hash]:
|
if not self.cached[mm_hash]:
|
||||||
num_tokens = self.freeable.pop(mm_hash)
|
num_encoder_embeds = self.freeable.pop(mm_hash)
|
||||||
self.num_freeable_slots -= num_tokens
|
self.num_freeable_slots -= num_encoder_embeds
|
||||||
|
|
||||||
self.cached[mm_hash].add(request.request_id)
|
self.cached[mm_hash].add(request.request_id)
|
||||||
return True
|
return True
|
||||||
@@ -104,7 +110,7 @@ class EncoderCacheManager:
|
|||||||
request: Request,
|
request: Request,
|
||||||
input_id: int,
|
input_id: int,
|
||||||
encoder_compute_budget: int,
|
encoder_compute_budget: int,
|
||||||
num_tokens_to_schedule: int,
|
num_embeds_to_schedule: int,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Check if there's sufficient cache space for a multimodal input.
|
"""Check if there's sufficient cache space for a multimodal input.
|
||||||
If there is, return True and update EncoderCacheManager state.
|
If there is, return True and update EncoderCacheManager state.
|
||||||
@@ -121,9 +127,9 @@ class EncoderCacheManager:
|
|||||||
Args:
|
Args:
|
||||||
request: The request containing the multimodal input.
|
request: The request containing the multimodal input.
|
||||||
input_id: Index of the multimodal input within the request.
|
input_id: Index of the multimodal input within the request.
|
||||||
encoder_compute_budget: Number of encoder tokens allowed to be
|
encoder_compute_budget: Number of encoder embeddings allowed to be
|
||||||
computed when this method is invoked.
|
computed when this method is invoked.
|
||||||
num_tokens_to_schedule: Number of tokens already scheduled to be
|
num_embeds_to_schedule: Number of encoder embeddings already scheduled to be
|
||||||
allocated with cache space when this method is invoked.
|
allocated with cache space when this method is invoked.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -134,30 +140,30 @@ class EncoderCacheManager:
|
|||||||
Note: This method does not allocate physical memory for the encoder
|
Note: This method does not allocate physical memory for the encoder
|
||||||
output but only the state of EncoderCacheManager.
|
output but only the state of EncoderCacheManager.
|
||||||
"""
|
"""
|
||||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
num_embeds = request.get_num_encoder_embeds(input_id)
|
||||||
|
|
||||||
# Not enough compute budget
|
# Not enough compute budget
|
||||||
if num_tokens > encoder_compute_budget:
|
if num_embeds > encoder_compute_budget:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
num_tokens += num_tokens_to_schedule
|
num_embeds += num_embeds_to_schedule
|
||||||
|
|
||||||
# Enough free slots
|
# Enough free slots
|
||||||
if num_tokens <= self.num_free_slots:
|
if num_embeds <= self.num_free_slots:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Not enough reclaimable slots
|
# Not enough reclaimable slots
|
||||||
if num_tokens > self.num_freeable_slots:
|
if num_embeds > self.num_freeable_slots:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# Not enough free slots but enough reclaimable slots
|
# Not enough free slots but enough reclaimable slots
|
||||||
# NOTE: Eviction takes place here, but physical memory is not freed
|
# NOTE: Eviction takes place here, but physical memory is not freed
|
||||||
# until model runner is notified by the scheduler output.
|
# until model runner is notified by the scheduler output.
|
||||||
while num_tokens > self.num_free_slots:
|
while num_embeds > self.num_free_slots:
|
||||||
mm_hash, num_free_token = self.freeable.popitem(last=False)
|
mm_hash, num_free_embeds = self.freeable.popitem(last=False)
|
||||||
del self.cached[mm_hash]
|
del self.cached[mm_hash]
|
||||||
self.freed.append(mm_hash)
|
self.freed.append(mm_hash)
|
||||||
self.num_free_slots += num_free_token
|
self.num_free_slots += num_free_embeds
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def allocate(self, request: Request, input_id: int) -> None:
|
def allocate(self, request: Request, input_id: int) -> None:
|
||||||
@@ -176,16 +182,16 @@ class EncoderCacheManager:
|
|||||||
if mm_hash not in self.cached:
|
if mm_hash not in self.cached:
|
||||||
self.cached[mm_hash] = set()
|
self.cached[mm_hash] = set()
|
||||||
|
|
||||||
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
|
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||||
|
|
||||||
# NOTE: Encoder cache should always have enough space for encoder inputs
|
# NOTE: Encoder cache should always have enough space for encoder inputs
|
||||||
# that are scheduled since eviction takes place at can_allocate().
|
# that are scheduled since eviction takes place at can_allocate().
|
||||||
assert self.num_free_slots >= num_encoder_tokens
|
assert self.num_free_slots >= num_encoder_embeds
|
||||||
assert self.num_freeable_slots >= num_encoder_tokens
|
assert self.num_freeable_slots >= num_encoder_embeds
|
||||||
|
|
||||||
self.cached[mm_hash].add(request_id)
|
self.cached[mm_hash].add(request_id)
|
||||||
self.num_free_slots -= num_encoder_tokens
|
self.num_free_slots -= num_encoder_embeds
|
||||||
self.num_freeable_slots -= num_encoder_tokens
|
self.num_freeable_slots -= num_encoder_embeds
|
||||||
|
|
||||||
def get_cached_input_ids(self, request: Request) -> set[int]:
|
def get_cached_input_ids(self, request: Request) -> set[int]:
|
||||||
"""Get all cached multimodal input IDs for a request.
|
"""Get all cached multimodal input IDs for a request.
|
||||||
@@ -206,7 +212,7 @@ class EncoderCacheManager:
|
|||||||
|
|
||||||
When the reference set for the corresponding `mm_hash` becomes empty,
|
When the reference set for the corresponding `mm_hash` becomes empty,
|
||||||
the entry is appended to `freeable` and `num_freeable_slots` is
|
the entry is appended to `freeable` and `num_freeable_slots` is
|
||||||
increased by the number of encoder tokens for that input.
|
increased by the number of encoder embeddings for that input.
|
||||||
|
|
||||||
The entry is NOT physically freed until capacity is needed (e.g., by
|
The entry is NOT physically freed until capacity is needed (e.g., by
|
||||||
`can_allocate`).
|
`can_allocate`).
|
||||||
@@ -218,9 +224,9 @@ class EncoderCacheManager:
|
|||||||
return
|
return
|
||||||
self.cached[mm_hash].discard(req_id)
|
self.cached[mm_hash].discard(req_id)
|
||||||
if not self.cached[mm_hash]:
|
if not self.cached[mm_hash]:
|
||||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||||
self.freeable[mm_hash] = num_tokens
|
self.freeable[mm_hash] = num_encoder_embeds
|
||||||
self.num_freeable_slots += num_tokens
|
self.num_freeable_slots += num_encoder_embeds
|
||||||
|
|
||||||
def free(self, request: Request) -> None:
|
def free(self, request: Request) -> None:
|
||||||
"""Free all encoder input cache reference held by *request*.
|
"""Free all encoder input cache reference held by *request*.
|
||||||
@@ -361,20 +367,20 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
|
|||||||
request: Request,
|
request: Request,
|
||||||
input_id: int,
|
input_id: int,
|
||||||
encoder_compute_budget: int,
|
encoder_compute_budget: int,
|
||||||
num_tokens_to_schedule: int,
|
num_embeds_to_schedule: int,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||||
# Not enough compute budget
|
# Not enough compute budget
|
||||||
if num_tokens > encoder_compute_budget:
|
if num_encoder_embeds > encoder_compute_budget:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
num_tokens += num_tokens_to_schedule
|
num_encoder_embeds += num_embeds_to_schedule
|
||||||
# Enough free slots
|
# Enough free slots
|
||||||
return num_tokens <= self.num_free_slots
|
return num_encoder_embeds <= self.num_free_slots
|
||||||
|
|
||||||
def allocate(self, request: Request, input_id: int) -> None:
|
def allocate(self, request: Request, input_id: int) -> None:
|
||||||
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
|
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||||
self.num_free_slots -= num_encoder_tokens
|
self.num_free_slots -= num_encoder_embeds
|
||||||
|
|
||||||
mm_hash = request.mm_features[input_id].identifier
|
mm_hash = request.mm_features[input_id].identifier
|
||||||
self.freed.append(mm_hash)
|
self.freed.append(mm_hash)
|
||||||
@@ -392,5 +398,5 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
|
|||||||
return freed
|
return freed
|
||||||
|
|
||||||
def free_encoder_input(self, request: Request, input_id: int) -> None:
|
def free_encoder_input(self, request: Request, input_id: int) -> None:
|
||||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||||
self.num_free_slots += num_tokens
|
self.num_free_slots += num_encoder_embeds
|
||||||
|
|||||||
@@ -349,11 +349,11 @@ class Scheduler(SchedulerInterface):
|
|||||||
if preempted_encoder_inputs:
|
if preempted_encoder_inputs:
|
||||||
# Restore encoder compute budget if the preempted
|
# Restore encoder compute budget if the preempted
|
||||||
# request had encoder inputs scheduled in this step.
|
# request had encoder inputs scheduled in this step.
|
||||||
num_tokens_to_restore = sum(
|
num_embeds_to_restore = sum(
|
||||||
preempted_req.get_num_encoder_tokens(i)
|
preempted_req.get_num_encoder_embeds(i)
|
||||||
for i in preempted_encoder_inputs
|
for i in preempted_encoder_inputs
|
||||||
)
|
)
|
||||||
encoder_compute_budget += num_tokens_to_restore
|
encoder_compute_budget += num_embeds_to_restore
|
||||||
req_index -= 1
|
req_index -= 1
|
||||||
else:
|
else:
|
||||||
preempted_req = self.running.pop()
|
preempted_req = self.running.pop()
|
||||||
@@ -911,10 +911,11 @@ class Scheduler(SchedulerInterface):
|
|||||||
# multiple encoder inputs per request), we need to create temporary
|
# multiple encoder inputs per request), we need to create temporary
|
||||||
# trackers for accounting at the encoder input level.
|
# trackers for accounting at the encoder input level.
|
||||||
mm_hashes_to_schedule = set()
|
mm_hashes_to_schedule = set()
|
||||||
num_tokens_to_schedule = 0
|
num_embeds_to_schedule = 0
|
||||||
for i, mm_feature in enumerate(mm_features):
|
for i, mm_feature in enumerate(mm_features):
|
||||||
start_pos = mm_feature.mm_position.offset
|
start_pos = mm_feature.mm_position.offset
|
||||||
num_encoder_tokens = mm_feature.mm_position.length
|
num_encoder_tokens = mm_feature.mm_position.length
|
||||||
|
num_encoder_embeds = mm_feature.mm_position.get_num_embeds
|
||||||
|
|
||||||
# The encoder output is needed if the two ranges overlap:
|
# The encoder output is needed if the two ranges overlap:
|
||||||
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
|
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
|
||||||
@@ -970,9 +971,8 @@ class Scheduler(SchedulerInterface):
|
|||||||
):
|
):
|
||||||
num_new_tokens = start_pos - num_computed_tokens
|
num_new_tokens = start_pos - num_computed_tokens
|
||||||
break
|
break
|
||||||
|
|
||||||
if not self.encoder_cache_manager.can_allocate(
|
if not self.encoder_cache_manager.can_allocate(
|
||||||
request, i, encoder_compute_budget, num_tokens_to_schedule
|
request, i, encoder_compute_budget, num_embeds_to_schedule
|
||||||
):
|
):
|
||||||
# The encoder cache is full or the encoder budget is exhausted.
|
# The encoder cache is full or the encoder budget is exhausted.
|
||||||
# NOTE(woosuk): We assume that the encoder input tokens should
|
# NOTE(woosuk): We assume that the encoder input tokens should
|
||||||
@@ -992,14 +992,31 @@ class Scheduler(SchedulerInterface):
|
|||||||
num_new_tokens = 0
|
num_new_tokens = 0
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# Calculate the number of embeddings to schedule in the current range
|
||||||
|
# of scheduled encoder placholder tokens.
|
||||||
|
start_idx_rel = max(0, num_computed_tokens - start_pos)
|
||||||
|
end_idx_rel = min(
|
||||||
|
num_encoder_tokens, num_computed_tokens + num_new_tokens - start_pos
|
||||||
|
)
|
||||||
|
curr_embeds_start, curr_embeds_end = (
|
||||||
|
mm_feature.mm_position.get_embeds_indices_in_range(
|
||||||
|
start_idx_rel,
|
||||||
|
end_idx_rel,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# There's no embeddings in the current range of encoder placeholder tokens
|
||||||
|
# so we can skip the encoder input.
|
||||||
|
if curr_embeds_end - curr_embeds_start == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
if self.ec_connector is not None and remote_cache_has_item[i]:
|
if self.ec_connector is not None and remote_cache_has_item[i]:
|
||||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
||||||
external_load_encoder_input.append(i)
|
external_load_encoder_input.append(i)
|
||||||
num_tokens_to_schedule += num_encoder_tokens
|
num_embeds_to_schedule += num_encoder_embeds
|
||||||
continue
|
continue
|
||||||
|
|
||||||
num_tokens_to_schedule += num_encoder_tokens
|
num_embeds_to_schedule += num_encoder_embeds
|
||||||
encoder_compute_budget -= num_encoder_tokens
|
encoder_compute_budget -= num_encoder_embeds
|
||||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
||||||
encoder_inputs_to_schedule.append(i)
|
encoder_inputs_to_schedule.append(i)
|
||||||
|
|
||||||
|
|||||||
@@ -209,10 +209,10 @@ class Request:
|
|||||||
def get_finished_reason(self) -> FinishReason | None:
|
def get_finished_reason(self) -> FinishReason | None:
|
||||||
return RequestStatus.get_finished_reason(self.status)
|
return RequestStatus.get_finished_reason(self.status)
|
||||||
|
|
||||||
def get_num_encoder_tokens(self, input_id: int) -> int:
|
def get_num_encoder_embeds(self, input_id: int) -> int:
|
||||||
assert input_id < len(self.mm_features)
|
assert input_id < len(self.mm_features)
|
||||||
num_tokens = self.mm_features[input_id].mm_position.length
|
num_embeds = self.mm_features[input_id].mm_position.get_num_embeds
|
||||||
return num_tokens
|
return num_embeds
|
||||||
|
|
||||||
def record_event(
|
def record_event(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -169,9 +169,7 @@ from .utils import (
|
|||||||
MultiModalBudget,
|
MultiModalBudget,
|
||||||
add_kv_sharing_layers_to_kv_cache_groups,
|
add_kv_sharing_layers_to_kv_cache_groups,
|
||||||
bind_kv_cache,
|
bind_kv_cache,
|
||||||
gather_mm_placeholders,
|
|
||||||
sanity_check_mm_encoder_outputs,
|
sanity_check_mm_encoder_outputs,
|
||||||
scatter_mm_placeholders,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -2187,10 +2185,7 @@ class GPUModelRunner(
|
|||||||
|
|
||||||
# Cache the encoder outputs by mm_hash
|
# Cache the encoder outputs by mm_hash
|
||||||
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
|
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
|
||||||
self.encoder_cache[mm_hash] = scatter_mm_placeholders(
|
self.encoder_cache[mm_hash] = output
|
||||||
output,
|
|
||||||
is_embed=pos_info.is_embed,
|
|
||||||
)
|
|
||||||
logger.debug("Finish execute for mm hash %s", mm_hash)
|
logger.debug("Finish execute for mm hash %s", mm_hash)
|
||||||
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
|
self.maybe_save_ec_to_connector(self.encoder_cache, mm_hash)
|
||||||
|
|
||||||
@@ -2241,6 +2236,13 @@ class GPUModelRunner(
|
|||||||
num_encoder_tokens,
|
num_encoder_tokens,
|
||||||
)
|
)
|
||||||
assert start_idx < end_idx
|
assert start_idx < end_idx
|
||||||
|
curr_embeds_start, curr_embeds_end = (
|
||||||
|
pos_info.get_embeds_indices_in_range(start_idx, end_idx)
|
||||||
|
)
|
||||||
|
# If there are no embeddings in the current range, we skip
|
||||||
|
# gathering the embeddings.
|
||||||
|
if curr_embeds_start == curr_embeds_end:
|
||||||
|
continue
|
||||||
|
|
||||||
mm_hash = mm_feature.identifier
|
mm_hash = mm_feature.identifier
|
||||||
encoder_output = self.encoder_cache.get(mm_hash, None)
|
encoder_output = self.encoder_cache.get(mm_hash, None)
|
||||||
@@ -2248,16 +2250,14 @@ class GPUModelRunner(
|
|||||||
|
|
||||||
if (is_embed := pos_info.is_embed) is not None:
|
if (is_embed := pos_info.is_embed) is not None:
|
||||||
is_embed = is_embed[start_idx:end_idx]
|
is_embed = is_embed[start_idx:end_idx]
|
||||||
|
mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end]
|
||||||
|
else:
|
||||||
|
mm_embeds_item = encoder_output[start_idx:end_idx]
|
||||||
|
|
||||||
req_start_pos = req_start_idx + start_pos - num_computed_tokens
|
req_start_pos = req_start_idx + start_pos - num_computed_tokens
|
||||||
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
|
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
|
||||||
True if is_embed is None else is_embed
|
True if is_embed is None else is_embed
|
||||||
)
|
)
|
||||||
|
|
||||||
mm_embeds_item = gather_mm_placeholders(
|
|
||||||
encoder_output[start_idx:end_idx],
|
|
||||||
is_embed=is_embed,
|
|
||||||
)
|
|
||||||
mm_embeds_req.append(mm_embeds_item)
|
mm_embeds_req.append(mm_embeds_item)
|
||||||
|
|
||||||
if self.is_multimodal_pruning_enabled and self.uses_mrope:
|
if self.is_multimodal_pruning_enabled and self.uses_mrope:
|
||||||
@@ -4467,31 +4467,8 @@ class GPUModelRunner(
|
|||||||
dummy_encoder_outputs,
|
dummy_encoder_outputs,
|
||||||
expected_num_items=max_mm_items_per_batch,
|
expected_num_items=max_mm_items_per_batch,
|
||||||
)
|
)
|
||||||
|
for i, output in enumerate(dummy_encoder_outputs):
|
||||||
# NOTE: This happens when encoder cache needs to store
|
self.encoder_cache[f"tmp_{i}"] = output
|
||||||
# the embeddings that encoder outputs are scattered onto.
|
|
||||||
# In this case we create dummy embeddings of size
|
|
||||||
# (max_tokens_for_modality, hidden_size) and scatter
|
|
||||||
# encoder output into it.
|
|
||||||
encoder_output_shape = dummy_encoder_outputs[0].shape
|
|
||||||
max_mm_tokens_per_item = mm_budget.max_tokens_by_modality[
|
|
||||||
dummy_modality
|
|
||||||
]
|
|
||||||
if encoder_output_shape[0] < max_mm_tokens_per_item:
|
|
||||||
encoder_hidden_size = encoder_output_shape[-1]
|
|
||||||
expanded_outputs = []
|
|
||||||
for output in dummy_encoder_outputs:
|
|
||||||
expanded = output.new_zeros(
|
|
||||||
(max_mm_tokens_per_item, encoder_hidden_size)
|
|
||||||
)
|
|
||||||
num_tokens = output.shape[0]
|
|
||||||
expanded[:num_tokens].copy_(output)
|
|
||||||
expanded_outputs.append(expanded)
|
|
||||||
|
|
||||||
dummy_encoder_outputs = expanded_outputs
|
|
||||||
|
|
||||||
# Cache the dummy encoder outputs.
|
|
||||||
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
|
|
||||||
|
|
||||||
# Add `is_profile` here to pre-allocate communication buffers
|
# Add `is_profile` here to pre-allocate communication buffers
|
||||||
hidden_states, last_hidden_states = self._dummy_run(
|
hidden_states, last_hidden_states = self._dummy_run(
|
||||||
|
|||||||
@@ -4,10 +4,12 @@ from collections import defaultdict
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
from vllm.config import ModelConfig, SchedulerConfig, VllmConfig
|
||||||
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
||||||
from vllm.model_executor.models.utils import extract_layer_index
|
from vllm.model_executor.models.utils import extract_layer_index
|
||||||
from vllm.multimodal.cache import processor_only_cache_from_config
|
from vllm.multimodal.cache import processor_only_cache_from_config
|
||||||
@@ -17,6 +19,8 @@ from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
|
|||||||
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
|
||||||
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheGroupSpec, KVCacheSpec
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MultiModalBudget:
|
class MultiModalBudget:
|
||||||
"""Helper class to calculate budget information for multi-modal models."""
|
"""Helper class to calculate budget information for multi-modal models."""
|
||||||
@@ -198,6 +202,7 @@ def sanity_check_mm_encoder_outputs(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated("`scatter_mm_placeholders` is deprecated and will be removed in v0.15.0.")
|
||||||
def scatter_mm_placeholders(
|
def scatter_mm_placeholders(
|
||||||
embeds: torch.Tensor,
|
embeds: torch.Tensor,
|
||||||
is_embed: torch.Tensor | None,
|
is_embed: torch.Tensor | None,
|
||||||
@@ -226,6 +231,7 @@ def scatter_mm_placeholders(
|
|||||||
return placeholders
|
return placeholders
|
||||||
|
|
||||||
|
|
||||||
|
@deprecated("`gather_mm_placeholders` is deprecated and will be removed in v0.15.0.")
|
||||||
def gather_mm_placeholders(
|
def gather_mm_placeholders(
|
||||||
placeholders: torch.Tensor,
|
placeholders: torch.Tensor,
|
||||||
is_embed: torch.Tensor | None,
|
is_embed: torch.Tensor | None,
|
||||||
|
|||||||
@@ -145,12 +145,20 @@ class WorkspaceManager:
|
|||||||
|
|
||||||
for ubatch_id in range(self._num_ubatches):
|
for ubatch_id in range(self._num_ubatches):
|
||||||
current_workspace = self._current_workspaces[ubatch_id]
|
current_workspace = self._current_workspaces[ubatch_id]
|
||||||
if current_workspace is None:
|
if (
|
||||||
|
current_workspace is None
|
||||||
|
or self._workspace_size_bytes(current_workspace) < required_bytes
|
||||||
|
):
|
||||||
|
# Delete old tensor before allocating new one to avoid
|
||||||
|
# memory spike from resize_(). resize_() allocates new
|
||||||
|
# memory before freeing old, which can cause OOM.
|
||||||
|
# Must clear the list reference first since local var
|
||||||
|
# is just a copy of the reference.
|
||||||
|
self._current_workspaces[ubatch_id] = None
|
||||||
|
del current_workspace
|
||||||
self._current_workspaces[ubatch_id] = torch.empty(
|
self._current_workspaces[ubatch_id] = torch.empty(
|
||||||
(required_bytes,), dtype=torch.uint8, device=self._device
|
(required_bytes,), dtype=torch.uint8, device=self._device
|
||||||
)
|
)
|
||||||
elif self._workspace_size_bytes(current_workspace) < required_bytes:
|
|
||||||
current_workspace.resize_(required_bytes)
|
|
||||||
|
|
||||||
if envs.VLLM_DEBUG_WORKSPACE:
|
if envs.VLLM_DEBUG_WORKSPACE:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
Reference in New Issue
Block a user