Compare commits
8 Commits
v0.13.0rc3
...
v0.13.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
72506c9834 | ||
|
|
b2eb84de77 | ||
|
|
ac43367ced | ||
|
|
30fe765e9f | ||
|
|
2c0ee0fde8 | ||
|
|
55f1fc1b1b | ||
|
|
17f3988094 | ||
|
|
682c38583c |
@@ -406,6 +406,7 @@ th {
|
||||
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ |
|
||||
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ |
|
||||
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ |
|
||||
| `Jais2ForCausalLM` | Jais2 | `inceptionai/Jais-2-8B-Chat`, `inceptionai/Jais-2-70B-Chat`, etc. | | ✅︎ |
|
||||
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ |
|
||||
| `KimiLinearForCausalLM` | Kimi-Linear-48B-A3B-Base, Kimi-Linear-48B-A3B-Instruct | `moonshotai/Kimi-Linear-48B-A3B-Base`, `moonshotai/Kimi-Linear-48B-A3B-Instruct` | | ✅︎ |
|
||||
| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ |
|
||||
|
||||
@@ -233,24 +233,6 @@ def test_splitting_ops_dynamic():
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
|
||||
|
||||
def test_moe_splitting_ops_deepep_ht_piecewise():
|
||||
# Non-inductor, non-attn-fusion case: DeepEP HT with dp>1
|
||||
# should add MoE ops to splitting_ops on top of attention ops.
|
||||
config = VllmConfig(
|
||||
parallel_config=ParallelConfig(
|
||||
all2all_backend="deepep_high_throughput",
|
||||
data_parallel_size=8,
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
),
|
||||
)
|
||||
splitting_ops = config.compilation_config.splitting_ops
|
||||
assert splitting_ops is not None
|
||||
assert "vllm::moe_forward" in splitting_ops
|
||||
assert "vllm::moe_forward_shared" in splitting_ops
|
||||
|
||||
|
||||
def test_moe_splitting_ops_deepep_ht_inductor_partition():
|
||||
# Inductor partition case: user-provided splitting_ops should be
|
||||
# preserved and MoE ops should be appended for DeepEP HT with dp>1.
|
||||
@@ -277,26 +259,6 @@ def test_moe_splitting_ops_deepep_ht_inductor_partition():
|
||||
]
|
||||
|
||||
|
||||
def test_moe_splitting_ops_deepep_ht_attn_fusion_no_inductor():
|
||||
# Pure attn-fusion case without inductor partition: even with
|
||||
# DeepEP HT and dp>1, we should not re-enable piecewise compilation
|
||||
# or add MoE ops into splitting_ops.
|
||||
config = VllmConfig(
|
||||
parallel_config=ParallelConfig(
|
||||
all2all_backend="deepep_high_throughput",
|
||||
data_parallel_size=8,
|
||||
),
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
pass_config={"fuse_attn_quant": True, "eliminate_noops": True},
|
||||
custom_ops=["+quant_fp8"],
|
||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
||||
),
|
||||
)
|
||||
assert config.compilation_config.splitting_ops == []
|
||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
|
||||
|
||||
|
||||
def test_should_split():
|
||||
import torch
|
||||
|
||||
|
||||
@@ -254,7 +254,9 @@ async def test_single_chat_session_input_audio(
|
||||
async def test_chat_streaming_audio(
|
||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str
|
||||
):
|
||||
messages = dummy_messages_from_audio_url(audio_url)
|
||||
messages = dummy_messages_from_audio_url(
|
||||
audio_url, "What's a short title for this audio?"
|
||||
)
|
||||
|
||||
# test single completion
|
||||
chat_completion = await client.chat.completions.create(
|
||||
|
||||
@@ -15,7 +15,10 @@ from tests.v1.attention.utils import (
|
||||
create_standard_kv_cache_spec,
|
||||
create_vllm_config,
|
||||
)
|
||||
from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.flex_attention import (
|
||||
FlexAttentionMetadataBuilder,
|
||||
physical_to_logical_mapping,
|
||||
)
|
||||
|
||||
from ..models.utils import check_embeddings_close, check_logprobs_close
|
||||
|
||||
@@ -212,5 +215,31 @@ def test_block_mask_direct_vs_slow_path():
|
||||
)
|
||||
|
||||
|
||||
def test_physical_to_logical_mapping_handles_reused_blocks():
|
||||
"""Regression test: reused physical blocks map to the latest logical block.
|
||||
|
||||
For sliding-window / hybrid attention layers, physical KV-cache blocks can be
|
||||
reused over time. The inverse mapping must therefore select the latest
|
||||
logical block index for a physical block id.
|
||||
"""
|
||||
# Padding should not make physical block 0 look live.
|
||||
block_table = torch.tensor([[6, 0, 0, 0]], dtype=torch.int32)
|
||||
seq_lens = torch.tensor([1 * 16], dtype=torch.int32) # only 1 block valid
|
||||
out = physical_to_logical_mapping(
|
||||
block_table=block_table, seq_lens=seq_lens, block_size=16, total_blocks=10
|
||||
)
|
||||
assert out[0, 0].item() == -1
|
||||
assert out[0, 6].item() == 0
|
||||
|
||||
# If a physical block id appears multiple times (block reuse), mapping should
|
||||
# point to the latest logical block index.
|
||||
block_table2 = torch.tensor([[2, 2, 5]], dtype=torch.int32)
|
||||
seq_lens2 = torch.tensor([3 * 16], dtype=torch.int32)
|
||||
out2 = physical_to_logical_mapping(
|
||||
block_table=block_table2, seq_lens=seq_lens2, block_size=16, total_blocks=8
|
||||
)
|
||||
assert out2[0, 2].item() == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Literal, NamedTuple
|
||||
import os
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import pytest
|
||||
from huggingface_hub import hf_hub_download
|
||||
from pytest import MarkDecorator
|
||||
from transformers import AutoModelForImageTextToText
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.multimodal.image import rescale_image_size
|
||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||
|
||||
from ....conftest import PromptImageInput, VllmRunner
|
||||
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner
|
||||
from ...utils import check_logprobs_close
|
||||
|
||||
|
||||
@@ -21,9 +27,10 @@ class GGUFMMTestConfig(NamedTuple):
|
||||
gguf_backbone: str
|
||||
gguf_mmproj: str
|
||||
prompt: list[str]
|
||||
mm_data: dict[Literal["images"], PromptImageInput]
|
||||
image_names: list[str] # Store names, load PIL images at runtime
|
||||
max_model_len: int = 4096
|
||||
marks: list[MarkDecorator] = []
|
||||
mm_processor_kwargs: dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def gguf_model(self):
|
||||
@@ -31,27 +38,75 @@ class GGUFMMTestConfig(NamedTuple):
|
||||
return hf_hub_download(self.gguf_repo, filename=self.gguf_backbone)
|
||||
|
||||
|
||||
# Common prompts aligned with test_common.py "gemma3" entry format
|
||||
_GEMMA3_PROMPTS = IMAGE_ASSETS.prompts(
|
||||
{
|
||||
"stop_sign": (
|
||||
"<bos><start_of_turn>user\n"
|
||||
"<start_of_image>What's the content in the center of the image?"
|
||||
"<end_of_turn>\n<start_of_turn>model\n"
|
||||
),
|
||||
"cherry_blossom": (
|
||||
"<bos><start_of_turn>user\n"
|
||||
"<start_of_image>What is the season?"
|
||||
"<end_of_turn>\n<start_of_turn>model\n"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Image asset names - load at runtime to avoid pickle issues with subprocess
|
||||
_GEMMA3_IMAGE_NAMES = ["stop_sign", "cherry_blossom"]
|
||||
|
||||
# Regular multimodal (no pan-and-scan) - uses QAT Q4_0 GGUF
|
||||
GEMMA3_CONFIG = GGUFMMTestConfig(
|
||||
original_model="google/gemma-3-4b-it",
|
||||
gguf_repo="google/gemma-3-4b-it-qat-q4_0-gguf",
|
||||
gguf_backbone="gemma-3-4b-it-q4_0.gguf",
|
||||
gguf_mmproj="mmproj-model-f16-4B.gguf",
|
||||
prompt=["<start_of_image>Describe this image in detail:"],
|
||||
mm_data={"images": [ImageAsset("stop_sign").pil_image]},
|
||||
prompt=_GEMMA3_PROMPTS,
|
||||
image_names=_GEMMA3_IMAGE_NAMES,
|
||||
max_model_len=4096,
|
||||
marks=[pytest.mark.core_model],
|
||||
mm_processor_kwargs={},
|
||||
)
|
||||
|
||||
MODELS_TO_TEST = [GEMMA3_CONFIG]
|
||||
# Pan-and-scan multimodal - uses unquantized BF16 GGUF
|
||||
GEMMA3_CONFIG_PAN_AND_SCAN = GGUFMMTestConfig(
|
||||
original_model="google/gemma-3-4b-it",
|
||||
gguf_repo="unsloth/gemma-3-4b-it-GGUF",
|
||||
gguf_backbone="gemma-3-4b-it-BF16.gguf",
|
||||
gguf_mmproj="mmproj-BF16.gguf",
|
||||
prompt=_GEMMA3_PROMPTS,
|
||||
image_names=_GEMMA3_IMAGE_NAMES,
|
||||
max_model_len=4096,
|
||||
marks=[pytest.mark.core_model],
|
||||
mm_processor_kwargs={"do_pan_and_scan": True},
|
||||
)
|
||||
|
||||
MODELS_TO_TEST = [GEMMA3_CONFIG, GEMMA3_CONFIG_PAN_AND_SCAN]
|
||||
|
||||
|
||||
def run_multimodal_gguf_test(
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
model: GGUFMMTestConfig,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
):
|
||||
# Run gguf model.
|
||||
# Load images at runtime (inside subprocess) to avoid pickle issues
|
||||
images = [ImageAsset(name).pil_image for name in model.image_names]
|
||||
size_factors = [0.25, 0.5, 1.0]
|
||||
inputs_per_image = [
|
||||
(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
)
|
||||
for image, prompt in zip(images, model.prompt)
|
||||
]
|
||||
|
||||
# NOTE: Run vLLM first to avoid CUDA init issues with multiprocessing fork.
|
||||
# Run GGUF model via vLLM.
|
||||
with (
|
||||
set_default_torch_num_threads(1),
|
||||
vllm_runner(
|
||||
@@ -60,35 +115,42 @@ def run_multimodal_gguf_test(
|
||||
tokenizer_name=model.original_model,
|
||||
dtype=dtype,
|
||||
max_model_len=model.max_model_len,
|
||||
mm_processor_kwargs=model.mm_processor_kwargs,
|
||||
) as gguf_model,
|
||||
):
|
||||
gguf_outputs = gguf_model.generate_greedy_logprobs(
|
||||
prompts=model.prompt,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
**model.mm_data,
|
||||
)
|
||||
gguf_outputs_per_case = [
|
||||
gguf_model.generate_greedy_logprobs(
|
||||
prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images,
|
||||
)
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
# Run unquantized model.
|
||||
with vllm_runner(
|
||||
model_name=model.original_model,
|
||||
enforce_eager=True, # faster tests
|
||||
# Then run HfRunner for HuggingFace baseline comparison.
|
||||
with hf_runner(
|
||||
model.original_model,
|
||||
dtype=dtype,
|
||||
max_model_len=model.max_model_len,
|
||||
) as original_model:
|
||||
original_outputs = original_model.generate_greedy_logprobs(
|
||||
prompts=model.prompt,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
**model.mm_data,
|
||||
)
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
) as hf_model:
|
||||
hf_outputs_per_case = [
|
||||
hf_model.generate_greedy_logprobs_limit(
|
||||
prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images,
|
||||
)
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=original_outputs,
|
||||
outputs_1_lst=gguf_outputs,
|
||||
name_0="original",
|
||||
name_1="gguf",
|
||||
)
|
||||
for hf_outputs, gguf_outputs in zip(hf_outputs_per_case, gguf_outputs_per_case):
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=gguf_outputs,
|
||||
name_0="hf",
|
||||
name_1="gguf",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@@ -105,11 +167,14 @@ def run_multimodal_gguf_test(
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [32])
|
||||
@pytest.mark.parametrize("num_logprobs", [10])
|
||||
def test_models(
|
||||
def test_gemma3_mm_gguf(
|
||||
hf_runner: type[HfRunner],
|
||||
vllm_runner: type[VllmRunner],
|
||||
model: GGUFMMTestConfig,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
run_multimodal_gguf_test(vllm_runner, model, dtype, max_tokens, num_logprobs)
|
||||
run_multimodal_gguf_test(
|
||||
hf_runner, vllm_runner, model, dtype, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
@@ -295,6 +295,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"internlm/internlm3-8b-instruct", trust_remote_code=True
|
||||
),
|
||||
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
|
||||
"Jais2ForCausalLM": _HfExamplesInfo(
|
||||
"inceptionai/Jais-2-8B-Chat", min_transformers_version="4.58"
|
||||
),
|
||||
"JambaForCausalLM": _HfExamplesInfo(
|
||||
"ai21labs/AI21-Jamba-1.5-Mini",
|
||||
extras={
|
||||
|
||||
@@ -86,6 +86,9 @@ def kernel_unified_attention_2d(
|
||||
USE_SOFTCAP: tl.constexpr, # bool
|
||||
USE_SINKS: tl.constexpr, # bool
|
||||
SLIDING_WINDOW: tl.constexpr, # int
|
||||
USE_MM_PREFIX: tl.constexpr, # bool
|
||||
MAX_MM_RANGES: tl.constexpr, # int
|
||||
mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence
|
||||
stride_k_cache_0: tl.int64, # int
|
||||
stride_k_cache_1: tl.int64, # int
|
||||
stride_k_cache_2: tl.int64, # int
|
||||
@@ -270,7 +273,38 @@ def kernel_unified_attention_2d(
|
||||
else:
|
||||
V = V_load
|
||||
|
||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
||||
# Compute attention mask: causal by default (key <= query)
|
||||
query_abs_pos = context_len + query_pos[:, None]
|
||||
seq_mask = seq_offset[None, :] <= query_abs_pos
|
||||
|
||||
# Apply sliding window to base mask BEFORE mm_prefix OR.
|
||||
# Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix
|
||||
if SLIDING_WINDOW > 0:
|
||||
seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW)
|
||||
|
||||
# PrefixLM: extend mask with bidirectional ranges for multimodal tokens.
|
||||
# Applied AFTER sliding window so mm_prefix ranges override SW restriction.
|
||||
if USE_MM_PREFIX:
|
||||
for i in range(MAX_MM_RANGES):
|
||||
range_start = tl.load(
|
||||
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2
|
||||
)
|
||||
range_end = tl.load(
|
||||
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + 1
|
||||
)
|
||||
|
||||
is_valid = range_start < range_end
|
||||
q_in_range = (
|
||||
(query_abs_pos >= range_start)
|
||||
& (query_abs_pos <= range_end)
|
||||
& is_valid
|
||||
)
|
||||
k_in_range = (
|
||||
(seq_offset[None, :] >= range_start)
|
||||
& (seq_offset[None, :] <= range_end)
|
||||
& is_valid
|
||||
)
|
||||
seq_mask |= q_in_range & k_in_range
|
||||
|
||||
# S : (BLOCK_M, TILE_SIZE)
|
||||
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||
@@ -284,13 +318,6 @@ def kernel_unified_attention_2d(
|
||||
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
|
||||
)
|
||||
|
||||
if SLIDING_WINDOW > 0:
|
||||
S = tl.where(
|
||||
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
|
||||
S,
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
if USE_ALIBI_SLOPES:
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
|
||||
@@ -398,6 +425,9 @@ def kernel_unified_attention_3d(
|
||||
num_seqs: tl.int32,
|
||||
BLOCK_M: tl.constexpr, # int
|
||||
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
||||
USE_MM_PREFIX: tl.constexpr, # bool
|
||||
MAX_MM_RANGES: tl.constexpr, # int
|
||||
mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence
|
||||
):
|
||||
q_block_global_idx = tl.program_id(0)
|
||||
kv_head_idx = tl.program_id(1)
|
||||
@@ -559,7 +589,38 @@ def kernel_unified_attention_3d(
|
||||
else:
|
||||
V = V_load
|
||||
|
||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
||||
# Compute attention mask: causal by default (key <= query)
|
||||
query_abs_pos = context_len + query_pos[:, None]
|
||||
seq_mask = seq_offset[None, :] <= query_abs_pos
|
||||
|
||||
# Apply sliding window to base mask BEFORE mm_prefix OR.
|
||||
# Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix
|
||||
if SLIDING_WINDOW > 0:
|
||||
seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW)
|
||||
|
||||
# PrefixLM: extend mask with bidirectional ranges for multimodal tokens.
|
||||
# Applied AFTER sliding window so mm_prefix ranges override SW restriction.
|
||||
if USE_MM_PREFIX:
|
||||
for i in range(MAX_MM_RANGES):
|
||||
range_start = tl.load(
|
||||
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2
|
||||
)
|
||||
range_end = tl.load(
|
||||
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + 1
|
||||
)
|
||||
|
||||
is_valid = range_start < range_end
|
||||
q_in_range = (
|
||||
(query_abs_pos >= range_start)
|
||||
& (query_abs_pos <= range_end)
|
||||
& is_valid
|
||||
)
|
||||
k_in_range = (
|
||||
(seq_offset[None, :] >= range_start)
|
||||
& (seq_offset[None, :] <= range_end)
|
||||
& is_valid
|
||||
)
|
||||
seq_mask |= q_in_range & k_in_range
|
||||
|
||||
# S : (BLOCK_M, TILE_SIZE)
|
||||
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||
@@ -572,13 +633,6 @@ def kernel_unified_attention_3d(
|
||||
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
|
||||
)
|
||||
|
||||
if SLIDING_WINDOW > 0:
|
||||
S = tl.where(
|
||||
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
|
||||
S,
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
if USE_ALIBI_SLOPES:
|
||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||
|
||||
@@ -732,6 +786,38 @@ def reduce_segments(
|
||||
tl.store(output_ptr + output_offset, acc, mask=dim_mask)
|
||||
|
||||
|
||||
def _is_gemma3_attention(head_size: int, sliding_window: int) -> bool:
|
||||
"""Detect Gemma3 models via unique (head_size, sliding_window) signature.
|
||||
|
||||
Gemma3 models are the only ones using sliding_window=1024 with
|
||||
head_size 128 (27B) or 256 (1B, 4B, 12B). Other SWA models use
|
||||
different window sizes (Mistral=4096, Phi-3=2047).
|
||||
"""
|
||||
return sliding_window == 1024 and head_size in (128, 256)
|
||||
|
||||
|
||||
def _get_tile_size(
|
||||
head_size: int,
|
||||
sliding_window: int,
|
||||
element_size: int,
|
||||
is_prefill: bool,
|
||||
) -> int:
|
||||
"""Select tile size with Gemma3-specific optimization.
|
||||
|
||||
For Gemma3, use 32 for both prefill and decode to better utilize
|
||||
the larger head dimension (128/256). For other models, use
|
||||
the default vLLM behavior.
|
||||
"""
|
||||
if _is_gemma3_attention(head_size, sliding_window):
|
||||
# Gemma3: use 32 for decode (default is 16)
|
||||
return 32
|
||||
|
||||
# Default behavior
|
||||
if is_prefill:
|
||||
return 32
|
||||
return 16 if element_size >= 2 else 32
|
||||
|
||||
|
||||
def unified_attention(
|
||||
q,
|
||||
k,
|
||||
@@ -759,6 +845,8 @@ def unified_attention(
|
||||
qq_bias=None,
|
||||
# Optional tensor for sinks
|
||||
sinks=None,
|
||||
# Optional tensor for prefix lengths (PrefixLM support)
|
||||
mm_prefix_range=None,
|
||||
):
|
||||
assert causal, "Only causal attention is supported"
|
||||
assert q_descale is None, "Q scales not supported"
|
||||
@@ -766,6 +854,17 @@ def unified_attention(
|
||||
if sinks is not None:
|
||||
assert sinks.shape[0] == q.shape[1], "Sinks must be num_query_heads size"
|
||||
|
||||
use_mm_prefix = False
|
||||
max_mm_ranges = 0
|
||||
if mm_prefix_range is not None:
|
||||
if mm_prefix_range.ndim == 3:
|
||||
use_mm_prefix = True
|
||||
max_mm_ranges = mm_prefix_range.shape[1]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported mm_prefix_range shape: {mm_prefix_range.shape}"
|
||||
)
|
||||
|
||||
use_alibi_slopes = alibi_slopes is not None
|
||||
use_qq_bias = qq_bias is not None
|
||||
|
||||
@@ -792,11 +891,21 @@ def unified_attention(
|
||||
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
|
||||
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
|
||||
|
||||
# Assigning default tile sizes for prefill and decode.
|
||||
# Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1)
|
||||
# and at least 16 for all other data types.
|
||||
TILE_SIZE_PREFILL = 32
|
||||
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
|
||||
# Tile sizes for prefill and decode. Gemma3 models use optimized values.
|
||||
# Note: tile size must be at least 32 for fp8 (element_size == 1).
|
||||
sliding_window_val = 1 + window_size[0] if window_size[0] >= 0 else 0
|
||||
TILE_SIZE_PREFILL = _get_tile_size(
|
||||
head_size,
|
||||
sliding_window_val,
|
||||
q.element_size(),
|
||||
is_prefill=True,
|
||||
)
|
||||
TILE_SIZE_DECODE = _get_tile_size(
|
||||
head_size,
|
||||
sliding_window_val,
|
||||
q.element_size(),
|
||||
is_prefill=False,
|
||||
)
|
||||
|
||||
# Launch the 2D kernel if
|
||||
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
|
||||
@@ -847,6 +956,9 @@ def unified_attention(
|
||||
USE_QQ_BIAS=use_qq_bias,
|
||||
USE_SOFTCAP=(softcap > 0),
|
||||
USE_SINKS=(sinks is not None),
|
||||
USE_MM_PREFIX=use_mm_prefix,
|
||||
MAX_MM_RANGES=max_mm_ranges,
|
||||
mm_prefix_range_ptr=mm_prefix_range,
|
||||
SLIDING_WINDOW=(1 + window_size[0]),
|
||||
stride_k_cache_0=k.stride(0),
|
||||
stride_k_cache_1=k.stride(1),
|
||||
@@ -895,6 +1007,9 @@ def unified_attention(
|
||||
USE_QQ_BIAS=use_qq_bias,
|
||||
USE_SOFTCAP=(softcap > 0),
|
||||
USE_SINKS=(sinks is not None),
|
||||
USE_MM_PREFIX=use_mm_prefix,
|
||||
MAX_MM_RANGES=max_mm_ranges,
|
||||
mm_prefix_range_ptr=mm_prefix_range,
|
||||
SLIDING_WINDOW=(1 + window_size[0]),
|
||||
stride_k_cache_0=k.stride(0),
|
||||
stride_k_cache_1=k.stride(1),
|
||||
|
||||
@@ -915,8 +915,6 @@ class CompilationConfig:
|
||||
"mode is CompilationMode.VLLM_COMPILE"
|
||||
)
|
||||
|
||||
added_default_splitting_ops = False
|
||||
|
||||
if self.pass_config.fuse_attn_quant and not self.use_inductor_graph_partition:
|
||||
self.set_splitting_ops_for_attn_fusion()
|
||||
else:
|
||||
@@ -930,7 +928,6 @@ class CompilationConfig:
|
||||
# for details. Make a copy to avoid mutating the class-level
|
||||
# list via reference.
|
||||
self.splitting_ops = list(self._attention_ops)
|
||||
added_default_splitting_ops = True
|
||||
elif len(self.splitting_ops) == 0:
|
||||
if (
|
||||
self.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||
@@ -958,44 +955,25 @@ class CompilationConfig:
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
self.splitting_ops = []
|
||||
|
||||
# split MoE ops for cudagraph
|
||||
moe_ops = [
|
||||
"vllm::moe_forward",
|
||||
"vllm::moe_forward_shared",
|
||||
]
|
||||
# Disable CUDA graphs for DeepEP high-throughput since its not CG compatible
|
||||
backend = all2all_backend or envs.VLLM_ALL2ALL_BACKEND
|
||||
dp_size = data_parallel_size if data_parallel_size is not None else 1
|
||||
need_moe_splitting = (
|
||||
if (
|
||||
backend == "deepep_high_throughput"
|
||||
and dp_size > 1
|
||||
# pure attn-fusion without inductor partition deliberately disables
|
||||
# piecewise graphs and MoE splitting.
|
||||
and not (
|
||||
self.pass_config.fuse_attn_quant
|
||||
and not self.use_inductor_graph_partition
|
||||
and self.cudagraph_mode != CUDAGraphMode.NONE
|
||||
):
|
||||
# TODO: Piecewise Cuda graph might be enabled
|
||||
# if torch compile cache key issue fixed
|
||||
# See https://github.com/vllm-project/vllm/pull/25093
|
||||
logger.info(
|
||||
"DeepEP: Disabling CUDA Graphs since DeepEP high-throughput kernels "
|
||||
"are optimized for prefill and are incompatible with CUDA Graphs. "
|
||||
"In order to use CUDA Graphs for decode-optimized workloads, "
|
||||
"use --all2all-backend with another option, such as "
|
||||
"deepep_low_latency, pplx, or allgather_reducescatter."
|
||||
)
|
||||
)
|
||||
|
||||
if need_moe_splitting and self.cudagraph_mode != CUDAGraphMode.NONE:
|
||||
# if we just initialized default splitting_ops for this config,
|
||||
# automatically append the MoE ops
|
||||
if added_default_splitting_ops:
|
||||
for op in moe_ops:
|
||||
if op not in self.splitting_ops:
|
||||
self.splitting_ops.append(op)
|
||||
|
||||
# make sure MoE ops are split out
|
||||
if not any(op in self.splitting_ops for op in moe_ops):
|
||||
self.cudagraph_mode = CUDAGraphMode.NONE
|
||||
logger.warning_once(
|
||||
"DeepEP high throughput backend with data_parallel_size > 1 "
|
||||
"requires splitting MoE ops from cudagraphs. Please ensure "
|
||||
"'vllm::moe_forward' or 'vllm::moe_forward_shared' are "
|
||||
"present in CompilationConfig.splitting_ops."
|
||||
)
|
||||
elif self.cudagraph_mode.has_full_cudagraphs():
|
||||
# fall back to piecewise when MoE splitting is required.
|
||||
self.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
||||
self.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
def set_splitting_ops_for_attn_fusion(self):
|
||||
assert self.pass_config.fuse_attn_quant
|
||||
|
||||
@@ -795,7 +795,10 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
top_k,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_meta,
|
||||
# expert_tokens_meta help in allocating optimal/minimal
|
||||
# amount of workspace. Mark it None, so we allocate for
|
||||
# the worst-case scenario.
|
||||
expert_tokens_meta=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import Gemma3TextConfig
|
||||
|
||||
@@ -223,77 +222,9 @@ class Gemma3Attention(nn.Module):
|
||||
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
if not kwargs.get("has_images", False):
|
||||
# Fast path for text-only inputs. The performance for the text-only
|
||||
# inputs are not affected by the naive attention below.
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
# NOTE(woosuk): Gemma3 uses bidirectional attention between image tokens
|
||||
# that correspond to the same image while using causal attention
|
||||
# otherwise. Current attention backends cannot handle this pattern, so
|
||||
# we temporarily use a naive attention implementation with mask tensors.
|
||||
|
||||
# We intentionally keep the attention backend as-is and only override
|
||||
# `attn_output` with the naive implementation's output. This minimizes
|
||||
# changes to existing model runners and attention backends. The call to
|
||||
# `self.attn(q, k, v)` is only used to populate the KV cache - its
|
||||
# output is discarded and overwritten below. While this duplicates
|
||||
# computation, it maintains compatibility.
|
||||
# TODO(woosuk): Optimize by implementing custom attention kernels.
|
||||
attn_output = self.naive_attn_with_masks(q, k, v, out=attn_output, **kwargs)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
def naive_attn_with_masks(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(woosuk): As described in the comment above, this code is not
|
||||
# meant to be performant. It is only meant to be correct.
|
||||
q = q.view(-1, self.num_heads, self.head_dim)
|
||||
# Expand the key and value to handle GQA.
|
||||
num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||
k = k.repeat_interleave(num_queries_per_kv, dim=-2)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||
v = v.repeat_interleave(num_queries_per_kv, dim=-2)
|
||||
|
||||
if self.is_sliding:
|
||||
attn_masks = kwargs["local_attn_masks"]
|
||||
else:
|
||||
attn_masks = kwargs["global_attn_masks"]
|
||||
|
||||
seq_lens = kwargs["seq_lens"]
|
||||
start_idx = 0
|
||||
for seq_len, attn_mask in zip(seq_lens, attn_masks):
|
||||
end_idx = start_idx + seq_len
|
||||
query = q[start_idx:end_idx].unsqueeze(0)
|
||||
key = k[start_idx:end_idx].unsqueeze(0)
|
||||
value = v[start_idx:end_idx].unsqueeze(0)
|
||||
|
||||
# Transpose.
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
output = F.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
self.scaling,
|
||||
)
|
||||
output = output.transpose(1, 2).flatten(-2, -1)
|
||||
out[start_idx:end_idx] = output
|
||||
start_idx = end_idx
|
||||
return out
|
||||
|
||||
|
||||
class Gemma3DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
|
||||
529
vllm/model_executor/models/jais2.py
Normal file
529
vllm/model_executor/models/jais2.py
Normal file
@@ -0,0 +1,529 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Inference-only Jais2 model compatible with HuggingFace weights."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Jais2Config
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
extract_layer_index,
|
||||
is_pp_missing_parameter,
|
||||
make_empty_intermediate_tensors_factory,
|
||||
make_layers,
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
|
||||
class Jais2MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
bias: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.up_proj = ColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_size=intermediate_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
input_size=intermediate_size,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
self.act_fn = ReLUSquaredActivation()
|
||||
|
||||
def forward(self, x):
|
||||
x, _ = self.up_proj(x)
|
||||
x = self.act_fn(x)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class Jais2Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Jais2Config,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
bias: bool = False,
|
||||
cache_config: CacheConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
layer_idx = extract_layer_index(prefix)
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
||||
self.head_dim = getattr(
|
||||
config, "head_dim", self.hidden_size // self.total_num_heads
|
||||
)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=hidden_size,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.total_num_heads,
|
||||
total_num_kv_heads=self.total_num_kv_heads,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=self.total_num_heads * self.head_dim,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
is_neox_style = True
|
||||
if quant_config is not None and quant_config.get_name() == "gguf":
|
||||
is_neox_style = False
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
rope_parameters=getattr(config, "rope_parameters", None),
|
||||
is_neox_style=is_neox_style,
|
||||
)
|
||||
|
||||
if hasattr(config, "interleaved_sliding_window"):
|
||||
interleaved_sliding_window = config.interleaved_sliding_window
|
||||
if isinstance(interleaved_sliding_window, int):
|
||||
sliding_window = interleaved_sliding_window
|
||||
elif isinstance(interleaved_sliding_window, list):
|
||||
sw_idx = layer_idx % len(interleaved_sliding_window)
|
||||
sliding_window = interleaved_sliding_window[sw_idx]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{type(interleaved_sliding_window)} is not supported."
|
||||
)
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class Jais2DecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
config: Jais2Config,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
config = config or vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = self.get_quant_config(vllm_config)
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
# Support abacusai/Smaug-72B-v0.1 with attention_bias
|
||||
# Support internlm/internlm-7b with bias
|
||||
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
||||
config, "bias", False
|
||||
)
|
||||
self.self_attn = Jais2Attention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=getattr(
|
||||
config, "num_key_value_heads", config.num_attention_heads
|
||||
),
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
bias=attention_bias,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.mlp = Jais2MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
bias=getattr(config, "mlp_bias", False),
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = (
|
||||
self.input_layernorm(hidden_states + residual),
|
||||
hidden_states + residual,
|
||||
)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = (
|
||||
self.post_attention_layernorm(hidden_states + residual),
|
||||
hidden_states + residual,
|
||||
)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
|
||||
"""Get quantization config for this layer. Override in subclasses."""
|
||||
return vllm_config.quant_config
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Jais2Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
layer_type: type[nn.Module] = Jais2DecoderLayer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.padding_idx = config.pad_token_id
|
||||
lora_vocab = (
|
||||
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
||||
if lora_config
|
||||
else 0
|
||||
)
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.org_vocab_size = config.vocab_size
|
||||
if get_pp_group().is_first_rank or (
|
||||
config.tie_word_embeddings and get_pp_group().is_last_rank
|
||||
):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
else:
|
||||
self.embed_tokens = PPMissingLayer()
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
config.num_hidden_layers,
|
||||
lambda prefix: layer_type(
|
||||
config=config,
|
||||
vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
),
|
||||
prefix=f"{prefix}.layers",
|
||||
)
|
||||
if get_pp_group().is_last_rank:
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
else:
|
||||
self.norm = PPMissingLayer()
|
||||
|
||||
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||
["hidden_states", "residual"], config.hidden_size
|
||||
)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states + residual), residual
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if self.quant_config is not None and (
|
||||
scale_name := self.quant_config.get_cache_scale(name)
|
||||
):
|
||||
# Loading kv cache scales for compressed-tensors quantization
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
if "scale" in name:
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class Jais2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
}
|
||||
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.model = self._init_model(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||
)
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=(
|
||||
DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config
|
||||
else lora_config.lora_vocab_padding_size
|
||||
),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
|
||||
|
||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||
self.logits_processor = LogitsProcessor(
|
||||
self.unpadded_vocab_size, config.vocab_size, logit_scale
|
||||
)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
return Jais2Model(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.embed_input_ids(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
model_output = self.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds
|
||||
)
|
||||
return model_output
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor | None:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
@@ -127,6 +127,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
|
||||
"InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
|
||||
"Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
|
||||
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
|
||||
"KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"), # noqa: E501
|
||||
"Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
|
||||
|
||||
@@ -318,15 +318,15 @@ def patch_rope_parameters(config: PretrainedConfig) -> None:
|
||||
# Transformers v4 installed, legacy config fields may be present
|
||||
if (rope_scaling := getattr(config, "rope_scaling", None)) is not None:
|
||||
config.rope_parameters = rope_scaling
|
||||
if (
|
||||
rope_theta is not None or partial_rotary_factor is not None
|
||||
) and not getattr(config, "rope_parameters", None):
|
||||
config.rope_parameters = {"rope_type": "default"}
|
||||
if rope_theta is not None:
|
||||
if not hasattr(config, "rope_parameters"):
|
||||
config.rope_parameters = {"rope_type": "default"}
|
||||
config.rope_parameters["rope_theta"] = rope_theta
|
||||
if partial_rotary_factor is not None:
|
||||
if not hasattr(config, "rope_parameters"):
|
||||
config.rope_parameters = {"rope_type": "default"}
|
||||
config.rope_parameters["partial_rotary_factor"] = partial_rotary_factor
|
||||
elif rope_theta is not None or hasattr(config, "rope_parameters"):
|
||||
elif rope_theta is not None or getattr(config, "rope_parameters", None):
|
||||
# Transformers v5 installed
|
||||
# Patch these fields in case they used non-standard names
|
||||
if rope_theta is not None:
|
||||
|
||||
@@ -160,7 +160,7 @@ def physical_to_logical_mapping(
|
||||
└───────────────────────────────────────────┘
|
||||
|
||||
If multiple logical blocks map to the same physical block,
|
||||
this function returns the first (minimum) logical block index.
|
||||
this function returns the latest (maximum) logical block index.
|
||||
|
||||
If a physical block is not mapped to by any logical block,
|
||||
its value in the result will be -1.
|
||||
@@ -183,6 +183,15 @@ def physical_to_logical_mapping(
|
||||
To prevent this, we use seq_lens and block_size to mask out unused
|
||||
entries, ensuring only valid block references are processed.
|
||||
|
||||
IMPORTANT: Reused physical blocks (sliding-window / hybrid attention)
|
||||
────────────────────────────────────────────────────────────────────
|
||||
For some attention types, physical cache blocks can be reused over time.
|
||||
This can cause the same physical block id to appear multiple times in a row
|
||||
of `block_table` at different logical block indices. In that case, only the
|
||||
latest logical block index corresponds to the current contents of that
|
||||
physical block. Therefore, the inverse mapping must pick the maximum logical
|
||||
block index for each physical block id.
|
||||
|
||||
Args:
|
||||
block_table: Tensor of shape [max_reqs, max_num_blocks]
|
||||
mapping logical blocks to physical locations. May contain
|
||||
@@ -217,8 +226,8 @@ def physical_to_logical_mapping(
|
||||
mask, torch.arange(max_num_blocks, device=device)[None, :], 0
|
||||
)
|
||||
|
||||
physical_to_logical.scatter_(
|
||||
-1, valid_block_table.to(torch.int64), valid_logical_indices
|
||||
physical_to_logical.scatter_reduce_(
|
||||
-1, valid_block_table.to(torch.int64), valid_logical_indices, reduce="amax"
|
||||
)
|
||||
# NB - Seems like block 0 is always empty so we reset it manually
|
||||
physical_to_logical[:, 0] = -1
|
||||
|
||||
@@ -76,6 +76,39 @@ class TritonAttentionMetadata:
|
||||
# Optional aot scheduling
|
||||
scheduler_metadata: torch.Tensor | None = None
|
||||
prefix_scheduler_metadata: torch.Tensor | None = None
|
||||
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
|
||||
|
||||
@property
|
||||
def mm_prefix_range_tensor(self) -> torch.Tensor | None:
|
||||
"""Convert mm_prefix_range dict to padded tensor for Triton kernel.
|
||||
|
||||
Returns shape: (num_seqs, max_ranges, 2) with 0-padding for empty ranges.
|
||||
Empty ranges have start==end==0, which kernel skips via is_valid check.
|
||||
"""
|
||||
# TODO(Isotr0py): Move to model runner's attention metadata
|
||||
# preparation to avoid duplicate computation.
|
||||
if self.mm_prefix_range is None:
|
||||
return None
|
||||
|
||||
num_seqs = self.seq_lens.shape[0]
|
||||
device = self.seq_lens.device
|
||||
|
||||
# Collect ranges, using [(0,0)] for empty sequences to ensure uniform dims
|
||||
range_lists = [
|
||||
self.mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)] for i in range(num_seqs)
|
||||
]
|
||||
|
||||
# Return None if all ranges are trivial (only (0,0) placeholders)
|
||||
if all(r == [(0, 0)] for r in range_lists):
|
||||
return None
|
||||
|
||||
# Create 2D tensors with shape (num_ranges, 2) for each sequence
|
||||
range_tensors = [
|
||||
torch.tensor(r, dtype=torch.int32, device=device).view(-1, 2)
|
||||
for r in range_lists
|
||||
]
|
||||
|
||||
return torch.nested.nested_tensor(range_tensors).to_padded_tensor(0)
|
||||
|
||||
|
||||
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||
@@ -268,6 +301,10 @@ class TritonAttentionBackend(AttentionBackend):
|
||||
def supports_head_size(cls, head_size: int) -> bool:
|
||||
return head_size >= 32
|
||||
|
||||
@classmethod
|
||||
def supports_mm_prefix(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_sink(cls) -> bool:
|
||||
return True
|
||||
@@ -427,6 +464,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
|
||||
mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor
|
||||
|
||||
unified_attention(
|
||||
q=query[:num_actual_tokens],
|
||||
@@ -453,6 +491,7 @@ class TritonAttentionImpl(AttentionImpl):
|
||||
softmax_segm_expsum=softmax_segm_expsum,
|
||||
sinks=self.sinks,
|
||||
output_scale=output_scale,
|
||||
mm_prefix_range=mm_prefix_range_tensor,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user