Compare commits
11 Commits
v0.13.0rc2
...
v0.13.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
72506c9834 | ||
|
|
b2eb84de77 | ||
|
|
ac43367ced | ||
|
|
30fe765e9f | ||
|
|
2c0ee0fde8 | ||
|
|
55f1fc1b1b | ||
|
|
17f3988094 | ||
|
|
682c38583c | ||
|
|
f124b56786 | ||
|
|
d78e128b8b | ||
|
|
761b730dcb |
@@ -1223,6 +1223,8 @@ steps:
|
|||||||
# FIXIT: find out which code initialize cuda before running the test
|
# FIXIT: find out which code initialize cuda before running the test
|
||||||
# before the fix, we need to use spawn to test it
|
# before the fix, we need to use spawn to test it
|
||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||||
|
# Alot of these tests are on the edge of OOMing
|
||||||
|
- export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||||
# There is some Tensor Parallelism related processing logic in LoRA that
|
# There is some Tensor Parallelism related processing logic in LoRA that
|
||||||
# requires multi-GPU testing for validation.
|
# requires multi-GPU testing for validation.
|
||||||
- pytest -v -s -x lora/test_chatglm3_tp.py
|
- pytest -v -s -x lora/test_chatglm3_tp.py
|
||||||
|
|||||||
@@ -406,6 +406,7 @@ th {
|
|||||||
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ |
|
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ |
|
||||||
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ |
|
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ |
|
||||||
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ |
|
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ |
|
||||||
|
| `Jais2ForCausalLM` | Jais2 | `inceptionai/Jais-2-8B-Chat`, `inceptionai/Jais-2-70B-Chat`, etc. | | ✅︎ |
|
||||||
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ |
|
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ |
|
||||||
| `KimiLinearForCausalLM` | Kimi-Linear-48B-A3B-Base, Kimi-Linear-48B-A3B-Instruct | `moonshotai/Kimi-Linear-48B-A3B-Base`, `moonshotai/Kimi-Linear-48B-A3B-Instruct` | | ✅︎ |
|
| `KimiLinearForCausalLM` | Kimi-Linear-48B-A3B-Base, Kimi-Linear-48B-A3B-Instruct | `moonshotai/Kimi-Linear-48B-A3B-Base`, `moonshotai/Kimi-Linear-48B-A3B-Instruct` | | ✅︎ |
|
||||||
| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ |
|
| `Lfm2ForCausalLM` | LFM2 | `LiquidAI/LFM2-1.2B`, `LiquidAI/LFM2-700M`, `LiquidAI/LFM2-350M`, etc. | ✅︎ | ✅︎ |
|
||||||
|
|||||||
@@ -233,24 +233,6 @@ def test_splitting_ops_dynamic():
|
|||||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||||
|
|
||||||
|
|
||||||
def test_moe_splitting_ops_deepep_ht_piecewise():
|
|
||||||
# Non-inductor, non-attn-fusion case: DeepEP HT with dp>1
|
|
||||||
# should add MoE ops to splitting_ops on top of attention ops.
|
|
||||||
config = VllmConfig(
|
|
||||||
parallel_config=ParallelConfig(
|
|
||||||
all2all_backend="deepep_high_throughput",
|
|
||||||
data_parallel_size=8,
|
|
||||||
),
|
|
||||||
compilation_config=CompilationConfig(
|
|
||||||
mode=CompilationMode.VLLM_COMPILE,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
splitting_ops = config.compilation_config.splitting_ops
|
|
||||||
assert splitting_ops is not None
|
|
||||||
assert "vllm::moe_forward" in splitting_ops
|
|
||||||
assert "vllm::moe_forward_shared" in splitting_ops
|
|
||||||
|
|
||||||
|
|
||||||
def test_moe_splitting_ops_deepep_ht_inductor_partition():
|
def test_moe_splitting_ops_deepep_ht_inductor_partition():
|
||||||
# Inductor partition case: user-provided splitting_ops should be
|
# Inductor partition case: user-provided splitting_ops should be
|
||||||
# preserved and MoE ops should be appended for DeepEP HT with dp>1.
|
# preserved and MoE ops should be appended for DeepEP HT with dp>1.
|
||||||
@@ -277,26 +259,6 @@ def test_moe_splitting_ops_deepep_ht_inductor_partition():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_moe_splitting_ops_deepep_ht_attn_fusion_no_inductor():
|
|
||||||
# Pure attn-fusion case without inductor partition: even with
|
|
||||||
# DeepEP HT and dp>1, we should not re-enable piecewise compilation
|
|
||||||
# or add MoE ops into splitting_ops.
|
|
||||||
config = VllmConfig(
|
|
||||||
parallel_config=ParallelConfig(
|
|
||||||
all2all_backend="deepep_high_throughput",
|
|
||||||
data_parallel_size=8,
|
|
||||||
),
|
|
||||||
compilation_config=CompilationConfig(
|
|
||||||
mode=CompilationMode.VLLM_COMPILE,
|
|
||||||
pass_config={"fuse_attn_quant": True, "eliminate_noops": True},
|
|
||||||
custom_ops=["+quant_fp8"],
|
|
||||||
cudagraph_mode=CUDAGraphMode.PIECEWISE,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
assert config.compilation_config.splitting_ops == []
|
|
||||||
assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL
|
|
||||||
|
|
||||||
|
|
||||||
def test_should_split():
|
def test_should_split():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@@ -254,7 +254,9 @@ async def test_single_chat_session_input_audio(
|
|||||||
async def test_chat_streaming_audio(
|
async def test_chat_streaming_audio(
|
||||||
client: openai.AsyncOpenAI, model_name: str, audio_url: str
|
client: openai.AsyncOpenAI, model_name: str, audio_url: str
|
||||||
):
|
):
|
||||||
messages = dummy_messages_from_audio_url(audio_url)
|
messages = dummy_messages_from_audio_url(
|
||||||
|
audio_url, "What's a short title for this audio?"
|
||||||
|
)
|
||||||
|
|
||||||
# test single completion
|
# test single completion
|
||||||
chat_completion = await client.chat.completions.create(
|
chat_completion = await client.chat.completions.create(
|
||||||
|
|||||||
@@ -15,7 +15,10 @@ from tests.v1.attention.utils import (
|
|||||||
create_standard_kv_cache_spec,
|
create_standard_kv_cache_spec,
|
||||||
create_vllm_config,
|
create_vllm_config,
|
||||||
)
|
)
|
||||||
from vllm.v1.attention.backends.flex_attention import FlexAttentionMetadataBuilder
|
from vllm.v1.attention.backends.flex_attention import (
|
||||||
|
FlexAttentionMetadataBuilder,
|
||||||
|
physical_to_logical_mapping,
|
||||||
|
)
|
||||||
|
|
||||||
from ..models.utils import check_embeddings_close, check_logprobs_close
|
from ..models.utils import check_embeddings_close, check_logprobs_close
|
||||||
|
|
||||||
@@ -212,5 +215,31 @@ def test_block_mask_direct_vs_slow_path():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_physical_to_logical_mapping_handles_reused_blocks():
|
||||||
|
"""Regression test: reused physical blocks map to the latest logical block.
|
||||||
|
|
||||||
|
For sliding-window / hybrid attention layers, physical KV-cache blocks can be
|
||||||
|
reused over time. The inverse mapping must therefore select the latest
|
||||||
|
logical block index for a physical block id.
|
||||||
|
"""
|
||||||
|
# Padding should not make physical block 0 look live.
|
||||||
|
block_table = torch.tensor([[6, 0, 0, 0]], dtype=torch.int32)
|
||||||
|
seq_lens = torch.tensor([1 * 16], dtype=torch.int32) # only 1 block valid
|
||||||
|
out = physical_to_logical_mapping(
|
||||||
|
block_table=block_table, seq_lens=seq_lens, block_size=16, total_blocks=10
|
||||||
|
)
|
||||||
|
assert out[0, 0].item() == -1
|
||||||
|
assert out[0, 6].item() == 0
|
||||||
|
|
||||||
|
# If a physical block id appears multiple times (block reuse), mapping should
|
||||||
|
# point to the latest logical block index.
|
||||||
|
block_table2 = torch.tensor([[2, 2, 5]], dtype=torch.int32)
|
||||||
|
seq_lens2 = torch.tensor([3 * 16], dtype=torch.int32)
|
||||||
|
out2 = physical_to_logical_mapping(
|
||||||
|
block_table=block_table2, seq_lens=seq_lens2, block_size=16, total_blocks=8
|
||||||
|
)
|
||||||
|
assert out2[0, 2].item() == 1
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|||||||
@@ -1,17 +1,23 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from typing import Literal, NamedTuple
|
import os
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||||
|
|
||||||
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from pytest import MarkDecorator
|
from pytest import MarkDecorator
|
||||||
|
from transformers import AutoModelForImageTextToText
|
||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
|
from vllm.multimodal.image import rescale_image_size
|
||||||
from vllm.utils.torch_utils import set_default_torch_num_threads
|
from vllm.utils.torch_utils import set_default_torch_num_threads
|
||||||
|
|
||||||
from ....conftest import PromptImageInput, VllmRunner
|
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner
|
||||||
from ...utils import check_logprobs_close
|
from ...utils import check_logprobs_close
|
||||||
|
|
||||||
|
|
||||||
@@ -21,9 +27,10 @@ class GGUFMMTestConfig(NamedTuple):
|
|||||||
gguf_backbone: str
|
gguf_backbone: str
|
||||||
gguf_mmproj: str
|
gguf_mmproj: str
|
||||||
prompt: list[str]
|
prompt: list[str]
|
||||||
mm_data: dict[Literal["images"], PromptImageInput]
|
image_names: list[str] # Store names, load PIL images at runtime
|
||||||
max_model_len: int = 4096
|
max_model_len: int = 4096
|
||||||
marks: list[MarkDecorator] = []
|
marks: list[MarkDecorator] = []
|
||||||
|
mm_processor_kwargs: dict[str, Any] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def gguf_model(self):
|
def gguf_model(self):
|
||||||
@@ -31,27 +38,75 @@ class GGUFMMTestConfig(NamedTuple):
|
|||||||
return hf_hub_download(self.gguf_repo, filename=self.gguf_backbone)
|
return hf_hub_download(self.gguf_repo, filename=self.gguf_backbone)
|
||||||
|
|
||||||
|
|
||||||
|
# Common prompts aligned with test_common.py "gemma3" entry format
|
||||||
|
_GEMMA3_PROMPTS = IMAGE_ASSETS.prompts(
|
||||||
|
{
|
||||||
|
"stop_sign": (
|
||||||
|
"<bos><start_of_turn>user\n"
|
||||||
|
"<start_of_image>What's the content in the center of the image?"
|
||||||
|
"<end_of_turn>\n<start_of_turn>model\n"
|
||||||
|
),
|
||||||
|
"cherry_blossom": (
|
||||||
|
"<bos><start_of_turn>user\n"
|
||||||
|
"<start_of_image>What is the season?"
|
||||||
|
"<end_of_turn>\n<start_of_turn>model\n"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Image asset names - load at runtime to avoid pickle issues with subprocess
|
||||||
|
_GEMMA3_IMAGE_NAMES = ["stop_sign", "cherry_blossom"]
|
||||||
|
|
||||||
|
# Regular multimodal (no pan-and-scan) - uses QAT Q4_0 GGUF
|
||||||
GEMMA3_CONFIG = GGUFMMTestConfig(
|
GEMMA3_CONFIG = GGUFMMTestConfig(
|
||||||
original_model="google/gemma-3-4b-it",
|
original_model="google/gemma-3-4b-it",
|
||||||
gguf_repo="google/gemma-3-4b-it-qat-q4_0-gguf",
|
gguf_repo="google/gemma-3-4b-it-qat-q4_0-gguf",
|
||||||
gguf_backbone="gemma-3-4b-it-q4_0.gguf",
|
gguf_backbone="gemma-3-4b-it-q4_0.gguf",
|
||||||
gguf_mmproj="mmproj-model-f16-4B.gguf",
|
gguf_mmproj="mmproj-model-f16-4B.gguf",
|
||||||
prompt=["<start_of_image>Describe this image in detail:"],
|
prompt=_GEMMA3_PROMPTS,
|
||||||
mm_data={"images": [ImageAsset("stop_sign").pil_image]},
|
image_names=_GEMMA3_IMAGE_NAMES,
|
||||||
|
max_model_len=4096,
|
||||||
marks=[pytest.mark.core_model],
|
marks=[pytest.mark.core_model],
|
||||||
|
mm_processor_kwargs={},
|
||||||
)
|
)
|
||||||
|
|
||||||
MODELS_TO_TEST = [GEMMA3_CONFIG]
|
# Pan-and-scan multimodal - uses unquantized BF16 GGUF
|
||||||
|
GEMMA3_CONFIG_PAN_AND_SCAN = GGUFMMTestConfig(
|
||||||
|
original_model="google/gemma-3-4b-it",
|
||||||
|
gguf_repo="unsloth/gemma-3-4b-it-GGUF",
|
||||||
|
gguf_backbone="gemma-3-4b-it-BF16.gguf",
|
||||||
|
gguf_mmproj="mmproj-BF16.gguf",
|
||||||
|
prompt=_GEMMA3_PROMPTS,
|
||||||
|
image_names=_GEMMA3_IMAGE_NAMES,
|
||||||
|
max_model_len=4096,
|
||||||
|
marks=[pytest.mark.core_model],
|
||||||
|
mm_processor_kwargs={"do_pan_and_scan": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
MODELS_TO_TEST = [GEMMA3_CONFIG, GEMMA3_CONFIG_PAN_AND_SCAN]
|
||||||
|
|
||||||
|
|
||||||
def run_multimodal_gguf_test(
|
def run_multimodal_gguf_test(
|
||||||
|
hf_runner: type[HfRunner],
|
||||||
vllm_runner: type[VllmRunner],
|
vllm_runner: type[VllmRunner],
|
||||||
model: GGUFMMTestConfig,
|
model: GGUFMMTestConfig,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
):
|
):
|
||||||
# Run gguf model.
|
# Load images at runtime (inside subprocess) to avoid pickle issues
|
||||||
|
images = [ImageAsset(name).pil_image for name in model.image_names]
|
||||||
|
size_factors = [0.25, 0.5, 1.0]
|
||||||
|
inputs_per_image = [
|
||||||
|
(
|
||||||
|
[prompt for _ in size_factors],
|
||||||
|
[rescale_image_size(image, factor) for factor in size_factors],
|
||||||
|
)
|
||||||
|
for image, prompt in zip(images, model.prompt)
|
||||||
|
]
|
||||||
|
|
||||||
|
# NOTE: Run vLLM first to avoid CUDA init issues with multiprocessing fork.
|
||||||
|
# Run GGUF model via vLLM.
|
||||||
with (
|
with (
|
||||||
set_default_torch_num_threads(1),
|
set_default_torch_num_threads(1),
|
||||||
vllm_runner(
|
vllm_runner(
|
||||||
@@ -60,35 +115,42 @@ def run_multimodal_gguf_test(
|
|||||||
tokenizer_name=model.original_model,
|
tokenizer_name=model.original_model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
max_model_len=model.max_model_len,
|
max_model_len=model.max_model_len,
|
||||||
|
mm_processor_kwargs=model.mm_processor_kwargs,
|
||||||
) as gguf_model,
|
) as gguf_model,
|
||||||
):
|
):
|
||||||
gguf_outputs = gguf_model.generate_greedy_logprobs(
|
gguf_outputs_per_case = [
|
||||||
prompts=model.prompt,
|
gguf_model.generate_greedy_logprobs(
|
||||||
max_tokens=max_tokens,
|
prompts,
|
||||||
num_logprobs=num_logprobs,
|
max_tokens,
|
||||||
**model.mm_data,
|
num_logprobs=num_logprobs,
|
||||||
)
|
images=images,
|
||||||
|
)
|
||||||
|
for prompts, images in inputs_per_image
|
||||||
|
]
|
||||||
|
|
||||||
# Run unquantized model.
|
# Then run HfRunner for HuggingFace baseline comparison.
|
||||||
with vllm_runner(
|
with hf_runner(
|
||||||
model_name=model.original_model,
|
model.original_model,
|
||||||
enforce_eager=True, # faster tests
|
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
max_model_len=model.max_model_len,
|
auto_cls=AutoModelForImageTextToText,
|
||||||
) as original_model:
|
) as hf_model:
|
||||||
original_outputs = original_model.generate_greedy_logprobs(
|
hf_outputs_per_case = [
|
||||||
prompts=model.prompt,
|
hf_model.generate_greedy_logprobs_limit(
|
||||||
max_tokens=max_tokens,
|
prompts,
|
||||||
num_logprobs=num_logprobs,
|
max_tokens,
|
||||||
**model.mm_data,
|
num_logprobs=num_logprobs,
|
||||||
)
|
images=images,
|
||||||
|
)
|
||||||
|
for prompts, images in inputs_per_image
|
||||||
|
]
|
||||||
|
|
||||||
check_logprobs_close(
|
for hf_outputs, gguf_outputs in zip(hf_outputs_per_case, gguf_outputs_per_case):
|
||||||
outputs_0_lst=original_outputs,
|
check_logprobs_close(
|
||||||
outputs_1_lst=gguf_outputs,
|
outputs_0_lst=hf_outputs,
|
||||||
name_0="original",
|
outputs_1_lst=gguf_outputs,
|
||||||
name_1="gguf",
|
name_0="hf",
|
||||||
)
|
name_1="gguf",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
@@ -105,11 +167,14 @@ def run_multimodal_gguf_test(
|
|||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@pytest.mark.parametrize("max_tokens", [32])
|
@pytest.mark.parametrize("max_tokens", [32])
|
||||||
@pytest.mark.parametrize("num_logprobs", [10])
|
@pytest.mark.parametrize("num_logprobs", [10])
|
||||||
def test_models(
|
def test_gemma3_mm_gguf(
|
||||||
|
hf_runner: type[HfRunner],
|
||||||
vllm_runner: type[VllmRunner],
|
vllm_runner: type[VllmRunner],
|
||||||
model: GGUFMMTestConfig,
|
model: GGUFMMTestConfig,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
run_multimodal_gguf_test(vllm_runner, model, dtype, max_tokens, num_logprobs)
|
run_multimodal_gguf_test(
|
||||||
|
hf_runner, vllm_runner, model, dtype, max_tokens, num_logprobs
|
||||||
|
)
|
||||||
|
|||||||
@@ -295,6 +295,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
"internlm/internlm3-8b-instruct", trust_remote_code=True
|
"internlm/internlm3-8b-instruct", trust_remote_code=True
|
||||||
),
|
),
|
||||||
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
|
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
|
||||||
|
"Jais2ForCausalLM": _HfExamplesInfo(
|
||||||
|
"inceptionai/Jais-2-8B-Chat", min_transformers_version="4.58"
|
||||||
|
),
|
||||||
"JambaForCausalLM": _HfExamplesInfo(
|
"JambaForCausalLM": _HfExamplesInfo(
|
||||||
"ai21labs/AI21-Jamba-1.5-Mini",
|
"ai21labs/AI21-Jamba-1.5-Mini",
|
||||||
extras={
|
extras={
|
||||||
|
|||||||
@@ -86,6 +86,9 @@ def kernel_unified_attention_2d(
|
|||||||
USE_SOFTCAP: tl.constexpr, # bool
|
USE_SOFTCAP: tl.constexpr, # bool
|
||||||
USE_SINKS: tl.constexpr, # bool
|
USE_SINKS: tl.constexpr, # bool
|
||||||
SLIDING_WINDOW: tl.constexpr, # int
|
SLIDING_WINDOW: tl.constexpr, # int
|
||||||
|
USE_MM_PREFIX: tl.constexpr, # bool
|
||||||
|
MAX_MM_RANGES: tl.constexpr, # int
|
||||||
|
mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence
|
||||||
stride_k_cache_0: tl.int64, # int
|
stride_k_cache_0: tl.int64, # int
|
||||||
stride_k_cache_1: tl.int64, # int
|
stride_k_cache_1: tl.int64, # int
|
||||||
stride_k_cache_2: tl.int64, # int
|
stride_k_cache_2: tl.int64, # int
|
||||||
@@ -270,7 +273,38 @@ def kernel_unified_attention_2d(
|
|||||||
else:
|
else:
|
||||||
V = V_load
|
V = V_load
|
||||||
|
|
||||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
# Compute attention mask: causal by default (key <= query)
|
||||||
|
query_abs_pos = context_len + query_pos[:, None]
|
||||||
|
seq_mask = seq_offset[None, :] <= query_abs_pos
|
||||||
|
|
||||||
|
# Apply sliding window to base mask BEFORE mm_prefix OR.
|
||||||
|
# Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix
|
||||||
|
if SLIDING_WINDOW > 0:
|
||||||
|
seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW)
|
||||||
|
|
||||||
|
# PrefixLM: extend mask with bidirectional ranges for multimodal tokens.
|
||||||
|
# Applied AFTER sliding window so mm_prefix ranges override SW restriction.
|
||||||
|
if USE_MM_PREFIX:
|
||||||
|
for i in range(MAX_MM_RANGES):
|
||||||
|
range_start = tl.load(
|
||||||
|
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2
|
||||||
|
)
|
||||||
|
range_end = tl.load(
|
||||||
|
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
is_valid = range_start < range_end
|
||||||
|
q_in_range = (
|
||||||
|
(query_abs_pos >= range_start)
|
||||||
|
& (query_abs_pos <= range_end)
|
||||||
|
& is_valid
|
||||||
|
)
|
||||||
|
k_in_range = (
|
||||||
|
(seq_offset[None, :] >= range_start)
|
||||||
|
& (seq_offset[None, :] <= range_end)
|
||||||
|
& is_valid
|
||||||
|
)
|
||||||
|
seq_mask |= q_in_range & k_in_range
|
||||||
|
|
||||||
# S : (BLOCK_M, TILE_SIZE)
|
# S : (BLOCK_M, TILE_SIZE)
|
||||||
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||||
@@ -284,13 +318,6 @@ def kernel_unified_attention_2d(
|
|||||||
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
|
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
|
||||||
)
|
)
|
||||||
|
|
||||||
if SLIDING_WINDOW > 0:
|
|
||||||
S = tl.where(
|
|
||||||
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
|
|
||||||
S,
|
|
||||||
float("-inf"),
|
|
||||||
)
|
|
||||||
|
|
||||||
if USE_ALIBI_SLOPES:
|
if USE_ALIBI_SLOPES:
|
||||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||||
|
|
||||||
@@ -398,6 +425,9 @@ def kernel_unified_attention_3d(
|
|||||||
num_seqs: tl.int32,
|
num_seqs: tl.int32,
|
||||||
BLOCK_M: tl.constexpr, # int
|
BLOCK_M: tl.constexpr, # int
|
||||||
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
|
||||||
|
USE_MM_PREFIX: tl.constexpr, # bool
|
||||||
|
MAX_MM_RANGES: tl.constexpr, # int
|
||||||
|
mm_prefix_range_ptr, # [num_seqs] - prefix length for each sequence
|
||||||
):
|
):
|
||||||
q_block_global_idx = tl.program_id(0)
|
q_block_global_idx = tl.program_id(0)
|
||||||
kv_head_idx = tl.program_id(1)
|
kv_head_idx = tl.program_id(1)
|
||||||
@@ -559,7 +589,38 @@ def kernel_unified_attention_3d(
|
|||||||
else:
|
else:
|
||||||
V = V_load
|
V = V_load
|
||||||
|
|
||||||
seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1
|
# Compute attention mask: causal by default (key <= query)
|
||||||
|
query_abs_pos = context_len + query_pos[:, None]
|
||||||
|
seq_mask = seq_offset[None, :] <= query_abs_pos
|
||||||
|
|
||||||
|
# Apply sliding window to base mask BEFORE mm_prefix OR.
|
||||||
|
# Order must match FlexAttention: (causal AND sliding_window) OR mm_prefix
|
||||||
|
if SLIDING_WINDOW > 0:
|
||||||
|
seq_mask = seq_mask & ((query_abs_pos - seq_offset) < SLIDING_WINDOW)
|
||||||
|
|
||||||
|
# PrefixLM: extend mask with bidirectional ranges for multimodal tokens.
|
||||||
|
# Applied AFTER sliding window so mm_prefix ranges override SW restriction.
|
||||||
|
if USE_MM_PREFIX:
|
||||||
|
for i in range(MAX_MM_RANGES):
|
||||||
|
range_start = tl.load(
|
||||||
|
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2
|
||||||
|
)
|
||||||
|
range_end = tl.load(
|
||||||
|
mm_prefix_range_ptr + seq_idx * MAX_MM_RANGES * 2 + i * 2 + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
is_valid = range_start < range_end
|
||||||
|
q_in_range = (
|
||||||
|
(query_abs_pos >= range_start)
|
||||||
|
& (query_abs_pos <= range_end)
|
||||||
|
& is_valid
|
||||||
|
)
|
||||||
|
k_in_range = (
|
||||||
|
(seq_offset[None, :] >= range_start)
|
||||||
|
& (seq_offset[None, :] <= range_end)
|
||||||
|
& is_valid
|
||||||
|
)
|
||||||
|
seq_mask |= q_in_range & k_in_range
|
||||||
|
|
||||||
# S : (BLOCK_M, TILE_SIZE)
|
# S : (BLOCK_M, TILE_SIZE)
|
||||||
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32)
|
||||||
@@ -572,13 +633,6 @@ def kernel_unified_attention_3d(
|
|||||||
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
|
query_mask_1[:, None] & query_mask_0[:, None] & seq_mask, S, float("-inf")
|
||||||
)
|
)
|
||||||
|
|
||||||
if SLIDING_WINDOW > 0:
|
|
||||||
S = tl.where(
|
|
||||||
(context_len + query_pos[:, None] - seq_offset) < SLIDING_WINDOW,
|
|
||||||
S,
|
|
||||||
float("-inf"),
|
|
||||||
)
|
|
||||||
|
|
||||||
if USE_ALIBI_SLOPES:
|
if USE_ALIBI_SLOPES:
|
||||||
S += alibi_slope[:, None] * (seq_offset - context_len)
|
S += alibi_slope[:, None] * (seq_offset - context_len)
|
||||||
|
|
||||||
@@ -732,6 +786,38 @@ def reduce_segments(
|
|||||||
tl.store(output_ptr + output_offset, acc, mask=dim_mask)
|
tl.store(output_ptr + output_offset, acc, mask=dim_mask)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_gemma3_attention(head_size: int, sliding_window: int) -> bool:
|
||||||
|
"""Detect Gemma3 models via unique (head_size, sliding_window) signature.
|
||||||
|
|
||||||
|
Gemma3 models are the only ones using sliding_window=1024 with
|
||||||
|
head_size 128 (27B) or 256 (1B, 4B, 12B). Other SWA models use
|
||||||
|
different window sizes (Mistral=4096, Phi-3=2047).
|
||||||
|
"""
|
||||||
|
return sliding_window == 1024 and head_size in (128, 256)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tile_size(
|
||||||
|
head_size: int,
|
||||||
|
sliding_window: int,
|
||||||
|
element_size: int,
|
||||||
|
is_prefill: bool,
|
||||||
|
) -> int:
|
||||||
|
"""Select tile size with Gemma3-specific optimization.
|
||||||
|
|
||||||
|
For Gemma3, use 32 for both prefill and decode to better utilize
|
||||||
|
the larger head dimension (128/256). For other models, use
|
||||||
|
the default vLLM behavior.
|
||||||
|
"""
|
||||||
|
if _is_gemma3_attention(head_size, sliding_window):
|
||||||
|
# Gemma3: use 32 for decode (default is 16)
|
||||||
|
return 32
|
||||||
|
|
||||||
|
# Default behavior
|
||||||
|
if is_prefill:
|
||||||
|
return 32
|
||||||
|
return 16 if element_size >= 2 else 32
|
||||||
|
|
||||||
|
|
||||||
def unified_attention(
|
def unified_attention(
|
||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
@@ -759,6 +845,8 @@ def unified_attention(
|
|||||||
qq_bias=None,
|
qq_bias=None,
|
||||||
# Optional tensor for sinks
|
# Optional tensor for sinks
|
||||||
sinks=None,
|
sinks=None,
|
||||||
|
# Optional tensor for prefix lengths (PrefixLM support)
|
||||||
|
mm_prefix_range=None,
|
||||||
):
|
):
|
||||||
assert causal, "Only causal attention is supported"
|
assert causal, "Only causal attention is supported"
|
||||||
assert q_descale is None, "Q scales not supported"
|
assert q_descale is None, "Q scales not supported"
|
||||||
@@ -766,6 +854,17 @@ def unified_attention(
|
|||||||
if sinks is not None:
|
if sinks is not None:
|
||||||
assert sinks.shape[0] == q.shape[1], "Sinks must be num_query_heads size"
|
assert sinks.shape[0] == q.shape[1], "Sinks must be num_query_heads size"
|
||||||
|
|
||||||
|
use_mm_prefix = False
|
||||||
|
max_mm_ranges = 0
|
||||||
|
if mm_prefix_range is not None:
|
||||||
|
if mm_prefix_range.ndim == 3:
|
||||||
|
use_mm_prefix = True
|
||||||
|
max_mm_ranges = mm_prefix_range.shape[1]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported mm_prefix_range shape: {mm_prefix_range.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
use_alibi_slopes = alibi_slopes is not None
|
use_alibi_slopes = alibi_slopes is not None
|
||||||
use_qq_bias = qq_bias is not None
|
use_qq_bias = qq_bias is not None
|
||||||
|
|
||||||
@@ -792,11 +891,21 @@ def unified_attention(
|
|||||||
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
|
# = floor(q.shape[0] / BLOCK_Q) + num_seqs
|
||||||
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
|
total_num_q_blocks = q.shape[0] // BLOCK_Q + num_seqs
|
||||||
|
|
||||||
# Assigning default tile sizes for prefill and decode.
|
# Tile sizes for prefill and decode. Gemma3 models use optimized values.
|
||||||
# Note: each tile size must be at least 32 for "fp8" (q.element_size() == 1)
|
# Note: tile size must be at least 32 for fp8 (element_size == 1).
|
||||||
# and at least 16 for all other data types.
|
sliding_window_val = 1 + window_size[0] if window_size[0] >= 0 else 0
|
||||||
TILE_SIZE_PREFILL = 32
|
TILE_SIZE_PREFILL = _get_tile_size(
|
||||||
TILE_SIZE_DECODE = 16 if q.element_size() >= 2 else 32
|
head_size,
|
||||||
|
sliding_window_val,
|
||||||
|
q.element_size(),
|
||||||
|
is_prefill=True,
|
||||||
|
)
|
||||||
|
TILE_SIZE_DECODE = _get_tile_size(
|
||||||
|
head_size,
|
||||||
|
sliding_window_val,
|
||||||
|
q.element_size(),
|
||||||
|
is_prefill=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Launch the 2D kernel if
|
# Launch the 2D kernel if
|
||||||
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
|
# 1. No intermediate tiled softmax buffers for the 3D kernel have been allocated, or
|
||||||
@@ -847,6 +956,9 @@ def unified_attention(
|
|||||||
USE_QQ_BIAS=use_qq_bias,
|
USE_QQ_BIAS=use_qq_bias,
|
||||||
USE_SOFTCAP=(softcap > 0),
|
USE_SOFTCAP=(softcap > 0),
|
||||||
USE_SINKS=(sinks is not None),
|
USE_SINKS=(sinks is not None),
|
||||||
|
USE_MM_PREFIX=use_mm_prefix,
|
||||||
|
MAX_MM_RANGES=max_mm_ranges,
|
||||||
|
mm_prefix_range_ptr=mm_prefix_range,
|
||||||
SLIDING_WINDOW=(1 + window_size[0]),
|
SLIDING_WINDOW=(1 + window_size[0]),
|
||||||
stride_k_cache_0=k.stride(0),
|
stride_k_cache_0=k.stride(0),
|
||||||
stride_k_cache_1=k.stride(1),
|
stride_k_cache_1=k.stride(1),
|
||||||
@@ -895,6 +1007,9 @@ def unified_attention(
|
|||||||
USE_QQ_BIAS=use_qq_bias,
|
USE_QQ_BIAS=use_qq_bias,
|
||||||
USE_SOFTCAP=(softcap > 0),
|
USE_SOFTCAP=(softcap > 0),
|
||||||
USE_SINKS=(sinks is not None),
|
USE_SINKS=(sinks is not None),
|
||||||
|
USE_MM_PREFIX=use_mm_prefix,
|
||||||
|
MAX_MM_RANGES=max_mm_ranges,
|
||||||
|
mm_prefix_range_ptr=mm_prefix_range,
|
||||||
SLIDING_WINDOW=(1 + window_size[0]),
|
SLIDING_WINDOW=(1 + window_size[0]),
|
||||||
stride_k_cache_0=k.stride(0),
|
stride_k_cache_0=k.stride(0),
|
||||||
stride_k_cache_1=k.stride(1),
|
stride_k_cache_1=k.stride(1),
|
||||||
|
|||||||
@@ -915,8 +915,6 @@ class CompilationConfig:
|
|||||||
"mode is CompilationMode.VLLM_COMPILE"
|
"mode is CompilationMode.VLLM_COMPILE"
|
||||||
)
|
)
|
||||||
|
|
||||||
added_default_splitting_ops = False
|
|
||||||
|
|
||||||
if self.pass_config.fuse_attn_quant and not self.use_inductor_graph_partition:
|
if self.pass_config.fuse_attn_quant and not self.use_inductor_graph_partition:
|
||||||
self.set_splitting_ops_for_attn_fusion()
|
self.set_splitting_ops_for_attn_fusion()
|
||||||
else:
|
else:
|
||||||
@@ -930,7 +928,6 @@ class CompilationConfig:
|
|||||||
# for details. Make a copy to avoid mutating the class-level
|
# for details. Make a copy to avoid mutating the class-level
|
||||||
# list via reference.
|
# list via reference.
|
||||||
self.splitting_ops = list(self._attention_ops)
|
self.splitting_ops = list(self._attention_ops)
|
||||||
added_default_splitting_ops = True
|
|
||||||
elif len(self.splitting_ops) == 0:
|
elif len(self.splitting_ops) == 0:
|
||||||
if (
|
if (
|
||||||
self.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
self.cudagraph_mode == CUDAGraphMode.PIECEWISE
|
||||||
@@ -958,44 +955,25 @@ class CompilationConfig:
|
|||||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||||
self.splitting_ops = []
|
self.splitting_ops = []
|
||||||
|
|
||||||
# split MoE ops for cudagraph
|
# Disable CUDA graphs for DeepEP high-throughput since its not CG compatible
|
||||||
moe_ops = [
|
|
||||||
"vllm::moe_forward",
|
|
||||||
"vllm::moe_forward_shared",
|
|
||||||
]
|
|
||||||
backend = all2all_backend or envs.VLLM_ALL2ALL_BACKEND
|
backend = all2all_backend or envs.VLLM_ALL2ALL_BACKEND
|
||||||
dp_size = data_parallel_size if data_parallel_size is not None else 1
|
dp_size = data_parallel_size if data_parallel_size is not None else 1
|
||||||
need_moe_splitting = (
|
if (
|
||||||
backend == "deepep_high_throughput"
|
backend == "deepep_high_throughput"
|
||||||
and dp_size > 1
|
and dp_size > 1
|
||||||
# pure attn-fusion without inductor partition deliberately disables
|
and self.cudagraph_mode != CUDAGraphMode.NONE
|
||||||
# piecewise graphs and MoE splitting.
|
):
|
||||||
and not (
|
# TODO: Piecewise Cuda graph might be enabled
|
||||||
self.pass_config.fuse_attn_quant
|
# if torch compile cache key issue fixed
|
||||||
and not self.use_inductor_graph_partition
|
# See https://github.com/vllm-project/vllm/pull/25093
|
||||||
|
logger.info(
|
||||||
|
"DeepEP: Disabling CUDA Graphs since DeepEP high-throughput kernels "
|
||||||
|
"are optimized for prefill and are incompatible with CUDA Graphs. "
|
||||||
|
"In order to use CUDA Graphs for decode-optimized workloads, "
|
||||||
|
"use --all2all-backend with another option, such as "
|
||||||
|
"deepep_low_latency, pplx, or allgather_reducescatter."
|
||||||
)
|
)
|
||||||
)
|
self.cudagraph_mode = CUDAGraphMode.NONE
|
||||||
|
|
||||||
if need_moe_splitting and self.cudagraph_mode != CUDAGraphMode.NONE:
|
|
||||||
# if we just initialized default splitting_ops for this config,
|
|
||||||
# automatically append the MoE ops
|
|
||||||
if added_default_splitting_ops:
|
|
||||||
for op in moe_ops:
|
|
||||||
if op not in self.splitting_ops:
|
|
||||||
self.splitting_ops.append(op)
|
|
||||||
|
|
||||||
# make sure MoE ops are split out
|
|
||||||
if not any(op in self.splitting_ops for op in moe_ops):
|
|
||||||
self.cudagraph_mode = CUDAGraphMode.NONE
|
|
||||||
logger.warning_once(
|
|
||||||
"DeepEP high throughput backend with data_parallel_size > 1 "
|
|
||||||
"requires splitting MoE ops from cudagraphs. Please ensure "
|
|
||||||
"'vllm::moe_forward' or 'vllm::moe_forward_shared' are "
|
|
||||||
"present in CompilationConfig.splitting_ops."
|
|
||||||
)
|
|
||||||
elif self.cudagraph_mode.has_full_cudagraphs():
|
|
||||||
# fall back to piecewise when MoE splitting is required.
|
|
||||||
self.cudagraph_mode = CUDAGraphMode.PIECEWISE
|
|
||||||
|
|
||||||
def set_splitting_ops_for_attn_fusion(self):
|
def set_splitting_ops_for_attn_fusion(self):
|
||||||
assert self.pass_config.fuse_attn_quant
|
assert self.pass_config.fuse_attn_quant
|
||||||
|
|||||||
@@ -795,7 +795,10 @@ class FusedMoEModularKernel(torch.nn.Module):
|
|||||||
top_k,
|
top_k,
|
||||||
global_num_experts,
|
global_num_experts,
|
||||||
local_num_experts,
|
local_num_experts,
|
||||||
expert_tokens_meta,
|
# expert_tokens_meta help in allocating optimal/minimal
|
||||||
|
# amount of workspace. Mark it None, so we allocate for
|
||||||
|
# the worst-case scenario.
|
||||||
|
expert_tokens_meta=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,10 @@ from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
|
|||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8LinearMethod
|
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8LinearMethod
|
||||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
||||||
|
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||||
|
maybe_create_device_identity,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.parameter import ModelWeightParameter
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
@@ -305,6 +309,37 @@ class XPUFp8LinearMethod(Fp8LinearMethod):
|
|||||||
def __init__(self, quant_config: Fp8Config):
|
def __init__(self, quant_config: Fp8Config):
|
||||||
super().__init__(quant_config)
|
super().__init__(quant_config)
|
||||||
|
|
||||||
|
def create_weights(
|
||||||
|
self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
input_size_per_partition: int,
|
||||||
|
output_partition_sizes: list[int],
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
params_dtype: torch.dtype,
|
||||||
|
**extra_weight_attrs,
|
||||||
|
):
|
||||||
|
maybe_create_device_identity()
|
||||||
|
|
||||||
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||||
|
layer.logical_widths = output_partition_sizes
|
||||||
|
layer.input_size_per_partition = input_size_per_partition
|
||||||
|
layer.output_size_per_partition = output_size_per_partition
|
||||||
|
layer.orig_dtype = params_dtype
|
||||||
|
layer.weight_block_size = None
|
||||||
|
weight = ModelWeightParameter(
|
||||||
|
data=torch.empty(
|
||||||
|
output_size_per_partition,
|
||||||
|
input_size_per_partition,
|
||||||
|
dtype=params_dtype,
|
||||||
|
),
|
||||||
|
input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
)
|
||||||
|
layer.register_parameter("weight", weight)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: Module) -> None:
|
def process_weights_after_loading(self, layer: Module) -> None:
|
||||||
# If checkpoint not serialized fp8, quantize the weights.
|
# If checkpoint not serialized fp8, quantize the weights.
|
||||||
if not self.quant_config.is_checkpoint_fp8_serialized:
|
if not self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
|
|||||||
@@ -264,6 +264,15 @@ class ApplyRotaryEmb(CustomOp):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def forward_cpu(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# TODO (bigPYJ1151): need to enable fused CPU ROPE here
|
||||||
|
return self.forward_native(x, cos, sin)
|
||||||
|
|
||||||
def extra_repr(self) -> str:
|
def extra_repr(self) -> str:
|
||||||
s = f"is_neox_style={self.is_neox_style}"
|
s = f"is_neox_style={self.is_neox_style}"
|
||||||
s += f"enable_fp32_compute={self.enable_fp32_compute}"
|
s += f"enable_fp32_compute={self.enable_fp32_compute}"
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from collections.abc import Iterable
|
|||||||
from itertools import islice
|
from itertools import islice
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import Gemma3TextConfig
|
from transformers import Gemma3TextConfig
|
||||||
|
|
||||||
@@ -223,77 +222,9 @@ class Gemma3Attention(nn.Module):
|
|||||||
|
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v)
|
attn_output = self.attn(q, k, v)
|
||||||
|
|
||||||
if not kwargs.get("has_images", False):
|
|
||||||
# Fast path for text-only inputs. The performance for the text-only
|
|
||||||
# inputs are not affected by the naive attention below.
|
|
||||||
output, _ = self.o_proj(attn_output)
|
|
||||||
return output
|
|
||||||
|
|
||||||
# NOTE(woosuk): Gemma3 uses bidirectional attention between image tokens
|
|
||||||
# that correspond to the same image while using causal attention
|
|
||||||
# otherwise. Current attention backends cannot handle this pattern, so
|
|
||||||
# we temporarily use a naive attention implementation with mask tensors.
|
|
||||||
|
|
||||||
# We intentionally keep the attention backend as-is and only override
|
|
||||||
# `attn_output` with the naive implementation's output. This minimizes
|
|
||||||
# changes to existing model runners and attention backends. The call to
|
|
||||||
# `self.attn(q, k, v)` is only used to populate the KV cache - its
|
|
||||||
# output is discarded and overwritten below. While this duplicates
|
|
||||||
# computation, it maintains compatibility.
|
|
||||||
# TODO(woosuk): Optimize by implementing custom attention kernels.
|
|
||||||
attn_output = self.naive_attn_with_masks(q, k, v, out=attn_output, **kwargs)
|
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def naive_attn_with_masks(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
out: torch.Tensor,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
# NOTE(woosuk): As described in the comment above, this code is not
|
|
||||||
# meant to be performant. It is only meant to be correct.
|
|
||||||
q = q.view(-1, self.num_heads, self.head_dim)
|
|
||||||
# Expand the key and value to handle GQA.
|
|
||||||
num_queries_per_kv = self.num_heads // self.num_kv_heads
|
|
||||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
|
||||||
k = k.repeat_interleave(num_queries_per_kv, dim=-2)
|
|
||||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
|
||||||
v = v.repeat_interleave(num_queries_per_kv, dim=-2)
|
|
||||||
|
|
||||||
if self.is_sliding:
|
|
||||||
attn_masks = kwargs["local_attn_masks"]
|
|
||||||
else:
|
|
||||||
attn_masks = kwargs["global_attn_masks"]
|
|
||||||
|
|
||||||
seq_lens = kwargs["seq_lens"]
|
|
||||||
start_idx = 0
|
|
||||||
for seq_len, attn_mask in zip(seq_lens, attn_masks):
|
|
||||||
end_idx = start_idx + seq_len
|
|
||||||
query = q[start_idx:end_idx].unsqueeze(0)
|
|
||||||
key = k[start_idx:end_idx].unsqueeze(0)
|
|
||||||
value = v[start_idx:end_idx].unsqueeze(0)
|
|
||||||
|
|
||||||
# Transpose.
|
|
||||||
query = query.transpose(1, 2)
|
|
||||||
key = key.transpose(1, 2)
|
|
||||||
value = value.transpose(1, 2)
|
|
||||||
|
|
||||||
output = F.scaled_dot_product_attention(
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
attn_mask,
|
|
||||||
self.scaling,
|
|
||||||
)
|
|
||||||
output = output.transpose(1, 2).flatten(-2, -1)
|
|
||||||
out[start_idx:end_idx] = output
|
|
||||||
start_idx = end_idx
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class Gemma3DecoderLayer(nn.Module):
|
class Gemma3DecoderLayer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
529
vllm/model_executor/models/jais2.py
Normal file
529
vllm/model_executor/models/jais2.py
Normal file
@@ -0,0 +1,529 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Adapted from
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Inference-only Jais2 model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import Jais2Config
|
||||||
|
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
|
from vllm.distributed import (
|
||||||
|
get_pp_group,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.activation import ReLUSquaredActivation
|
||||||
|
from vllm.model_executor.layers.linear import (
|
||||||
|
ColumnParallelLinear,
|
||||||
|
QKVParallelLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE,
|
||||||
|
ParallelLMHead,
|
||||||
|
VocabParallelEmbedding,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
|
default_weight_loader,
|
||||||
|
maybe_remap_kv_scale_name,
|
||||||
|
)
|
||||||
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
|
from .utils import (
|
||||||
|
AutoWeightsLoader,
|
||||||
|
PPMissingLayer,
|
||||||
|
extract_layer_index,
|
||||||
|
is_pp_missing_parameter,
|
||||||
|
make_empty_intermediate_tensors_factory,
|
||||||
|
make_layers,
|
||||||
|
maybe_prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Jais2MLP(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
hidden_act: str,
|
||||||
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
bias: bool = False,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.up_proj = ColumnParallelLinear(
|
||||||
|
input_size=hidden_size,
|
||||||
|
output_size=intermediate_size,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.up_proj",
|
||||||
|
)
|
||||||
|
self.down_proj = RowParallelLinear(
|
||||||
|
input_size=intermediate_size,
|
||||||
|
output_size=hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
)
|
||||||
|
self.act_fn = ReLUSquaredActivation()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x, _ = self.up_proj(x)
|
||||||
|
x = self.act_fn(x)
|
||||||
|
x, _ = self.down_proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Jais2Attention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: Jais2Config,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
bias: bool = False,
|
||||||
|
cache_config: CacheConfig | None = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
layer_idx = extract_layer_index(prefix)
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.total_num_heads = num_heads
|
||||||
|
assert self.total_num_heads % tp_size == 0
|
||||||
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
|
self.total_num_kv_heads = num_kv_heads
|
||||||
|
if self.total_num_kv_heads >= tp_size:
|
||||||
|
# Number of KV heads is greater than TP size, so we partition
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
|
else:
|
||||||
|
# Number of KV heads is less than TP size, so we replicate
|
||||||
|
# the KV heads across multiple tensor parallel GPUs.
|
||||||
|
assert tp_size % self.total_num_kv_heads == 0
|
||||||
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||||
|
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
||||||
|
self.head_dim = getattr(
|
||||||
|
config, "head_dim", self.hidden_size // self.total_num_heads
|
||||||
|
)
|
||||||
|
self.q_size = self.num_heads * self.head_dim
|
||||||
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
|
self.scaling = self.head_dim**-0.5
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
|
self.qkv_proj = QKVParallelLinear(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
head_size=self.head_dim,
|
||||||
|
total_num_heads=self.total_num_heads,
|
||||||
|
total_num_kv_heads=self.total_num_kv_heads,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.qkv_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.o_proj = RowParallelLinear(
|
||||||
|
input_size=self.total_num_heads * self.head_dim,
|
||||||
|
output_size=hidden_size,
|
||||||
|
bias=bias,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
)
|
||||||
|
|
||||||
|
is_neox_style = True
|
||||||
|
if quant_config is not None and quant_config.get_name() == "gguf":
|
||||||
|
is_neox_style = False
|
||||||
|
|
||||||
|
self.rotary_emb = get_rope(
|
||||||
|
self.head_dim,
|
||||||
|
rotary_dim=self.head_dim,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
rope_parameters=getattr(config, "rope_parameters", None),
|
||||||
|
is_neox_style=is_neox_style,
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(config, "interleaved_sliding_window"):
|
||||||
|
interleaved_sliding_window = config.interleaved_sliding_window
|
||||||
|
if isinstance(interleaved_sliding_window, int):
|
||||||
|
sliding_window = interleaved_sliding_window
|
||||||
|
elif isinstance(interleaved_sliding_window, list):
|
||||||
|
sw_idx = layer_idx % len(interleaved_sliding_window)
|
||||||
|
sliding_window = interleaved_sliding_window[sw_idx]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"{type(interleaved_sliding_window)} is not supported."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sliding_window = None
|
||||||
|
|
||||||
|
self.attn = Attention(
|
||||||
|
self.num_heads,
|
||||||
|
self.head_dim,
|
||||||
|
self.scaling,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
per_layer_sliding_window=sliding_window,
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
|
attn_output = self.attn(q, k, v)
|
||||||
|
output, _ = self.o_proj(attn_output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class Jais2DecoderLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
config: Jais2Config,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
config = config or vllm_config.model_config.hf_config
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
quant_config = self.get_quant_config(vllm_config)
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||||
|
# Support abacusai/Smaug-72B-v0.1 with attention_bias
|
||||||
|
# Support internlm/internlm-7b with bias
|
||||||
|
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
||||||
|
config, "bias", False
|
||||||
|
)
|
||||||
|
self.self_attn = Jais2Attention(
|
||||||
|
config=config,
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
num_heads=config.num_attention_heads,
|
||||||
|
num_kv_heads=getattr(
|
||||||
|
config, "num_key_value_heads", config.num_attention_heads
|
||||||
|
),
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
quant_config=quant_config,
|
||||||
|
bias=attention_bias,
|
||||||
|
cache_config=cache_config,
|
||||||
|
prefix=f"{prefix}.self_attn",
|
||||||
|
)
|
||||||
|
self.mlp = Jais2MLP(
|
||||||
|
hidden_size=self.hidden_size,
|
||||||
|
intermediate_size=config.intermediate_size,
|
||||||
|
hidden_act=config.hidden_act,
|
||||||
|
quant_config=quant_config,
|
||||||
|
bias=getattr(config, "mlp_bias", False),
|
||||||
|
prefix=f"{prefix}.mlp",
|
||||||
|
)
|
||||||
|
self.input_layernorm = nn.LayerNorm(
|
||||||
|
config.hidden_size, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = nn.LayerNorm(
|
||||||
|
config.hidden_size, eps=config.layer_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
residual: torch.Tensor | None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# Self Attention
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = (
|
||||||
|
self.input_layernorm(hidden_states + residual),
|
||||||
|
hidden_states + residual,
|
||||||
|
)
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
positions=positions,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
hidden_states, residual = (
|
||||||
|
self.post_attention_layernorm(hidden_states + residual),
|
||||||
|
hidden_states + residual,
|
||||||
|
)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
return hidden_states, residual
|
||||||
|
|
||||||
|
def get_quant_config(self, vllm_config: VllmConfig) -> QuantizationConfig | None:
|
||||||
|
"""Get quantization config for this layer. Override in subclasses."""
|
||||||
|
return vllm_config.quant_config
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
|
class Jais2Model(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
prefix: str = "",
|
||||||
|
layer_type: type[nn.Module] = Jais2DecoderLayer,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
lora_config = vllm_config.lora_config
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.padding_idx = config.pad_token_id
|
||||||
|
lora_vocab = (
|
||||||
|
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
||||||
|
if lora_config
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
self.vocab_size = config.vocab_size + lora_vocab
|
||||||
|
self.org_vocab_size = config.vocab_size
|
||||||
|
if get_pp_group().is_first_rank or (
|
||||||
|
config.tie_word_embeddings and get_pp_group().is_last_rank
|
||||||
|
):
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
self.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.embed_tokens = PPMissingLayer()
|
||||||
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
|
config.num_hidden_layers,
|
||||||
|
lambda prefix: layer_type(
|
||||||
|
config=config,
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
prefix=prefix,
|
||||||
|
),
|
||||||
|
prefix=f"{prefix}.layers",
|
||||||
|
)
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
else:
|
||||||
|
self.norm = PPMissingLayer()
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
|
||||||
|
["hidden_states", "residual"], config.hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor | None,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: IntermediateTensors | None,
|
||||||
|
inputs_embeds: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
|
||||||
|
if get_pp_group().is_first_rank:
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
else:
|
||||||
|
hidden_states = self.embed_input_ids(input_ids)
|
||||||
|
residual = None
|
||||||
|
else:
|
||||||
|
assert intermediate_tensors is not None
|
||||||
|
hidden_states = intermediate_tensors["hidden_states"]
|
||||||
|
residual = intermediate_tensors["residual"]
|
||||||
|
|
||||||
|
for i in range(self.start_layer, self.end_layer):
|
||||||
|
layer = self.layers[i]
|
||||||
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
|
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
return IntermediateTensors(
|
||||||
|
{"hidden_states": hidden_states, "residual": residual}
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states, _ = self.norm(hidden_states + residual), residual
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
|
]
|
||||||
|
params_dict = dict(self.named_parameters())
|
||||||
|
loaded_params: set[str] = set()
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "rotary_emb.inv_freq" in name:
|
||||||
|
continue
|
||||||
|
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||||
|
# Models trained using ColossalAI may include these tensors in
|
||||||
|
# the checkpoint. Skip them.
|
||||||
|
continue
|
||||||
|
if self.quant_config is not None and (
|
||||||
|
scale_name := self.quant_config.get_cache_scale(name)
|
||||||
|
):
|
||||||
|
# Loading kv cache scales for compressed-tensors quantization
|
||||||
|
param = params_dict[scale_name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
loaded_weight = loaded_weight[0]
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(scale_name)
|
||||||
|
continue
|
||||||
|
if "scale" in name:
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
|
if weight_name not in name:
|
||||||
|
continue
|
||||||
|
name = name.replace(weight_name, param_name)
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = param.weight_loader
|
||||||
|
weight_loader(param, loaded_weight, shard_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Skip loading extra bias for GPTQ models.
|
||||||
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
|
continue
|
||||||
|
# Remapping the name of FP8 kv-scale.
|
||||||
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||||
|
if name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
|
|
||||||
|
param = params_dict[name]
|
||||||
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||||
|
weight_loader(param, loaded_weight)
|
||||||
|
loaded_params.add(name)
|
||||||
|
return loaded_params
|
||||||
|
|
||||||
|
|
||||||
|
class Jais2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||||
|
}
|
||||||
|
|
||||||
|
embedding_modules = {
|
||||||
|
"embed_tokens": "input_embeddings",
|
||||||
|
"lm_head": "output_embeddings",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
super().__init__()
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
lora_config = vllm_config.lora_config
|
||||||
|
self.config = config
|
||||||
|
self.lora_config = lora_config
|
||||||
|
|
||||||
|
self.model = self._init_model(
|
||||||
|
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
||||||
|
)
|
||||||
|
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
|
if lora_config:
|
||||||
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
self.unpadded_vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
padding_size=(
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE
|
||||||
|
# We need bigger padding if using lora for kernel
|
||||||
|
# compatibility
|
||||||
|
if not lora_config
|
||||||
|
else lora_config.lora_vocab_padding_size
|
||||||
|
),
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, "lm_head"),
|
||||||
|
)
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
|
||||||
|
|
||||||
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||||
|
self.logits_processor = LogitsProcessor(
|
||||||
|
self.unpadded_vocab_size, config.vocab_size, logit_scale
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.lm_head = PPMissingLayer()
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
self.model.make_empty_intermediate_tensors
|
||||||
|
)
|
||||||
|
|
||||||
|
def _init_model(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
return Jais2Model(vllm_config=vllm_config, prefix=prefix)
|
||||||
|
|
||||||
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.model.embed_input_ids(input_ids)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
intermediate_tensors: IntermediateTensors | None = None,
|
||||||
|
inputs_embeds: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor | IntermediateTensors:
|
||||||
|
model_output = self.model(
|
||||||
|
input_ids, positions, intermediate_tensors, inputs_embeds
|
||||||
|
)
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
) -> torch.Tensor | None:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||||
|
loader = AutoWeightsLoader(
|
||||||
|
self,
|
||||||
|
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
|
||||||
|
)
|
||||||
|
return loader.load_weights(weights)
|
||||||
@@ -127,6 +127,7 @@ _TEXT_GENERATION_MODELS = {
|
|||||||
"InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
|
"InternLM2VEForCausalLM": ("internlm2_ve", "InternLM2VEForCausalLM"),
|
||||||
"InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
|
"InternLM3ForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||||
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
|
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
|
||||||
|
"Jais2ForCausalLM": ("jais2", "Jais2ForCausalLM"),
|
||||||
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
|
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
|
||||||
"KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"), # noqa: E501
|
"KimiLinearForCausalLM": ("kimi_linear", "KimiLinearForCausalLM"), # noqa: E501
|
||||||
"Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
|
"Lfm2ForCausalLM": ("lfm2", "Lfm2ForCausalLM"),
|
||||||
|
|||||||
@@ -318,15 +318,15 @@ def patch_rope_parameters(config: PretrainedConfig) -> None:
|
|||||||
# Transformers v4 installed, legacy config fields may be present
|
# Transformers v4 installed, legacy config fields may be present
|
||||||
if (rope_scaling := getattr(config, "rope_scaling", None)) is not None:
|
if (rope_scaling := getattr(config, "rope_scaling", None)) is not None:
|
||||||
config.rope_parameters = rope_scaling
|
config.rope_parameters = rope_scaling
|
||||||
|
if (
|
||||||
|
rope_theta is not None or partial_rotary_factor is not None
|
||||||
|
) and not getattr(config, "rope_parameters", None):
|
||||||
|
config.rope_parameters = {"rope_type": "default"}
|
||||||
if rope_theta is not None:
|
if rope_theta is not None:
|
||||||
if not hasattr(config, "rope_parameters"):
|
|
||||||
config.rope_parameters = {"rope_type": "default"}
|
|
||||||
config.rope_parameters["rope_theta"] = rope_theta
|
config.rope_parameters["rope_theta"] = rope_theta
|
||||||
if partial_rotary_factor is not None:
|
if partial_rotary_factor is not None:
|
||||||
if not hasattr(config, "rope_parameters"):
|
|
||||||
config.rope_parameters = {"rope_type": "default"}
|
|
||||||
config.rope_parameters["partial_rotary_factor"] = partial_rotary_factor
|
config.rope_parameters["partial_rotary_factor"] = partial_rotary_factor
|
||||||
elif rope_theta is not None or hasattr(config, "rope_parameters"):
|
elif rope_theta is not None or getattr(config, "rope_parameters", None):
|
||||||
# Transformers v5 installed
|
# Transformers v5 installed
|
||||||
# Patch these fields in case they used non-standard names
|
# Patch these fields in case they used non-standard names
|
||||||
if rope_theta is not None:
|
if rope_theta is not None:
|
||||||
|
|||||||
@@ -160,7 +160,7 @@ def physical_to_logical_mapping(
|
|||||||
└───────────────────────────────────────────┘
|
└───────────────────────────────────────────┘
|
||||||
|
|
||||||
If multiple logical blocks map to the same physical block,
|
If multiple logical blocks map to the same physical block,
|
||||||
this function returns the first (minimum) logical block index.
|
this function returns the latest (maximum) logical block index.
|
||||||
|
|
||||||
If a physical block is not mapped to by any logical block,
|
If a physical block is not mapped to by any logical block,
|
||||||
its value in the result will be -1.
|
its value in the result will be -1.
|
||||||
@@ -183,6 +183,15 @@ def physical_to_logical_mapping(
|
|||||||
To prevent this, we use seq_lens and block_size to mask out unused
|
To prevent this, we use seq_lens and block_size to mask out unused
|
||||||
entries, ensuring only valid block references are processed.
|
entries, ensuring only valid block references are processed.
|
||||||
|
|
||||||
|
IMPORTANT: Reused physical blocks (sliding-window / hybrid attention)
|
||||||
|
────────────────────────────────────────────────────────────────────
|
||||||
|
For some attention types, physical cache blocks can be reused over time.
|
||||||
|
This can cause the same physical block id to appear multiple times in a row
|
||||||
|
of `block_table` at different logical block indices. In that case, only the
|
||||||
|
latest logical block index corresponds to the current contents of that
|
||||||
|
physical block. Therefore, the inverse mapping must pick the maximum logical
|
||||||
|
block index for each physical block id.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
block_table: Tensor of shape [max_reqs, max_num_blocks]
|
block_table: Tensor of shape [max_reqs, max_num_blocks]
|
||||||
mapping logical blocks to physical locations. May contain
|
mapping logical blocks to physical locations. May contain
|
||||||
@@ -217,8 +226,8 @@ def physical_to_logical_mapping(
|
|||||||
mask, torch.arange(max_num_blocks, device=device)[None, :], 0
|
mask, torch.arange(max_num_blocks, device=device)[None, :], 0
|
||||||
)
|
)
|
||||||
|
|
||||||
physical_to_logical.scatter_(
|
physical_to_logical.scatter_reduce_(
|
||||||
-1, valid_block_table.to(torch.int64), valid_logical_indices
|
-1, valid_block_table.to(torch.int64), valid_logical_indices, reduce="amax"
|
||||||
)
|
)
|
||||||
# NB - Seems like block 0 is always empty so we reset it manually
|
# NB - Seems like block 0 is always empty so we reset it manually
|
||||||
physical_to_logical[:, 0] = -1
|
physical_to_logical[:, 0] = -1
|
||||||
|
|||||||
@@ -76,6 +76,39 @@ class TritonAttentionMetadata:
|
|||||||
# Optional aot scheduling
|
# Optional aot scheduling
|
||||||
scheduler_metadata: torch.Tensor | None = None
|
scheduler_metadata: torch.Tensor | None = None
|
||||||
prefix_scheduler_metadata: torch.Tensor | None = None
|
prefix_scheduler_metadata: torch.Tensor | None = None
|
||||||
|
mm_prefix_range: dict[int, list[tuple[int, int]]] | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mm_prefix_range_tensor(self) -> torch.Tensor | None:
|
||||||
|
"""Convert mm_prefix_range dict to padded tensor for Triton kernel.
|
||||||
|
|
||||||
|
Returns shape: (num_seqs, max_ranges, 2) with 0-padding for empty ranges.
|
||||||
|
Empty ranges have start==end==0, which kernel skips via is_valid check.
|
||||||
|
"""
|
||||||
|
# TODO(Isotr0py): Move to model runner's attention metadata
|
||||||
|
# preparation to avoid duplicate computation.
|
||||||
|
if self.mm_prefix_range is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
num_seqs = self.seq_lens.shape[0]
|
||||||
|
device = self.seq_lens.device
|
||||||
|
|
||||||
|
# Collect ranges, using [(0,0)] for empty sequences to ensure uniform dims
|
||||||
|
range_lists = [
|
||||||
|
self.mm_prefix_range.get(i, [(0, 0)]) or [(0, 0)] for i in range(num_seqs)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Return None if all ranges are trivial (only (0,0) placeholders)
|
||||||
|
if all(r == [(0, 0)] for r in range_lists):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create 2D tensors with shape (num_ranges, 2) for each sequence
|
||||||
|
range_tensors = [
|
||||||
|
torch.tensor(r, dtype=torch.int32, device=device).view(-1, 2)
|
||||||
|
for r in range_lists
|
||||||
|
]
|
||||||
|
|
||||||
|
return torch.nested.nested_tensor(range_tensors).to_padded_tensor(0)
|
||||||
|
|
||||||
|
|
||||||
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
|
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||||
@@ -268,6 +301,10 @@ class TritonAttentionBackend(AttentionBackend):
|
|||||||
def supports_head_size(cls, head_size: int) -> bool:
|
def supports_head_size(cls, head_size: int) -> bool:
|
||||||
return head_size >= 32
|
return head_size >= 32
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def supports_mm_prefix(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def supports_sink(cls) -> bool:
|
def supports_sink(cls) -> bool:
|
||||||
return True
|
return True
|
||||||
@@ -427,6 +464,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
|
softmax_segm_expsum = attn_metadata.softmax_segm_expsum
|
||||||
|
|
||||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
|
descale_shape = (cu_seqlens_q.shape[0] - 1, key_cache.shape[2])
|
||||||
|
mm_prefix_range_tensor = attn_metadata.mm_prefix_range_tensor
|
||||||
|
|
||||||
unified_attention(
|
unified_attention(
|
||||||
q=query[:num_actual_tokens],
|
q=query[:num_actual_tokens],
|
||||||
@@ -453,6 +491,7 @@ class TritonAttentionImpl(AttentionImpl):
|
|||||||
softmax_segm_expsum=softmax_segm_expsum,
|
softmax_segm_expsum=softmax_segm_expsum,
|
||||||
sinks=self.sinks,
|
sinks=self.sinks,
|
||||||
output_scale=output_scale,
|
output_scale=output_scale,
|
||||||
|
mm_prefix_range=mm_prefix_range_tensor,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|||||||
@@ -145,12 +145,20 @@ class WorkspaceManager:
|
|||||||
|
|
||||||
for ubatch_id in range(self._num_ubatches):
|
for ubatch_id in range(self._num_ubatches):
|
||||||
current_workspace = self._current_workspaces[ubatch_id]
|
current_workspace = self._current_workspaces[ubatch_id]
|
||||||
if current_workspace is None:
|
if (
|
||||||
|
current_workspace is None
|
||||||
|
or self._workspace_size_bytes(current_workspace) < required_bytes
|
||||||
|
):
|
||||||
|
# Delete old tensor before allocating new one to avoid
|
||||||
|
# memory spike from resize_(). resize_() allocates new
|
||||||
|
# memory before freeing old, which can cause OOM.
|
||||||
|
# Must clear the list reference first since local var
|
||||||
|
# is just a copy of the reference.
|
||||||
|
self._current_workspaces[ubatch_id] = None
|
||||||
|
del current_workspace
|
||||||
self._current_workspaces[ubatch_id] = torch.empty(
|
self._current_workspaces[ubatch_id] = torch.empty(
|
||||||
(required_bytes,), dtype=torch.uint8, device=self._device
|
(required_bytes,), dtype=torch.uint8, device=self._device
|
||||||
)
|
)
|
||||||
elif self._workspace_size_bytes(current_workspace) < required_bytes:
|
|
||||||
current_workspace.resize_(required_bytes)
|
|
||||||
|
|
||||||
if envs.VLLM_DEBUG_WORKSPACE:
|
if envs.VLLM_DEBUG_WORKSPACE:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
Reference in New Issue
Block a user