[Kernel][Performance] Add Triton kernel for Qwen3-VL interleaved MRoPE (#25055)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2025-09-19 18:27:49 +08:00
committed by GitHub
parent a684c0124c
commit cea91a32f2
2 changed files with 85 additions and 43 deletions

View File

@@ -1,9 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import NamedTuple
import pytest
import torch
from packaging.version import Version
from transformers import AutoConfig
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform
@@ -15,6 +18,7 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
head_size: int, max_position_embeddings: int,
dtype: torch.dtype, device: torch.device):
"""Generate test data for given configuration."""
current_platform.seed_everything(42)
# Create 2D positions (3, num_tokens) for multimodal case
positions = torch.randint(0,
max_position_embeddings // 4, (3, num_tokens),
@@ -33,22 +37,37 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
return positions, query, key
def unroll_model_tp_dict(model_tp_dict):
return [(model_name, tp_size)
for model_name, tp_sizes in model_tp_dict.items()
for tp_size in tp_sizes]
class MRoPETestInfo(NamedTuple):
model_name: str
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
atol: float = 1e-2
rtol: float = 1.6e-2
marks: list[pytest.MarkDecorator] = []
model_tp_dict = {
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
"Qwen/Qwen2-VL-72B-Instruct": [1, 2],
"Qwen/Qwen2.5-VL-72B-Instruct": [1, 2],
"zai-org/GLM-4.1V-9B-Thinking": [1, 2],
}
TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
dtype_atol_rtol_list = [
[torch.bfloat16, 1e-2, 1.6e-2],
MODELS_TO_TEST = [
MRoPETestInfo(model_name="zai-org/GLM-4.1V-9B-Thinking"),
MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"),
MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"),
MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"),
MRoPETestInfo(
model_name="Qwen/Qwen3-VL-4B-Instruct",
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
reason="Qwen3-VL only available after Transformers v4.57",
)
]),
MRoPETestInfo(
model_name="Qwen/Qwen3-VL-30B-A3B-Instruct",
marks=[
pytest.mark.skipif(
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
reason="Qwen3-VL only available after Transformers v4.57",
)
]),
]
num_tokens_list = [11, 8192]
@@ -56,20 +75,29 @@ num_tokens_list = [11, 8192]
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Skipping CUDA/ROCm only tests.")
@pytest.mark.parametrize("model_name, tp_size",
unroll_model_tp_dict(model_tp_dict))
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
@pytest.mark.parametrize("model_info, model_name", [
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
for test_config in MODELS_TO_TEST
])
@pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_tokens", num_tokens_list)
def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int,
dtype: torch.dtype, num_tokens: int):
atol = model_info.atol
rtol = model_info.rtol
config = AutoConfig.from_pretrained(model_name)
config = config.get_text_config()
# get the model config
total_num_kv_heads = config.num_key_value_heads
total_num_heads = config.num_attention_heads
num_heads = total_num_heads // tp_size
num_kv_heads = max(1, total_num_kv_heads // tp_size)
head_dim = config.hidden_size // total_num_heads
head_dim = (config.head_dim if hasattr(config, "head_dim") else
config.hidden_size // total_num_heads)
is_neox_style = True
rope_theta = config.rope_theta
@@ -111,24 +139,30 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Skipping CUDA/ROCm only tests.")
@pytest.mark.parametrize(
"model_name, tp_size",
unroll_model_tp_dict({
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
"zai-org/GLM-4.1V-9B-Thinking": [1, 2]
}))
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
@pytest.mark.parametrize("num_tokens", [4])
def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
num_tokens):
@pytest.mark.parametrize("model_info, model_name", [
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
for test_config in MODELS_TO_TEST
])
@pytest.mark.parametrize("tp_size", [1, 2])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("num_tokens", num_tokens_list)
def test_mrope_torch_compile_tracing(model_name: str,
model_info: MRoPETestInfo, tp_size: int,
dtype: torch.dtype, num_tokens: int):
atol = model_info.atol
rtol = model_info.rtol
config = AutoConfig.from_pretrained(model_name)
config = config.get_text_config()
# get the model config
total_num_kv_heads = config.num_key_value_heads
total_num_heads = config.num_attention_heads
num_heads = total_num_heads // tp_size
num_kv_heads = max(1, total_num_kv_heads // tp_size)
head_dim = config.hidden_size // total_num_heads
head_dim = (config.head_dim if hasattr(config, "head_dim") else
config.hidden_size // total_num_heads)
is_neox_style = True
rope_theta = config.rope_theta
max_position = config.max_position_embeddings