Compare commits

...

8 Commits

Author SHA1 Message Date
Or Ozeri
fe18ce4d3f Revert "Enable Cross layers KV cache layout at NIXL Connector (#30207)" (#33241)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: Kevin H. Luu <khluu000@gmail.com>
(cherry picked from commit 2e8de86777)
2026-01-28 11:44:59 -08:00
Jeffrey Wang
5f7f9ea884 Relax protobuf library version constraints (#33202)
Signed-off-by: Jeffrey Wang <jeffreywang@anyscale.com>
(cherry picked from commit a97b5e206d)
2026-01-28 02:17:19 -08:00
Nick Hill
7779de34da [BugFix] Fix P/D with non-MoE DP (#33037)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
(cherry picked from commit 0cd259b2d8)
2026-01-28 02:17:08 -08:00
Nicolò Lucchesi
0d8ce320a2 [Bugfix] Fix DeepseekV32 AssertionError: num_kv_heads == 1 (#33090)
Signed-off-by: NickLucche <nlucches@redhat.com>
(cherry picked from commit 492a7983dd)
2026-01-28 02:16:56 -08:00
Nicolò Lucchesi
d51e1f8b62 [Bugfix] Disable CG for Whisper+FA2 (#33164)
Signed-off-by: NickLucche <nlucches@redhat.com>
(cherry picked from commit 1f3a2c2944)
2026-01-28 02:16:41 -08:00
Roger Wang
5042815ab6 [Models] Kimi-K2.5 (#33131)
Signed-off-by: wanglinian <wanglinian@stu.pku.edu.cn>
Signed-off-by: wangln19 <96399074+wangln19@users.noreply.github.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: wanglinian <wanglinian@stu.pku.edu.cn>
Co-authored-by: wangln19 <96399074+wangln19@users.noreply.github.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Nick Hill <nickhill123@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
(cherry picked from commit b539f988e1)
2026-01-28 02:16:28 -08:00
Chauncey
afb390ab02 [CI] Fix AssertionError: MCP tool call not found in output_messages (#33093)
Signed-off-by: chaunceyjiang <chaunceyjiang@gmail.com>
(cherry picked from commit a2393ed496)
2026-01-28 02:16:14 -08:00
Robert Shaw
cf1167e50b [Bugfix] Fix Dtypes for Pynccl Wrapper (#33030)
Signed-off-by: Robert Shaw <robshaw@redhat.com>
Co-authored-by: Robert Shaw <robshaw@redhat.com>
(cherry picked from commit 43a013c3a2)
2026-01-26 12:37:16 -08:00
30 changed files with 1941 additions and 336 deletions

View File

@@ -184,15 +184,6 @@ Support use case: Prefill with 'HND' and decode with 'NHD' with experimental con
--kv-transfer-config '{..., "enable_permute_local_kv":"True"}'
```
### Cross layers blocks
By default, this feature is disabled. On attention backends that support this feature, each logical block is contiguous in physical memory. This reduces the number of buffers that need to be transferred.
To enable this feature:
```bash
--kv-transfer-config '{..., "kv_connector_extra_config": {"enable_cross_layers_blocks": "True"}}'
```
## Example Scripts/Code
Refer to these example scripts in the vLLM repository:

View File

@@ -686,6 +686,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | ✅︎ | ✅︎ |
| `KeyeVL1_5ForConditionalGeneration` | Keye-VL-1_5-8B | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-1_5-8B` | ✅︎ | ✅︎ |
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | ✅︎ |
| `KimiK25ForConditionalGeneration` | Kimi-K2.5 | T + I<sup>+</sup> | `moonshotai/Kimi-K2.5` | | ✅︎ |
| `LightOnOCRForConditionalGeneration` | LightOnOCR-1B | T + I<sup>+</sup> | `lightonai/LightOnOCR-1B`, etc | ✅︎ | ✅︎ |
| `Lfm2VlForConditionalGeneration` | LFM2-VL | T + I<sup>+</sup> | `LiquidAI/LFM2-VL-450M`, `LiquidAI/LFM2-VL-3B`, `LiquidAI/LFM2-VL-8B-A1B`, etc. | ✅︎ | ✅︎ |
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | ✅︎ | ✅︎ |

View File

@@ -9,7 +9,7 @@ requires = [
"torch == 2.9.1",
"wheel",
"jinja2",
"grpcio-tools>=1.76.0",
"grpcio-tools",
]
build-backend = "setuptools.build_meta"

View File

@@ -9,5 +9,5 @@ wheel
jinja2>=3.1.6
regex
build
protobuf>=6.33.2
grpcio-tools>=1.76.0
protobuf
grpcio-tools

View File

@@ -9,7 +9,7 @@ blake3
py-cpuinfo
transformers >= 4.56.0, < 5
tokenizers >= 0.21.1 # Required for fast incremental detokenization.
protobuf >= 6.30.0 # Required by LlamaTokenizer, gRPC.
protobuf # Required by LlamaTokenizer, gRPC.
fastapi[standard] >= 0.115.0 # Required by FastAPI's form models in the OpenAI API server's audio transcriptions endpoint.
aiohttp
openai >= 1.99.1 # For Responses API with reasoning content
@@ -51,5 +51,5 @@ openai-harmony >= 0.0.3 # Required for gpt-oss
anthropic >= 0.71.0
model-hosting-container-standards >= 0.1.13, < 1.0.0
mcp
grpcio>=1.76.0
grpcio-reflection>=1.76.0
grpcio
grpcio-reflection

View File

@@ -992,7 +992,7 @@ async def test_mcp_tool_multi_turn(client: OpenAI, model_name: str, server):
# First turn - make a calculation
response1 = await client.responses.create(
model=model_name,
input="Calculate 123 * 456 using python and print the result.",
input="Calculate 1234 * 4567 using python tool and print the result.",
tools=tools,
temperature=0.0,
instructions=(

View File

@@ -771,6 +771,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
)
},
),
"KimiK25ForConditionalGeneration": _HfExamplesInfo(
"moonshotai/Kimi-K2.5",
trust_remote_code=True,
is_available_online=False,
),
"LightOnOCRForConditionalGeneration": _HfExamplesInfo(
"lightonai/LightOnOCR-1B-1025"
),

View File

@@ -34,18 +34,11 @@ else
KV_CONFIG_HETERO_LAYOUT=''
fi
CROSS_LAYERS_BLOCKS=${CROSS_LAYERS_BLOCKS:-"False"} # Default to non cross layers
if [[ "$CROSS_LAYERS_BLOCKS" == "True" ]]; then
KV_EXTRA_CONFIG=',"kv_connector_extra_config":{"cross_layers_blocks": "True"}'
else
KV_EXTRA_CONFIG=''
fi
# Build the kv-transfer-config once
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}'
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}'
else
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}"
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}"
fi
# Models to run

View File

@@ -18,12 +18,8 @@ import ray
import torch
from vllm import LLM
from vllm.config import KVTransferConfig, set_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.utils import (
KVOutputAggregator,
TpKVTopology,
get_current_attn_backend,
)
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.distributed.kv_transfer.kv_connector.v1 import nixl_connector
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
@@ -52,11 +48,8 @@ from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheTensor
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import RequestStatus
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
from vllm.v1.worker.utils import AttentionGroup
from .utils import create_request, create_scheduler, create_vllm_config
@@ -373,7 +366,6 @@ def test_kv_transfer_handshake(dist_init):
# Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
decode_connector.register_kv_caches(kv_caches)
# Here we are testing the retrieval of NIXLAgentMetadata.
# Knowing the implementation detail, we override the add_remote_agent
@@ -410,23 +402,6 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
self.kv_cache_layout = kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
self.src_xfer_handles_by_block_size = {self.block_size: 1}
test_shape = self.attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=self.attn_backend,
tensor_shape=test_shape,
)
self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks
)
def _nixl_handshake(
self, host: str, port: int, remote_tp_size: int, expected_engine_id: str
@@ -1395,7 +1370,6 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
),
),
"TRITON_ATTN",
"FLASHINFER",
],
)
def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
@@ -1412,11 +1386,6 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
vllm_config = create_vllm_config(attention_backend=attn_backend)
# Enable cross layers blocks
vllm_config.kv_transfer_config.kv_connector_extra_config[
"enable_cross_layers_blocks"
] = True
# Import the appropriate backend based on the parameter
if attn_backend == "FLASH_ATTN":
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
@@ -1426,11 +1395,49 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend
backend_cls = RocmAttentionBackend
else: # TRITON
else: # TRITON_ATTN
from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend
backend_cls = TritonAttentionBackend
# Create test kv cache tensors using proper backend shape
kv_cache_shape = backend_cls.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
# Store tensor info for validation
test_shape = backend_cls.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1
if is_blocks_first:
expected_tensor_size = shared_tensor.element_size() * shared_tensor.numel()
expected_base_addrs = [
shared_tensor.data_ptr(),
unique_tensor.data_ptr(),
]
expected_num_entries = 2
else:
expected_tensor_size = (
shared_tensor[0].element_size() * shared_tensor[0].numel()
)
expected_base_addrs = [
shared_tensor[0].data_ptr(),
shared_tensor[1].data_ptr(),
unique_tensor[0].data_ptr(),
unique_tensor[1].data_ptr(),
]
expected_num_entries = 4
nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector"
with (
patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper,
@@ -1459,107 +1466,6 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
# Reassure the shutdown() check that the thread is terminated
mock_thread.return_value.is_alive.return_value = False
expected_tensor_size: int
expected_base_addrs: list[int]
expected_num_entries: int
kv_caches: dict[str, torch.Tensor]
if connector.prefer_cross_layer_blocks:
num_layers = 32
block_size = 16
num_blocks = 8
kv_cache_spec = AttentionSpec(
block_size=block_size,
num_kv_heads=4,
head_size=64,
dtype=torch.bfloat16,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[
KVCacheTensor(
size=kv_cache_spec.page_size_bytes * num_blocks,
shared_by=["dummy-layer"],
)
for i in range(num_layers)
],
# allocate_uniform_kv_caches does not use this
kv_cache_groups=[],
)
with set_current_vllm_config(vllm_config):
_, cross_layers_kv_cache, _ = (
KVConnectorModelRunnerMixin.allocate_uniform_kv_caches(
kv_cache_config=kv_cache_config,
attn_groups=[
[
AttentionGroup(
backend=backend_cls,
layer_names=[],
kv_cache_spec=kv_cache_spec,
kv_cache_group_id=0,
)
]
],
cache_dtype=torch.bfloat16,
device=torch.cuda.current_device(),
kernel_block_sizes=[block_size],
)
)
# Store tensor info for validation
expected_tensor_size = (
cross_layers_kv_cache.element_size() * cross_layers_kv_cache.numel()
)
expected_base_addrs = [
cross_layers_kv_cache.data_ptr(),
]
expected_num_entries = 1
expected_blocks_count = 8
kv_caches = {"all-layers": cross_layers_kv_cache}
else:
# Create test kv cache tensors using proper backend shape
kv_cache_shape = backend_cls.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
# Store tensor info for validation
test_shape = backend_cls.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1
if is_blocks_first:
expected_tensor_size = (
shared_tensor.element_size() * shared_tensor.numel()
)
expected_base_addrs = [
shared_tensor.data_ptr(),
unique_tensor.data_ptr(),
]
expected_num_entries = 2
else:
expected_tensor_size = (
shared_tensor[0].element_size() * shared_tensor[0].numel()
)
expected_base_addrs = [
shared_tensor[0].data_ptr(),
shared_tensor[1].data_ptr(),
unique_tensor[0].data_ptr(),
unique_tensor[1].data_ptr(),
]
expected_num_entries = 4
expected_blocks_count = 8
# Execute register_kv_caches
connector.register_kv_caches(kv_caches)
@@ -1583,19 +1489,16 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0]
# Validate blocks_data structure and size
expected_blocks_count = 8
assert len(blocks_data) == expected_blocks_count, (
f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}"
)
if connector.prefer_cross_layer_blocks:
num_blocks = 8
expected_block_len = expected_tensor_size // num_blocks
num_blocks = 2
if is_blocks_first:
expected_block_len = expected_tensor_size // num_blocks // 2
else:
num_blocks = 2
if is_blocks_first:
expected_block_len = expected_tensor_size // num_blocks // 2
else:
expected_block_len = expected_tensor_size // num_blocks
expected_block_len = expected_tensor_size // num_blocks
for i, block_entry in enumerate(blocks_data):
block_start_addr, block_len, tp_rank = block_entry
@@ -2146,17 +2049,6 @@ def test_compatibility_hash_validation(
)
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
decode_worker = decode_connector.connector_worker
kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
decode_connector.register_kv_caches(kv_caches)
remote_config_params: dict[str, Any] = {
"model": "facebook/opt-125m",
@@ -2179,9 +2071,7 @@ def test_compatibility_hash_validation(
)
)
remote_hash = compute_nixl_compatibility_hash(
remote_vllm_config,
decode_worker.backend_name,
decode_worker.kv_topo.cross_layers_blocks,
remote_vllm_config, decode_worker.backend_name
)
prefill_block_size = config_overrides.get("block_size", 16)
@@ -2260,27 +2150,6 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
decode_worker = decode_connector.connector_worker
backend = get_current_attn_backend(local_vllm_config)
test_shape = backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
decode_worker.kv_topo = TpKVTopology(
tp_rank=decode_worker.tp_rank,
engine_id=decode_worker.engine_id,
remote_tp_size=decode_worker._tp_size, # shared state
remote_block_size=decode_worker._block_size, # shared state
is_mla=decode_worker.use_mla,
total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(),
attn_backend=backend,
tensor_shape=test_shape,
)
decode_worker.compat_hash = compute_nixl_compatibility_hash(
decode_worker.vllm_config,
decode_worker.backend_name,
decode_worker.kv_topo.cross_layers_blocks,
)
if error_scenario == "handshake_decode_error":
msg_bytes = b"this is not valid msgpack data"
elif error_scenario == "handshake_validation_error":

View File

@@ -72,7 +72,8 @@ class ncclDataTypeEnum:
ncclFloat64 = 8
ncclDouble = 8
ncclBfloat16 = 9
ncclNumTypes = 10
ncclFloat8e4m3 = 10
ncclNumTypes = 11
@classmethod
def from_torch(cls, dtype: torch.dtype) -> int:
@@ -92,9 +93,12 @@ class ncclDataTypeEnum:
return cls.ncclFloat64
if dtype == torch.bfloat16:
return cls.ncclBfloat16
if dtype == torch.float8_e4m3fn:
return cls.ncclFloat8e4m3
raise ValueError(
f"Unsupported dtype {dtype}: should be one of "
f"int8, uint8, int32, int64, float16, float32, float64, bfloat16."
f"int8, uint8, int32, int64, float16, float32, float64, bfloat16,"
" float8e4m3."
)

View File

@@ -316,13 +316,12 @@ class TpKVTopology:
attn_backend: type[AttentionBackend]
engine_id: EngineId
remote_block_size: dict[EngineId, int]
tensor_shape: torch.Size | None = None
def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V
# or num_blocks. This is used to register the memory regions correctly.
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=4, head_size=1
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
# Non-MLA backends caches have 5 dims [2, num_blocks, H,N,D],
# we just mock num_blocks to 1 for the dimension check below.
@@ -330,32 +329,6 @@ class TpKVTopology:
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
)
self._kv_heads_position: int | None = None
self._cross_layers_blocks = False
if self.tensor_shape is not None:
self._cross_layers_blocks = (
len(self.tensor_shape) == len(kv_cache_shape) + 1
)
if self._cross_layers_blocks:
# prepend layers dimension
kv_cache_shape = (80,) + kv_cache_shape
try:
kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=self._cross_layers_blocks
)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(self.tensor_shape)))
# permute kv_cache_shape according to stride_order
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
physical_block_size_position = kv_cache_shape.index(16)
assert physical_block_size_position is not None
self._physical_block_size_position = -(
len(kv_cache_shape) - physical_block_size_position
)
@property
def is_kv_layout_blocks_first(self) -> bool:
return self._is_kv_layout_blocks_first
@@ -363,9 +336,7 @@ class TpKVTopology:
@property
def split_k_and_v(self) -> bool:
# Whether to register regions for K and V separately (when present).
return not (
self._cross_layers_blocks or self.is_mla or self.is_kv_layout_blocks_first
)
return not (self.is_mla or self.is_kv_layout_blocks_first)
@property
def tp_size(self) -> int:
@@ -375,14 +346,6 @@ class TpKVTopology:
def block_size(self) -> int:
return self.remote_block_size[self.engine_id]
@property
def cross_layers_blocks(self) -> bool:
return self._cross_layers_blocks
@property
def block_size_position(self) -> int:
return self._physical_block_size_position
def tp_ratio(
self,
remote_tp_size: int,

View File

@@ -54,7 +54,7 @@ from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.block_table import BlockTable
@@ -173,7 +173,7 @@ class NixlHandshakePayload(KVConnectorHandshakeMetadata):
def compute_nixl_compatibility_hash(
vllm_config: VllmConfig, attn_backend_name: str, cross_layers_blocks: bool
vllm_config: VllmConfig, attn_backend_name: str
) -> str:
"""
Compute compatibility hash for NIXL KV transfer.
@@ -216,7 +216,6 @@ def compute_nixl_compatibility_hash(
# Attention backend and KV cache dtype affect memory layout
"attn_backend_name": attn_backend_name,
"cache_dtype": str(cache_config.cache_dtype),
"cross_layers_blocks": cross_layers_blocks,
}
compat_hash = hash_factors(factors)
@@ -299,20 +298,6 @@ class NixlConnectorMetadata(KVConnectorMetadata):
class NixlConnector(KVConnectorBase_V1):
@property
def prefer_cross_layer_blocks(self) -> bool:
backend = get_current_attn_backend(self._vllm_config)
if backend().get_name() not in (
"FLASH_ATTN",
"FLASHINFER",
):
# For now there is no benefit to run cross layers when backend
# does not support on HND
return False
extra_config = self.kv_transfer_config.kv_connector_extra_config
return bool(str(extra_config.get("enable_cross_layers_blocks", "False")))
def __init__(
self,
vllm_config: VllmConfig,
@@ -324,7 +309,6 @@ class NixlConnector(KVConnectorBase_V1):
assert vllm_config.kv_transfer_config is not None
assert vllm_config.kv_transfer_config.engine_id is not None
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
self.kv_transfer_config = vllm_config.kv_transfer_config
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler: NixlConnectorScheduler | None = (
@@ -411,16 +395,6 @@ class NixlConnector(KVConnectorBase_V1):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
assert self.connector_worker is not None
cross_layer_name = "ALL_LAYERS"
kv_caches = {cross_layer_name: kv_cache}
self.connector_worker.register_kv_caches(kv_caches)
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
assert self.connector_worker is not None
self.connector_worker.set_host_xfer_buffer_ops(copy_operation)
@@ -1002,17 +976,20 @@ class NixlConnectorWorker:
# Get the attention backend from the first layer
# NOTE (NickLucche) models with multiple backends are not supported yet
self.attn_backend = get_current_attn_backend(vllm_config)
backend = get_current_attn_backend(vllm_config)
self.backend_name = self.attn_backend.get_name()
self.backend_name = backend.get_name()
self.kv_cache_layout = get_kv_cache_layout()
self.host_buffer_kv_cache_layout = self.kv_cache_layout
logger.debug("Detected attention backend %s", self.backend_name)
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
# lazy initialized in register_kv_caches
self.compat_hash: str | None = None
self.kv_topo: TpKVTopology | None = None
self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name
)
self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config(
"enforce_handshake_compat", True
)
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
@@ -1021,11 +998,16 @@ class NixlConnectorWorker:
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
self.xfer_stats = NixlKVConnectorStats()
self._physical_blocks_per_logical_kv_block = 1
self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config(
"enforce_handshake_compat", True
self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=backend,
)
self._physical_blocks_per_logical_kv_block = 1
def _nixl_handshake(
self,
@@ -1040,7 +1022,6 @@ class NixlConnectorWorker:
# Regardless, only handshake with the remote TP rank(s) that current
# local rank will read from. Note that With homogeneous TP,
# this happens to be the same single rank_i.
assert self.kv_topo is not None
p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size)
remote_rank_to_agent_name = {}
path = make_zmq_path("tcp", host, port)
@@ -1078,7 +1059,6 @@ class NixlConnectorWorker:
)
# Check compatibility hash BEFORE decoding agent metadata
assert self.compat_hash is not None
if (
self.enforce_compat_hash
and handshake_payload.compatibility_hash != self.compat_hash
@@ -1287,20 +1267,6 @@ class NixlConnectorWorker:
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=self.attn_backend,
tensor_shape=next(iter(kv_caches.values())).shape,
)
self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks
)
if self.use_host_buffer:
self.initialize_host_xfer_buffer(kv_caches=kv_caches)
assert len(self.host_xfer_buffers) == len(kv_caches), (
@@ -1335,21 +1301,29 @@ class NixlConnectorWorker:
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are registered in the same region
# to better exploit the memory layout (ie num_blocks is the first dim).
split_k_and_v = self.kv_topo.split_k_and_v
tensor_size_bytes = None
# TODO (NickLucche): Get kernel_block_size in a cleaner way
# NHD default "view" for non-MLA cache
if self.device_type == "cpu":
block_size_position = -2
else:
block_size_position = -2 if self.use_mla else -3
# Enable different block lengths for different layers when MLA is used.
self.block_len_per_layer = list[int]()
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
for layer_name, cache_or_caches in xfer_buffers.items():
cache_list = (
cache_or_caches if self.kv_topo.split_k_and_v else [cache_or_caches]
)
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
for cache in cache_list:
base_addr = cache.data_ptr()
if base_addr in seen_base_addresses:
continue
kernel_block_size = cache.shape[self.kv_topo.block_size_position]
kernel_block_size = cache.shape[block_size_position]
if self.block_size != kernel_block_size:
logger.info_once(
"User-specified logical block size (%s) does not match"
@@ -1411,7 +1385,6 @@ class NixlConnectorWorker:
self.device_kv_caches = kv_caches
self.dst_num_blocks[self.engine_id] = self.num_blocks
if self.kv_topo.is_kv_layout_blocks_first:
for i in range(len(self.slot_size_per_layer)):
assert self.slot_size_per_layer[i] % 2 == 0
@@ -1467,7 +1440,6 @@ class NixlConnectorWorker:
block_size=self.block_size,
)
# Wrap metadata in payload with hash for defensive decoding
assert self.compat_hash is not None
encoder = msgspec.msgpack.Encoder()
self.xfer_handshake_metadata = NixlHandshakePayload(
compatibility_hash=self.compat_hash,
@@ -1489,8 +1461,6 @@ class NixlConnectorWorker:
register another local_xfer_handler using remote block len to ensure
data copy correctness.
"""
assert self.kv_topo is not None
block_size_ratio = self.block_size // block_size
blocks_data = []
for i, base_addr in enumerate(self.seen_base_addresses):
@@ -1603,7 +1573,6 @@ class NixlConnectorWorker:
# remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|
# local origin:| 0| 1| 8| 12|
# local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15|
assert self.kv_topo is not None
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(engine_id)
if engine_id not in self.dst_num_blocks:
@@ -1731,10 +1700,7 @@ class NixlConnectorWorker:
"""
remote_engine_id = nixl_agent_meta.engine_id
assert (
self._tp_size[remote_engine_id] == remote_tp_size
and self.kv_topo is not None
)
assert self._tp_size[remote_engine_id] == remote_tp_size
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
@@ -1871,7 +1837,6 @@ class NixlConnectorWorker:
if len(self.device_kv_caches) == 0:
return
assert block_size_ratio >= 1, "Only nP < nD supported currently."
assert self.kv_topo is not None
if self.enable_permute_local_kv and block_size_ratio > 1:
logger.debug(
"Post-processing device kv cache on receive by converting "
@@ -1891,7 +1856,7 @@ class NixlConnectorWorker:
block_size_ratio,
)
split_k_and_v = self.kv_topo.split_k_and_v
split_k_and_v = not (self.use_mla or self.kv_topo.is_kv_layout_blocks_first)
for block_ids in block_ids_list:
indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long)
@@ -1916,7 +1881,6 @@ class NixlConnectorWorker:
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.
"""
assert self.kv_topo is not None
done_sending = self._get_new_notifs()
done_recving = self._pop_done_transfers(self._recving_transfers)
@@ -1986,7 +1950,6 @@ class NixlConnectorWorker:
are reading from the same producer (heterogeneous TP scenario), wait
for all consumers to be done pulling.
"""
assert self.kv_topo is not None
notified_req_ids: set[str] = set()
for notifs in self.nixl_wrapper.get_new_notifs().values():
for notif in notifs:
@@ -2146,7 +2109,7 @@ class NixlConnectorWorker:
self._reqs_to_send[req_id] = expiration_time
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
assert meta.remote is not None and self.kv_topo is not None
assert meta.remote is not None
remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id(
meta.remote.engine_id
)
@@ -2215,7 +2178,10 @@ class NixlConnectorWorker:
local_xfer_side_handle: int,
remote_xfer_side_handle: int,
):
assert self.kv_topo is not None
"""
Post a READ point-to-point xfer request from a single local worker to
a single remote worker.
"""
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)
if block_size_ratio > 1:
local_block_ids = self.get_mapped_blocks(
@@ -2448,7 +2414,6 @@ class NixlConnectorWorker:
For FlashInfer, this is half the length of the whole block, as K and V
share the same region.
"""
assert self.kv_topo is not None
if self.kv_topo.is_kv_layout_blocks_first:
# For indexing only half (either just the K or V part).
block_len = self.block_len_per_layer[layer_idx] // 2

View File

@@ -46,6 +46,9 @@ from vllm.multimodal.inputs import (
MultiModalBatchedField,
MultiModalFlatField,
MultiModalSharedField,
VisionChunk,
VisionChunkImage,
VisionChunkVideo,
)
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
@@ -336,7 +339,9 @@ ChatTemplateContentFormatOption = Literal["auto", "string", "openai"]
ChatTemplateContentFormat = Literal["string", "openai"]
ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds"]
ModalityStr = Literal[
"image", "audio", "video", "image_embeds", "audio_embeds", "vision_chunk"
]
_T = TypeVar("_T")
@@ -449,6 +454,78 @@ def _get_embeds_data(
raise NotImplementedError(type(data_items))
def rebuild_mm_uuids_from_mm_data(
mm_uuids: MultiModalUUIDDict,
mm_data: MultiModalDataDict,
) -> MultiModalUUIDDict:
"""Rebuild mm_uuids after vision_chunk processing.
When videos are split into chunks, the original UUIDs need to be updated
to reflect the new UUIDs generated for each chunk.
Args:
mm_uuids: Original UUIDs dictionary
mm_data: Processed multimodal data with vision_chunk items
Returns:
Updated UUIDs dictionary with chunk UUIDs
"""
vision_chunks = mm_data.get("vision_chunk")
if vision_chunks is None:
return mm_uuids
new_uuids = dict(mm_uuids)
vision_chunk_uuids = []
for item in vision_chunks:
# vision_chunk items are always dicts (VisionChunkImage/VisionChunkVideo)
assert isinstance(item, dict)
uuid_val = item.get("uuid")
if uuid_val is not None:
vision_chunk_uuids.append(uuid_val)
if vision_chunk_uuids:
new_uuids["vision_chunk"] = vision_chunk_uuids
return new_uuids
def build_video_prompts_from_mm_data(
mm_data: MultiModalDataDict,
) -> list[str]:
"""Build video prompts from vision_chunk data.
Collects prompts from video chunks and groups them by video_idx.
Args:
mm_data: Processed multimodal data with vision_chunk items
Returns:
List of video prompts, one per video.
"""
vision_chunks = mm_data.get("vision_chunk")
if vision_chunks is None:
return []
# Group chunks by video_idx
video_prompts_dict: dict[int, list[str]] = defaultdict(list)
for item in vision_chunks:
# vision_chunk items are always dicts (VisionChunkImage/VisionChunkVideo)
assert isinstance(item, dict)
if item.get("type") == "video_chunk":
video_idx = item.get("video_idx", 0)
prompt = item.get("prompt", "")
video_prompts_dict[video_idx].append(prompt)
# Build prompts in video order
video_prompts = []
for video_idx in sorted(video_prompts_dict.keys()):
video_prompts.append("".join(video_prompts_dict[video_idx]))
return video_prompts
class BaseMultiModalItemTracker(ABC, Generic[_T]):
"""
Tracks multi-modal items in a given request and ensures that the number
@@ -462,6 +539,13 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
self._model_config = model_config
self._items_by_modality = defaultdict[str, list[_T]](list)
# Track original modality for each vision_chunk item (image or video)
self._modality_order = defaultdict[str, list[str]](list)
@cached_property
def use_unified_vision_chunk_modality(self) -> bool:
"""Check if model uses unified vision_chunk modality for images/videos."""
return getattr(self._model_config.hf_config, "use_unified_vision_chunk", False)
@property
def model_config(self) -> ModelConfig:
@@ -499,11 +583,31 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
media.
"""
input_modality = modality.replace("_embeds", "")
num_items = len(self._items_by_modality[modality]) + 1
original_modality = modality
use_vision_chunk = (
self.use_unified_vision_chunk_modality
and original_modality in ["video", "image"]
)
# If use_unified_vision_chunk_modality is enabled,
# map image/video to vision_chunk
if use_vision_chunk:
# To avoid validation fail
# because models with use_unified_vision_chunk_modality=True
# will only accept vision_chunk modality.
input_modality = "vision_chunk"
num_items = len(self._items_by_modality[input_modality]) + 1
else:
num_items = len(self._items_by_modality[original_modality]) + 1
self.mm_processor.validate_num_items(input_modality, num_items)
self._items_by_modality[modality].append(item)
# Track original modality for vision_chunk items
if use_vision_chunk:
self._items_by_modality[input_modality].append(item) # type: ignore
self._modality_order["vision_chunk"].append(original_modality)
else:
self._items_by_modality[original_modality].append(item)
return self.model_cls.get_placeholder_str(modality, num_items)
@@ -515,6 +619,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
def _resolve_items(
items_by_modality: dict[str, list[tuple[object, str | None]]],
mm_processor: BaseMultiModalProcessor,
vision_chunk_modality_order: dict[str, list[str]],
) -> tuple[MultiModalDataDict, MultiModalUUIDDict]:
if "image" in items_by_modality and "image_embeds" in items_by_modality:
raise ValueError("Mixing raw image and embedding inputs is not allowed")
@@ -546,6 +651,74 @@ def _resolve_items(
if "video" in items_by_modality:
mm_data["video"] = [data for data, uuid in items_by_modality["video"]]
mm_uuids["video"] = [uuid for data, uuid in items_by_modality["video"]]
if "vision_chunk" in items_by_modality:
# Process vision_chunk items - extract from (data, modality) tuples
# and convert to VisionChunk types with proper UUID handling
vision_chunk_items = items_by_modality["vision_chunk"]
modality_order = vision_chunk_modality_order.get("vision_chunk", [])
mm_uuids["vision_chunk"] = [
uuid for data, uuid in items_by_modality["vision_chunk"]
]
# Filter out None items (from asyncio.sleep(0) placeholders)
filtered_items = [
(idx, item)
for idx, item in enumerate(vision_chunk_items)
if item is not None
]
assert len(filtered_items) == len(modality_order), (
f"vision_chunk items ({len(filtered_items)}) and "
f"modality_order ({len(modality_order)}) must have same length"
)
processed_chunks: list[VisionChunk] = []
video_idx = 0
for i, (idx, item) in enumerate(filtered_items):
inner_modality = modality_order[i]
data, uuid = item
uuid_val = uuid if idx < len(mm_uuids["vision_chunk"]) else None
if inner_modality == "image":
# Cast data to proper type for image
# Use .media (PIL.Image) directly to avoid redundant
# bytes→PIL conversion in media_processor
if hasattr(data, "media"):
image_data = data.media # type: ignore[union-attr]
processed_chunks.append(
VisionChunkImage(type="image", image=image_data, uuid=uuid_val)
)
else:
processed_chunks.append(data) # type: ignore[arg-type]
elif inner_modality == "video":
# For video, we may need to split into chunks
# if processor supports it
# For now, just wrap as a video chunk placeholder
if hasattr(mm_processor, "split_video_chunks") and data is not None:
try:
video_uuid = uuid_val or random_uuid()
# video await result is (video_data, video_meta) tuple
if isinstance(data, tuple) and len(data) >= 1:
video_data = data[0]
else:
video_data = data
video_chunks = mm_processor.split_video_chunks(video_data)
for i, vc in enumerate(video_chunks):
processed_chunks.append(
VisionChunkVideo(
type="video_chunk",
video_chunk=vc["video_chunk"],
uuid=f"{video_uuid}-{i}",
video_idx=video_idx,
prompt=vc["prompt"],
)
)
video_idx += 1
except Exception as e:
logger.warning("Failed to split video chunks: %s", e)
processed_chunks.append(data) # type: ignore[arg-type]
else:
processed_chunks.append(data) # type: ignore[arg-type]
mm_data["vision_chunk"] = processed_chunks
return mm_data, mm_uuids
@@ -557,7 +730,9 @@ class MultiModalItemTracker(BaseMultiModalItemTracker[tuple[object, str | None]]
if not self._items_by_modality:
return None, None
return _resolve_items(dict(self._items_by_modality), self.mm_processor)
return _resolve_items(
dict(self._items_by_modality), self.mm_processor, self._modality_order
)
def create_parser(self) -> "BaseMultiModalContentParser":
return MultiModalContentParser(self)
@@ -577,7 +752,9 @@ class AsyncMultiModalItemTracker(
for modality, coros in self._items_by_modality.items()
}
return _resolve_items(resolved_items_by_modality, self.mm_processor)
return _resolve_items(
resolved_items_by_modality, self.mm_processor, self._modality_order
)
def create_parser(self) -> "BaseMultiModalContentParser":
return AsyncMultiModalContentParser(self)

View File

@@ -782,6 +782,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
),
# Backend for Video IO
# - "opencv": Default backend that uses OpenCV stream buffered backend.
# - "identity": Returns raw video bytes for model processor to handle.
#
# Custom backend implementations can be registered
# via `@VIDEO_LOADER_REGISTRY.register("my_custom_video_loader")` and

View File

@@ -224,6 +224,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin
a1q_scale = None
if is_nvfp4 and a1q_scale is not None:
if a1q_scale.element_size() == 1:
a1q_scale = a1q_scale.view(torch.uint8)
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights

View File

@@ -0,0 +1,581 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: E501
"""
Kimi-K2.5 Model Implementation for vLLM.
Kimi-K2.5 extends Kimi-K2 with vision support
This module defines:
- KimiK25ProcessingInfo/KimiK25MultiModalProcessor: Processing logic
- KimiK25ForConditionalGeneration: Main model class
"""
import copy
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from typing import Annotated, Any, Literal
import torch
from torch import nn
from transformers import BatchFeature
from transformers.processing_utils import ProcessorMixin
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_pp_group
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.deepseek_v2 import DeepseekV2Model
from vllm.model_executor.models.interfaces import SupportsMultiModal, SupportsPP
from vllm.model_executor.models.kimi_k25_vit import (
KimiK25MultiModalProjector,
MoonViT3dPretrainedModel,
vision_tower_forward,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
NestedTensors,
VisionChunk,
VisionChunkImage,
VisionChunkVideo,
)
from vllm.multimodal.parse import MultiModalDataItems, VisionChunkProcessorItems
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseMultiModalProcessor,
BaseProcessingInfo,
InputProcessingContext,
PromptReplacement,
PromptUpdate,
)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import KimiK25Config
from vllm.transformers_utils.processor import cached_get_image_processor
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix
logger = init_logger(__name__)
# Dummy input dimensions for profiling.
@dataclass
class MaxImageTokenMeta:
width: int = 3000
height: int = 3000
class KimiK25MediaPixelInputs(TensorSchema):
"""
Media input schema for K2-VL model.
Dimensions:
- np: Number of patches (flattened from all media items)
- ps: Patch size
- nm: Number of media items
"""
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[
torch.Tensor | list[torch.Tensor],
TensorShape("np", 3, "ps", "ps"),
]
grid_thws: Annotated[torch.Tensor, TensorShape("nm", 3)]
class MoonshotKimiVAutoProcessor(ProcessorMixin):
attributes = ["tokenizer"]
tokenizer_class = "AutoTokenizer"
def __init__(self, media_processor=None, tokenizer=None):
super().__init__(tokenizer)
self.media_processor = media_processor
# We do not support str input for text here
def __call__(
self,
vision_chunks: list[VisionChunk] | None = None,
*,
text: list[int],
**kwargs,
) -> BatchFeature:
"""
Args:
vision_chunks: List of VisionChunk items to be processed.
For image: VisionChunkImage with type='image', image=PIL.Image
For video_chunk: VisionChunkVideo with type='video_chunk', video_chunk=list[PIL.Image]
text: The token ids to be fed to a model (required).
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **input_ids** -- list of token ids to be fed to a model.
- **pixel_values** -- Pixel values to be fed to a model. Returned when `vision_chunks` is not `None`.
- **grid_thws** -- list of image 3D grid in LLM. Returned when `vision_chunks` is not `None`.
"""
mm_inputs = {}
if vision_chunks is not None:
assert isinstance(vision_chunks, list)
mm_inputs = self.media_processor.preprocess(vision_chunks)
# XXX: _apply_hf_processor_text_mm will call tolist() on input_ids
return BatchFeature(
data={
"input_ids": torch.tensor([text]),
**mm_inputs,
}
)
class KimiK25ProcessingInfo(BaseProcessingInfo):
"""Processing information for Kimi-K2.5 model.
Provides configuration and utilities for processing both
images and video-chunks.
"""
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__(ctx)
self.hf_config = self.get_hf_config()
self.media_token_id = self.hf_config.media_placeholder_token_id
media_processor = cached_get_image_processor(
self.ctx.model_config.model, trust_remote_code=True
)
self.media_processor = media_processor
self.hf_processor = MoonshotKimiVAutoProcessor(
media_processor=self.media_processor,
tokenizer=self.get_tokenizer(),
)
self.media_tokens_calculator = self.media_processor.media_tokens_calculator
def get_hf_processor(self):
return self.hf_processor
def get_hf_config(self):
return self.ctx.get_hf_config(KimiK25Config)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
# None means unlimited
return {"vision_chunk": None}
class KimiK25DummyInputsBuilder(BaseDummyInputsBuilder[KimiK25ProcessingInfo]):
"""Builds dummy inputs for Kimi-K2.5 model profiling."""
def __init__(self, info: KimiK25ProcessingInfo) -> None:
super().__init__(info)
self.media_token_id = self.info.media_token_id
self.frame_per_chunk = self.info.media_processor.num_frames_per_chunk
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> list[int]:
num_media = mm_counts.get("vision_chunk", 0)
return [self.media_token_id] * num_media
def get_dummy_mm_items(self):
dummy_videos = self._get_dummy_images(
height=MaxImageTokenMeta.height,
width=MaxImageTokenMeta.width,
num_images=self.frame_per_chunk,
)
video_chunk_dummy_item = VisionChunkVideo(
type="video_chunk", video_chunk=dummy_videos
)
video_chunk_num_tokens = self.info.media_tokens_calculator(
video_chunk_dummy_item
)
image_dummy_item = VisionChunkImage(
type="image",
image=self._get_dummy_images(
height=MaxImageTokenMeta.height,
width=MaxImageTokenMeta.width,
num_images=1,
)[0],
)
image_num_tokens = self.info.media_tokens_calculator(image_dummy_item)
# return the larger one
if video_chunk_num_tokens >= image_num_tokens:
return [video_chunk_dummy_item]
else:
return [image_dummy_item]
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
# TODO: Support mm_options for vision_chunk to allow user configuration
dummy_items = self.get_dummy_mm_items()
return {"vision_chunk": dummy_items}
class KimiK25MultiModalProcessor(BaseMultiModalProcessor[KimiK25ProcessingInfo]):
"""Multi-modal processor for Kimi-K2.5.
Handles both image and video-chunk modalities.
"""
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
"""Indicates how to slice media input into multiple items.
pixel_values: [N, 3, patch_size, patch_size], all patches collected from B medias
grid_thws: [B,3], each item: [N_t, N_h ,N_w], indicates the grid size in time/height/width direction
for current item.
by multiplying [N_t, N_h ,N_w], we get the number of patches for each media item, thus we can slice
pixel_values by pixel_values[start:start + N_t*N_h*N_w] to get patches of one item.
"""
grid_thws = hf_inputs.get("grid_thws", torch.empty((0, 3)))
grid_sizes = grid_thws.prod(-1)
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"vision_chunk", grid_sizes
),
grid_thws=MultiModalFieldConfig.batched("vision_chunk"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
hf_config = self.info.get_hf_config()
media_token_id = hf_config.media_placeholder_token_id
def get_replacement(item_idx: int):
media = mm_items.get_items("vision_chunk", (VisionChunkProcessorItems,))
num_media_token = self.info.media_tokens_calculator(media[item_idx])
return [media_token_id] * num_media_token
return [
PromptReplacement(
modality="vision_chunk",
target=[media_token_id],
replacement=get_replacement,
),
]
def split_video_chunks(self, video):
return self.info.media_processor.split_video_chunks(video)
@MULTIMODAL_REGISTRY.register_processor(
KimiK25MultiModalProcessor,
info=KimiK25ProcessingInfo,
dummy_inputs=KimiK25DummyInputsBuilder,
)
class KimiK25ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
"""Kimi-K2.5 model for conditional generation.
Supports both image and video-chunk modalities.
Video-chunks are temporal segments (typically 4 frames) that are
processed with temporal pooling.
"""
supports_encoder_tp_data = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
# Kimi-K2.5 uses video_chunk for all media types
if modality == "image":
return "<|media_begin|>image<|media_content|><|media_pad|><|media_end|>"
elif modality == "video":
# return a placeholder, to be replaced in the future.
return "<|kimi_k25_video_placeholder|>"
raise ValueError(f"Unsupported modality: {modality}")
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> None:
super().__init__()
model_config = vllm_config.model_config
config: KimiK25Config = model_config.hf_config
self.config = config
quant_config = vllm_config.quant_config
# Check for MoonViT config compatibility
self.use_data_parallel = (
model_config.multimodal_config.mm_encoder_tp_mode == "data"
)
self.hidden_size = config.text_config.hidden_size
self.device = torch.cuda.current_device()
# Build vision tower directly with KimiK25VisionConfig
self.vision_tower = MoonViT3dPretrainedModel(
config.vision_config,
prefix=maybe_prefix(prefix, "vision_tower"),
)
self.vision_tower = self.vision_tower.to(
device=self.device, dtype=model_config.dtype
)
self.mm_projector = KimiK25MultiModalProjector(
config=config.vision_config,
use_data_parallel=self.use_data_parallel,
prefix=maybe_prefix(prefix, "mm_projector"),
)
self.mm_projector = self.mm_projector.to(
device=self.device, dtype=model_config.dtype
)
self.quant_config = quant_config
sub_vllm_config = copy.deepcopy(vllm_config)
sub_vllm_config.model_config.hf_config = (
sub_vllm_config.model_config.hf_config.text_config
)
self.language_model = DeepseekV2Model(
vllm_config=sub_vllm_config,
prefix=maybe_prefix(prefix, "language_model"),
)
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.text_config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
else:
self.lm_head = PPMissingLayer()
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors
)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale)
self.media_placeholder: int = self.config.media_placeholder_token_id
def _parse_and_validate_media_input(
self, **kwargs: object
) -> KimiK25MediaPixelInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
grid_thws = kwargs.pop("grid_thws", None)
if pixel_values is None:
return None
if isinstance(pixel_values, list):
pixel_values = torch.cat(pixel_values, dim=0)
if len(pixel_values.shape) == 5 or len(pixel_values.shape) == 3:
pixel_values = pixel_values.reshape(
pixel_values.shape[0] * pixel_values.shape[1], *pixel_values.shape[2:]
)
# The batch dimension of pixel_values has been flattened into shape[0]
target_dtype = next(self.vision_tower.parameters()).dtype
pixel_values = pixel_values.to(target_dtype)
assert isinstance(grid_thws, torch.Tensor), (
f"expect grid_thws to be a tensor, get {type(grid_thws)}"
)
# In some cases (e.g. with merger), grid_thws has an extra middle dimension
grid_thws = grid_thws.reshape(-1, grid_thws.shape[-1])
assert grid_thws.ndim == 2 and grid_thws.size(1) == 3, (
f"unexpected shape for grid_thws: {grid_thws.shape}"
)
return KimiK25MediaPixelInputs(
type="pixel_values",
pixel_values=pixel_values,
grid_thws=grid_thws,
)
def _process_media_input(
self, media_input: KimiK25MediaPixelInputs
) -> list[torch.Tensor]:
# NOTE(moyan): This forward will automatically batch the forward pass internally
media_features = vision_tower_forward(
self.vision_tower,
media_input["pixel_values"],
media_input["grid_thws"],
mm_projector=self.mm_projector,
use_data_parallel=self.use_data_parallel,
)
return media_features
def embed_multimodal(self, **kwargs: object) -> NestedTensors | None:
# Validate the multimodal input keyword arguments
media_input = self._parse_and_validate_media_input(**kwargs)
if media_input is None:
return None
# Run multimodal inputs through encoder and projector
vision_embeddings = self._process_media_input(media_input)
return vision_embeddings
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states, **kwargs)
return logits
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
config = self.config.text_config
if not getattr(config, "n_routed_experts", None):
return []
return SharedFusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=config.n_routed_experts,
num_redundant_experts=0,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
config = self.config.text_config
_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
# mm_projector -> mm_projector mapping
# "mm_projector": "mm_projector",
"mm_projector.proj.0": "mm_projector.linear_1",
"mm_projector.proj.2": "mm_projector.linear_2",
}
stacked_params_mapping = [
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
if getattr(config, "kv_lora_rank", None) and getattr(
config, "q_lora_rank", None
):
stacked_params_mapping += [
(".fused_qkv_a_proj", ".q_a_proj", 0),
(".fused_qkv_a_proj", ".kv_a_proj_with_mqa", 1),
]
expert_params_mapping = self.get_expert_mapping()
params_dict = dict(self.named_parameters())
for args in weights:
name, loaded_weight = args[:2]
kwargs = args[2] if len(args) > 2 else {}
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
use_default_weight_loading = False
if "vision" in name:
if self.vision_tower is not None:
use_default_weight_loading = True
else:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if ("mlp.experts." in name) and name not in params_dict:
continue
name = name.replace(weight_name, param_name)
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id, **kwargs)
break
else:
for _, (
param_name,
weight_name,
expert_id,
shard_id,
) in enumerate(expert_params_mapping):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
expert_id=expert_id,
shard_id=shard_id,
**kwargs,
)
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
if name.endswith(".bias") and name not in params_dict:
continue
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight, **kwargs)
def get_spec_layer_idx_from_weight_name(
config: KimiK25Config, weight_name: str
) -> int | None:
if hasattr(config, "num_nextn_predict_layers") and (
config.num_nextn_predict_layers > 0
):
layer_idx = config.num_hidden_layers
for i in range(config.num_nextn_predict_layers):
# might start with language_model.model.layers
if f"model.layers.{layer_idx + i}." in weight_name:
return layer_idx + i
return None

View File

@@ -0,0 +1,678 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Vision tower implementation for Kimi-K2.5 model.
This module provides the vision encoder components for Kimi-K2.5,
including 3D patch embedding, RoPE position embedding, and
temporal pooling for video chunks.
"""
from collections.abc import Sequence
from copy import deepcopy
from typing import Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.activations import GELUActivation
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.models.vision import (
is_vit_use_data_parallel,
run_dp_sharded_mrope_vision_model,
)
from vllm.transformers_utils.configs.kimi_k25 import KimiK25VisionConfig
logger = init_logger(__name__)
def _apply_rope_input_validation(x, freqs_cis):
assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape)
assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape)
assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype
def get_rope_shape_decorate(func):
_get_rope_shape_first_call_flag = set()
def wrapper(org, interpolation_mode, shape):
key = (org.requires_grad, torch.is_grad_enabled(), interpolation_mode)
if key not in _get_rope_shape_first_call_flag:
_get_rope_shape_first_call_flag.add(key)
_ = func(org, interpolation_mode, shape=(64, 64))
return func(org, interpolation_mode, shape)
return wrapper
@get_rope_shape_decorate
@torch.compile(dynamic=True)
def get_rope_shape(org, interpolation_mode, shape):
return (
F.interpolate(
org.permute((2, 0, 1)).unsqueeze(0),
size=shape,
mode=interpolation_mode,
)
.squeeze(0)
.permute((1, 2, 0))
.flatten(end_dim=1)
)
def apply_rope(
xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args: (The leading dimensions of all inputs should be the same)
xq: query, tensor of shape (..., num_heads, head_dim)
xk: key, tensor of shape (..., num_heads, head_dim)
freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64.
Returns:
xq_out, xk_out: tensors of shape (..., num_heads, head_dim)
"""
_apply_rope_input_validation(xq, freqs_cis)
_apply_rope_input_validation(xk, freqs_cis)
freqs_cis = freqs_cis.unsqueeze(-2) # ..., 1, head_dim/2
# ..., num_heads, head_dim/2
xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2))
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
return xq_out.type_as(xq), xk_out.type_as(xk)
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""Generate 1D sincos positional embedding from grid positions."""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def get_1d_sincos_pos_embed(embed_dim, t_size, cls_token=False):
"""Generate 1D sincos positional embedding."""
grid_t = np.arange(t_size, dtype=np.float32)
pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_t)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
class Learnable2DInterpPosEmbDivided_fixed(nn.Module):
"""2D learnable position embedding with temporal extension."""
def __init__(
self,
height: int,
width: int,
num_frames: int,
dim: int,
interpolation_mode: str = "bicubic",
) -> None:
super().__init__()
self.height = height
self.width = width
self.num_frames = num_frames
self.dim = dim
self.interpolation_mode = interpolation_mode
self.weight = nn.Parameter(torch.empty(height, width, dim))
self.register_buffer(
"time_weight",
torch.from_numpy(get_1d_sincos_pos_embed(self.dim, self.num_frames))
.float()
.unsqueeze(1),
persistent=False,
)
self.reset_parameters()
def reset_parameters(self):
nn.init.normal_(self.weight)
def forward(self, x: torch.Tensor, grid_thws: torch.Tensor) -> torch.Tensor:
pos_embs = []
for t, h, w in grid_thws.tolist():
assert t <= self.num_frames, f"t:{t} > self.num_frames:{self.num_frames}"
if (h, w) == self.weight.shape[:-1]:
pos_emb_2d = self.weight.flatten(end_dim=1)
else:
pos_emb_2d = get_rope_shape(
self.weight,
interpolation_mode=self.interpolation_mode,
shape=(h, w),
)
if t == 1:
pos_emb_3d = pos_emb_2d
else:
pos_emb_3d = (
pos_emb_2d.unsqueeze(0).repeat(t, 1, 1) + self.time_weight[0:t]
)
pos_embs.append(pos_emb_3d.reshape(-1, pos_emb_3d.shape[-1]))
out = x + torch.cat(pos_embs)
return out
class MoonVision3dPatchEmbed(nn.Module):
"""3D patch embedding for vision tower."""
def __init__(
self,
out_dim: int,
in_dim: int = 3,
patch_size: int | tuple[int, int] = (14, 14),
pos_emb_height: int = 14,
pos_emb_width: int = 14,
pos_emb_time: int = 4,
pos_emb_type: str = "divided_fixed",
):
super().__init__()
assert isinstance(patch_size, int | Sequence), (
f"Invalid patch_size type: {type(patch_size)}"
)
if isinstance(patch_size, int):
patch_size = (patch_size, patch_size)
assert len(patch_size) == 2, (
f"Expected patch_size to be a tuple of 2, got {patch_size}"
)
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_dim, out_dim, kernel_size=patch_size, stride=patch_size
)
if pos_emb_type == "divided_fixed":
self.pos_emb = Learnable2DInterpPosEmbDivided_fixed(
height=pos_emb_height,
width=pos_emb_width,
num_frames=pos_emb_time,
dim=out_dim,
)
else:
raise NotImplementedError(f"Not support pos_emb_type: {pos_emb_type}")
def forward(self, x: torch.Tensor, grid_thws: torch.Tensor) -> torch.Tensor:
x = self.proj(x).view(x.size(0), -1)
# apply positional embedding
x = self.pos_emb(x, grid_thws)
return x
class Rope2DPosEmbRepeated(nn.Module):
"""2D rotary position embedding with multi-resolution support."""
def __init__(self, dim: int, max_height: int, max_width: int, theta_base=10000):
super().__init__()
self.dim = dim
assert self.dim % 4 == 0, "dim must be divisible by 4"
self.max_height = max_height
self.max_width = max_width
self.theta_base = theta_base
def extra_repr(self):
return (
f"dim={self.dim}, max_height={self.max_height}, "
f"max_width={self.max_width}, theta_base={self.theta_base}"
)
def _precompute_freqs_cis(self, device: torch.device) -> torch.Tensor:
"""Calculate the cis(freqs) for each position in the 2D grid."""
N = self.max_height * self.max_width
flat_pos = torch.arange(0, N).float().to(device)
x_pos = flat_pos % self.max_width
y_pos = flat_pos // self.max_width
dim_range = (
torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(device)
) # C/4
freqs = 1.0 / (self.theta_base ** (dim_range / self.dim))
x_freqs = torch.outer(x_pos, freqs).float() # N, C/4
y_freqs = torch.outer(y_pos, freqs).float() # N, C/4
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4
# N, C/4, 2
freqs_cis = torch.cat(
[x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1
)
# max_height, max_width, C/2
freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
return freqs_cis
def get_freqs_cis(
self, grid_thws: torch.Tensor, device: torch.device
) -> torch.Tensor:
"""
Args:
grid_thws (torch.Tensor): grid time, height and width
Returns:
freqs_cis: tensor of shape (sum(t * height * width), dim//2)
"""
if not hasattr(self, "freqs_cis"):
self.register_buffer(
"freqs_cis", self._precompute_freqs_cis(device), persistent=False
)
shapes = grid_thws.tolist()
assert all(
1 <= h <= self.max_height and 1 <= w <= self.max_width for t, h, w in shapes
), (
shapes,
self.max_height,
self.max_width,
)
freqs_cis = torch.cat(
[
self.freqs_cis[:h, :w].reshape(-1, self.dim // 2).repeat(t, 1)
for t, h, w in shapes
],
dim=0,
)
return freqs_cis
class MLP2(nn.Module):
"""Two-layer MLP with tensor parallel support."""
def __init__(
self,
dims: list[int],
activation,
bias: bool = True,
prefix: str = "",
use_data_parallel: bool = False,
):
super().__init__()
assert len(dims) == 3
self.use_data_parallel = use_data_parallel
self.fc0 = ColumnParallelLinear(
dims[0],
dims[1],
bias=bias,
prefix=maybe_prefix(prefix, "fc0"),
disable_tp=self.use_data_parallel,
)
self.fc1 = RowParallelLinear(
dims[1],
dims[2],
bias=bias,
prefix=maybe_prefix(prefix, "fc1"),
disable_tp=self.use_data_parallel,
)
self.activation = activation
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.fc0(x)
x = self.activation(x)
x, _ = self.fc1(x)
return x
class MoonViTEncoderLayer(nn.Module):
"""Single encoder layer for MoonViT with TP/DP support."""
def __init__(
self,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
prefix: str = "",
*,
activation=F.gelu,
attn_bias: bool = False,
):
super().__init__()
self.use_data_parallel = is_vit_use_data_parallel()
self.num_heads = num_heads
self.hidden_dim = hidden_dim
self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
self.tp_size = (
1 if self.use_data_parallel else get_tensor_model_parallel_world_size()
)
self.num_attention_heads_per_partition = divide(num_heads, self.tp_size)
self.norm0 = nn.LayerNorm(hidden_dim)
self.norm1 = nn.LayerNorm(hidden_dim)
self.mlp = MLP2(
[hidden_dim, mlp_dim, hidden_dim],
activation,
prefix=f"{prefix}.mlp",
use_data_parallel=self.use_data_parallel,
)
self.wqkv = QKVParallelLinear(
hidden_size=hidden_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=attn_bias,
prefix=f"{prefix}.wqkv",
disable_tp=self.use_data_parallel,
)
self.wo = RowParallelLinear(
hidden_dim,
hidden_dim,
bias=attn_bias,
prefix=f"{prefix}.wo",
disable_tp=self.use_data_parallel,
)
self.attn = MMEncoderAttention(
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5,
prefix=f"{prefix}.attn",
)
def attention_qkvpacked(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rope_freqs_cis: torch.Tensor | None = None,
):
"""Compute self-attention with packed QKV.
Args:
x (torch.Tensor): (seqlen, hidden_dim)
cu_seqlens (torch.Tensor): cumulative sequence lengths
"""
seq_length = x.size(0)
xqkv, _ = self.wqkv(x)
qkv_shape = xqkv.size()[:-1] + (
3,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
# xqkv: (seqlen, 3, nheads, headdim)
xqkv = xqkv.view(*qkv_shape)
xq, xk, xv = torch.unbind(xqkv, dim=-3)
xq, xk = apply_rope(xq, xk, rope_freqs_cis)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
attn_out = self.attn(
xq.unsqueeze(0),
xk.unsqueeze(0),
xv.unsqueeze(0),
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
attn_out = attn_out.reshape(
seq_length,
self.num_attention_heads_per_partition
* self.hidden_size_per_attention_head,
)
attn_out, _ = self.wo(attn_out)
return attn_out
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rope_freqs_cis: torch.Tensor | None = None,
):
residual = hidden_states
hidden_states = self.norm0(hidden_states)
hidden_states = self.attention_qkvpacked(
hidden_states, cu_seqlens, rope_freqs_cis
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class MoonViT3dEncoder(nn.Module):
"""Full encoder stack for MoonViT 3D."""
def __init__(
self,
hidden_dim: int,
num_layers: int,
block_cfg: dict,
video_attn_type: str = "spatial_temporal",
prefix: str = "",
) -> None:
super().__init__()
assert video_attn_type == "spatial_temporal", (
f'video_attn_type must be "spatial_temporal", got {video_attn_type}'
)
self.video_attn_type = video_attn_type
self.rope_2d = Rope2DPosEmbRepeated(
block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512
)
self.blocks = nn.ModuleList(
[
MoonViTEncoderLayer(
**block_cfg,
prefix=f"{prefix}.blocks.{layer_idx}",
)
for layer_idx in range(num_layers)
]
)
self.final_layernorm = nn.LayerNorm(hidden_dim)
def forward(
self,
hidden_states: torch.Tensor,
grid_thws: torch.Tensor,
) -> torch.Tensor:
rope_freqs_cis = self.rope_2d.get_freqs_cis(
grid_thws=grid_thws, device=hidden_states.device
)
lengths = torch.cat(
(
torch.zeros(1, dtype=grid_thws.dtype, device=grid_thws.device),
grid_thws[:, 0] * grid_thws[:, 1] * grid_thws[:, 2],
)
)
cu_seqlens = lengths.to(hidden_states.device).cumsum(dim=0, dtype=torch.int32)
for block in self.blocks:
hidden_states = block(
hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis
)
hidden_states = self.final_layernorm(hidden_states)
return hidden_states
def tpool_patch_merger(
x: torch.Tensor,
grid_thws: torch.Tensor,
merge_kernel_size: tuple[int, int] = (2, 2),
) -> list[torch.Tensor]:
"""Temporal pooling patch merger."""
kh, kw = merge_kernel_size
lengths = (grid_thws[:, 0] * grid_thws[:, 1] * grid_thws[:, 2]).tolist()
seqs = x.split(lengths, dim=0)
outputs = []
for seq, (t, h, w) in zip(seqs, grid_thws.tolist()):
nh, nw = h // kh, w // kw
# Reshape: (t*h*w, d) -> (t, nh, kh, nw, kw, d)
v = seq.view(t, nh, kh, nw, kw, -1)
# Temporal pooling first (reduces tensor size before permute)
v = v.mean(dim=0) # (nh, kh, nw, kw, d)
# Spatial rearrangement: (nh, kh, nw, kw, d) -> (nh, nw, kh, kw, d)
out = v.permute(0, 2, 1, 3, 4).reshape(nh * nw, kh * kw, -1)
outputs.append(out)
return outputs
class MoonViT3dPretrainedModel(nn.Module):
"""Main vision tower model.
Uses KimiK25VisionConfig directly from transformers_utils/configs/kimi_k25.py.
"""
def __init__(
self,
config: KimiK25VisionConfig,
prefix: str = "",
):
super().__init__()
config = deepcopy(config)
self.config = config # Required for run_dp_sharded_mrope_vision_model
self.merge_kernel_size = config.merge_kernel_size
self.patch_size = config.patch_size
self.merge_type = config.merge_type
self.patch_embed = MoonVision3dPatchEmbed(
out_dim=config.hidden_size,
patch_size=config.patch_size,
pos_emb_height=config.init_pos_emb_height,
pos_emb_width=config.init_pos_emb_width,
pos_emb_time=config.init_pos_emb_time,
pos_emb_type=config.pos_emb_type,
)
self.encoder = MoonViT3dEncoder(
hidden_dim=config.hidden_size,
num_layers=config.num_hidden_layers,
block_cfg={
"num_heads": config.num_attention_heads,
"hidden_dim": config.hidden_size,
"mlp_dim": config.intermediate_size,
"activation": get_act_fn("gelu_pytorch_tanh"),
"attn_bias": True,
},
video_attn_type=config.video_attn_type,
prefix=maybe_prefix(prefix, "encoder"),
)
def forward(
self, pixel_values: torch.Tensor, grid_thws: torch.Tensor
) -> torch.Tensor:
"""
Args:
pixel_values (torch.Tensor): The input pixel values.
grid_thws (torch.Tensor): Temporal, height and width.
Returns:
torch.Tensor: The output tokens.
"""
hidden_states = self.patch_embed(pixel_values, grid_thws)
hidden_states = self.encoder(hidden_states, grid_thws)
if (
self.merge_type == "sd2_tpool"
): # spatial downsampling 2x with temporal pooling all
hidden_states = tpool_patch_merger(
hidden_states, grid_thws, merge_kernel_size=self.merge_kernel_size
)
else:
raise NotImplementedError(f"Not support {self.merge_type}")
return hidden_states
@torch.inference_mode()
def mm_projector_forward(mm_projector: torch.nn.Module, vt_output: list[torch.Tensor]):
"""Apply MM projector to vision tower outputs."""
num_embedding_list = [x.shape[0] for x in vt_output]
batched = torch.cat(vt_output, dim=0)
proj_out = mm_projector(batched)
proj_out = proj_out.reshape(-1, proj_out.shape[-1])
proj_out = torch.split(proj_out, num_embedding_list)
return proj_out
@torch.inference_mode()
def vision_tower_forward(
vision_tower: Any,
pixel_values: torch.Tensor,
grid_thw: torch.Tensor,
mm_projector: Any,
use_data_parallel: bool,
) -> list[torch.Tensor]:
"""DP-sharded vision tower forward with mrope.
Uses vLLM's standard data parallelism utility to shard the batch
across available GPUs, enabling parallel processing of vision features.
"""
if use_data_parallel:
grid_thw_list = grid_thw.tolist()
vt_outputs = run_dp_sharded_mrope_vision_model(
vision_model=vision_tower,
pixel_values=pixel_values,
grid_thw_list=grid_thw_list,
rope_type="rope_2d",
)
else:
vt_outputs = vision_tower(pixel_values, grid_thw)
tensors = mm_projector_forward(mm_projector, list(vt_outputs))
return list(tensors)
class KimiK25MultiModalProjector(nn.Module):
"""Multi-modal projector with patch merging for Kimi-K2.5."""
def __init__(
self,
config: KimiK25VisionConfig,
use_data_parallel: bool = False,
prefix: str = "",
):
super().__init__()
self.use_data_parallel = use_data_parallel
# Hidden size after patch merging
merge_h, merge_w = config.merge_kernel_size
self.hidden_size = config.hidden_size * merge_h * merge_w
self.pre_norm = torch.nn.LayerNorm(config.hidden_size, eps=1e-5)
self.linear_1 = ReplicatedLinear(
self.hidden_size,
self.hidden_size,
bias=True,
prefix=maybe_prefix(prefix, "linear_1"),
)
self.linear_2 = ReplicatedLinear(
self.hidden_size,
config.mm_hidden_size,
bias=True,
prefix=maybe_prefix(prefix, "linear_2"),
)
self.act = GELUActivation()
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size)
hidden_states, _ = self.linear_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.linear_2(hidden_states)
return hidden_states

View File

@@ -359,6 +359,7 @@ _MULTIMODAL_MODELS = {
),
"RForConditionalGeneration": ("rvl", "RForConditionalGeneration"),
"KimiVLForConditionalGeneration": ("kimi_vl", "KimiVLForConditionalGeneration"), # noqa: E501
"KimiK25ForConditionalGeneration": ("kimi_k25", "KimiK25ForConditionalGeneration"), # noqa: E501
"LightOnOCRForConditionalGeneration": (
"lightonocr",
"LightOnOCRForConditionalGeneration",

View File

@@ -20,6 +20,7 @@ from typing import (
)
import numpy as np
from PIL.Image import Image
from typing_extensions import NotRequired, TypeVar
from vllm.utils.collection_utils import full_groupby, is_list_of
@@ -29,7 +30,6 @@ from vllm.utils.jsontree import json_map_leaves
if TYPE_CHECKING:
import torch
import torch.types
from PIL.Image import Image
from transformers.feature_extraction_utils import BatchFeature
from .media import MediaWithBytes
@@ -105,6 +105,28 @@ The number of data items allowed per modality is restricted by
"""
class VisionChunkImage(TypedDict):
"""Represents an image wrapped as a vision chunk."""
type: Literal["image"]
image: Image
uuid: str | None
class VisionChunkVideo(TypedDict):
"""Represents a video chunk with metadata."""
type: Literal["video_chunk"]
video_chunk: list[Image]
uuid: str | None
prompt: str
video_idx: int
VisionChunk = VisionChunkImage | VisionChunkVideo
"""A vision chunk is either an image or a video chunk."""
@final
class MultiModalDataBuiltins(TypedDict, total=False):
"""Type annotations for modality types predefined by vLLM."""
@@ -118,6 +140,9 @@ class MultiModalDataBuiltins(TypedDict, total=False):
audio: ModalityData[AudioItem]
"""The input audio(s)."""
vision_chunk: ModalityData[VisionChunk]
"""The input visual atom(s) - unified modality for images and video chunks."""
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
"""

View File

@@ -384,6 +384,13 @@ class VideoEmbeddingItems(EmbeddingItems):
super().__init__(data, "video", expected_hidden_size)
class VisionChunkProcessorItems(ProcessorBatchItems[Any]):
"""Processor items for vision chunks (unified image and video chunks)."""
def __init__(self, data: Sequence[Any]) -> None:
super().__init__(data, "vision_chunk")
_D = TypeVar("_D", bound=ModalityDataItems[Any, Any])
@@ -652,11 +659,23 @@ class MultiModalDataParser:
return VideoProcessorItems(new_videos, metadata=metadata_lst)
def _parse_vision_chunk_data(
self,
data: ModalityData[Any],
) -> ModalityDataItems[Any, Any] | None:
"""Parse vision chunk data (unified image and video chunks)."""
if data is None or self._is_empty(data):
return None
if self.is_embeddings(data):
raise ValueError("Do not support embedding data for vision_chunk right now")
return VisionChunkProcessorItems(data)
def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:
return {
"audio": self._parse_audio_data,
"image": self._parse_image_data,
"video": self._parse_video_data,
"vision_chunk": self._parse_vision_chunk_data,
}
def parse_mm_data(self, mm_data: MultiModalDataDict) -> MultiModalDataItems:

View File

@@ -235,6 +235,27 @@ class VideoLoader:
VIDEO_LOADER_REGISTRY = ExtensionManager()
@VIDEO_LOADER_REGISTRY.register("identity")
class IdentityVideoLoader(VideoLoader):
"""IdentityVideoLoader returns raw video bytes without decoding.
This allows the model processor to handle video decoding and
is required for models like Kimi-K2.5 that need custom video chunk splitting.
NOTE: This is temporary for Kimi-K2.5 testing. Remember to change back
to opencv before release if needed.
"""
@classmethod
def load_bytes(
cls,
data: bytes,
num_frames: int = -1,
**kwargs: Any,
) -> tuple[Any, Any]:
return data, None
@VIDEO_LOADER_REGISTRY.register("opencv")
class OpenCVVideoBackend(VideoLoader):
def get_cv2_video_api(self):

View File

@@ -53,8 +53,8 @@ _REASONING_PARSERS_TO_REGISTER = {
"HunyuanA13BReasoningParser",
),
"kimi_k2": (
"deepseek_r1_reasoning_parser",
"DeepSeekR1ReasoningParser",
"kimi_k2_reasoning_parser",
"KimiK2ReasoningParser",
),
"minimax_m2": (
"minimax_m2_reasoning_parser",

View File

@@ -0,0 +1,80 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser
from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
from .identity_reasoning_parser import IdentityReasoningParser
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import (
ChatCompletionRequest,
)
else:
ChatCompletionRequest = Any
logger = init_logger(__name__)
class KimiK2ReasoningParser(ReasoningParser):
"""
Kimi K2 parser that delegates to either DeepSeekR1ReasoningParser or
IdentityReasoningParser based on `thinking` and `separate_reasoning`.
Unlike DeepSeekV3ReasoningParser which defaults to NOT thinking,
KimiK2ReasoningParser defaults to thinking mode (uses DeepSeekR1ReasoningParser).
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
chat_kwargs = kwargs.pop("chat_template_kwargs", {}) or {}
# Key difference: default to True instead of False
thinking = bool(chat_kwargs.pop("thinking", True))
if thinking:
self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs)
else:
self._parser = IdentityReasoningParser(tokenizer, *args, **kwargs)
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
return self._parser.is_reasoning_end(input_ids)
def is_reasoning_end_streaming(
self, input_ids: list[int], delta_ids: list[int]
) -> bool:
return self._parser.is_reasoning_end_streaming(input_ids, delta_ids)
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
return self._parser.extract_content_ids(input_ids)
def extract_reasoning(
self, model_output: str, request: "ChatCompletionRequest"
) -> tuple[str | None, str | None]:
return self._parser.extract_reasoning(model_output, request)
def extract_reasoning_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> DeltaMessage | None:
return self._parser.extract_reasoning_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
)

View File

@@ -20,9 +20,11 @@ from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption,
ChatTemplateResolutionError,
ConversationMessage,
build_video_prompts_from_mm_data,
load_chat_template,
parse_chat_messages,
parse_chat_messages_async,
rebuild_mm_uuids_from_mm_data,
)
from vllm.inputs import TextPrompt, TokensPrompt
from vllm.logger import init_logger
@@ -547,6 +549,40 @@ class HfRenderer(RendererLike):
**kwargs,
)
# NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5
# model which uses unified vision chunks for both images and videos.
if (
getattr(model_config.hf_config, "use_unified_vision_chunk", False)
and mm_uuids is not None
and mm_data is not None
):
mm_uuids = rebuild_mm_uuids_from_mm_data(mm_uuids, mm_data)
# get video placehoder, replace it with runtime video-chunk prompts
video_placeholder = getattr(
model_config.hf_config, "video_placeholder", None
)
if video_placeholder and isinstance(prompt_raw, str):
video_prompts = build_video_prompts_from_mm_data(mm_data)
# replace in order
prompt_raw_parts = prompt_raw.split(video_placeholder)
if len(prompt_raw_parts) == len(video_prompts) + 1:
prompt_raw = "".join(
[
prompt_raw_parts[i] + video_prompts[i]
for i in range(len(video_prompts))
]
)
prompt_raw += prompt_raw_parts[-1]
else:
logger.warning(
"Number of video placeholders (%d) does not match "
"number of videos (%d) in the request.",
len(prompt_raw_parts) - 1,
len(video_prompts),
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)
@@ -587,6 +623,40 @@ class HfRenderer(RendererLike):
**kwargs,
)
# NOTE: use_unified_vision_chunk is currently specific to Kimi-K2.5
# model which uses unified vision chunks for both images and videos.
if (
getattr(model_config.hf_config, "use_unified_vision_chunk", False)
and mm_uuids is not None
and mm_data is not None
):
mm_uuids = rebuild_mm_uuids_from_mm_data(mm_uuids, mm_data)
# get video placehoder, replace it with runtime video-chunk prompts
video_placeholder = getattr(
model_config.hf_config, "video_placeholder", None
)
if video_placeholder and isinstance(prompt_raw, str):
video_prompts = build_video_prompts_from_mm_data(mm_data)
# replace in order
prompt_raw_parts = prompt_raw.split(video_placeholder)
if len(prompt_raw_parts) == len(video_prompts) + 1:
prompt_raw = "".join(
[
prompt_raw_parts[i] + video_prompts[i]
for i in range(len(video_prompts))
]
)
prompt_raw += prompt_raw_parts[-1]
else:
logger.warning(
"Number of video placeholders (%d) does not match "
"number of videos (%d) in the request.",
len(prompt_raw_parts) - 1,
len(video_prompts),
)
prompt = (
TextPrompt(prompt=prompt_raw)
if isinstance(prompt_raw, str)

View File

@@ -81,6 +81,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
isaac="IsaacConfig",
kimi_linear="KimiLinearConfig",
kimi_vl="KimiVLConfig",
kimi_k25="KimiK25Config",
RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct)
RefinedWebModel="RWConfig", # For tiiuae/falcon-7b(-instruct)
jais="JAISConfig",

View File

@@ -38,6 +38,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
"MoonViTConfig": "vllm.transformers_utils.configs.moonvit",
"KimiLinearConfig": "vllm.transformers_utils.configs.kimi_linear",
"KimiVLConfig": "vllm.transformers_utils.configs.kimi_vl",
"KimiK25Config": "vllm.transformers_utils.configs.kimi_k25",
"NemotronConfig": "vllm.transformers_utils.configs.nemotron",
"NemotronHConfig": "vllm.transformers_utils.configs.nemotron_h",
"Olmo3Config": "vllm.transformers_utils.configs.olmo3",
@@ -77,6 +78,7 @@ __all__ = [
"MoonViTConfig",
"KimiLinearConfig",
"KimiVLConfig",
"KimiK25Config",
"NemotronConfig",
"NemotronHConfig",
"Olmo3Config",

View File

@@ -0,0 +1,129 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Kimi-K2.5 Model Configuration.
This configuration supports video-chunk as an internal modality type.
A video-chunk is the smallest independently processable unit of video.
"""
from transformers import DeepseekV3Config
from transformers.configuration_utils import PretrainedConfig
class KimiK25VisionConfig(PretrainedConfig):
model_type = "kimi_k25_vision"
def __init__(
self,
# Vision Tower
patch_size: int = 14,
init_pos_emb_height: int = 64,
init_pos_emb_width: int = 64,
init_pos_emb_time: int = 4,
pos_emb_type: str = "divided_fixed",
num_attention_heads: int = 16,
num_hidden_layers: int = 27,
hidden_size: int = 1152,
intermediate_size: int = 4304,
merge_kernel_size: tuple[int, int] = (2, 2),
video_attn_type: str = "spatial_temporal",
merge_type: str = "sd2_tpool",
# MM Projector
mm_projector_type: str = "patchmerger",
mm_hidden_size: int | None = None,
projector_hidden_act: str = "gelu",
projector_ln_eps: float = 1e-5,
**kwargs,
):
super().__init__(**kwargs)
# Vision Tower
self.patch_size = patch_size
self.init_pos_emb_height = init_pos_emb_height
self.init_pos_emb_width = init_pos_emb_width
self.init_pos_emb_time = init_pos_emb_time
self.pos_emb_type = pos_emb_type
self.num_attention_heads = num_attention_heads
self.num_hidden_layers = num_hidden_layers
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.merge_kernel_size = merge_kernel_size
self.video_attn_type = video_attn_type
self.merge_type = merge_type
# MM Projector
self.mm_projector_type = mm_projector_type
if mm_hidden_size is not None:
self.mm_hidden_size = mm_hidden_size
else:
self.mm_hidden_size = hidden_size
self.projector_hidden_act = projector_hidden_act
self.projector_ln_eps = projector_ln_eps
class KimiK25Config(PretrainedConfig):
"""Kimi-K2.5 model configuration.
Kimi-K2.5 extends Kimi-K2 with vision support using video-chunks.
A video-chunk consists of multiple consecutive frames
that are processed together with temporal pooling.
Args:
vision_config: Configuration for the vision tower and projector.
text_config: Configuration for the text model (DeepseekV3).
ignore_index: The ignore index for the loss function.
media_placeholder_token_id: The token ID for media placeholders.
pad_token_id: The token ID for padding.
"""
model_type = "kimi_k25"
def __init__(
self,
vision_config: dict | KimiK25VisionConfig | None = None,
text_config: dict | DeepseekV3Config | None = None,
ignore_index: int = -100,
media_placeholder_token_id: int = 163605,
pad_token_id: int = 0,
use_unified_vision_chunk: bool = False,
video_placeholder: str = "<|kimi_k25_video_placeholder|>",
**kwargs,
):
# Vision config
if vision_config is None:
vision_config = KimiK25VisionConfig()
elif isinstance(vision_config, dict):
vision_config = KimiK25VisionConfig(**vision_config)
self.vision_config: KimiK25VisionConfig = vision_config
# Text config
if text_config is None:
text_config = DeepseekV3Config()
elif isinstance(text_config, dict):
text_config = DeepseekV3Config(**text_config)
self.text_config: DeepseekV3Config = text_config
# Set mm_hidden_size to text hidden size if not explicitly set
if self.vision_config.mm_hidden_size == self.vision_config.hidden_size:
self.vision_config.mm_hidden_size = self.text_config.hidden_size
# Other config
self.ignore_index = ignore_index
self.media_placeholder_token_id = media_placeholder_token_id
self.use_unified_vision_chunk = use_unified_vision_chunk
self.video_placeholder = video_placeholder
# Propagate quantization config from text model
if getattr(self.text_config, "quantization_config", None) is not None:
self.quantization_config = self.text_config.quantization_config
super().__init__(pad_token_id=pad_token_id, **kwargs)
@property
def hidden_size(self) -> int:
"""Get hidden size from text config for compatibility."""
return self.text_config.hidden_size
@property
def vocab_size(self) -> int:
"""Get vocab size from text config for compatibility."""
return self.text_config.vocab_size

View File

@@ -257,6 +257,26 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
)
supports_update_block_table: bool = True
@classmethod
def get_cudagraph_support(
cls,
vllm_config: "VllmConfig",
kv_cache_spec: "AttentionSpec",
) -> AttentionCGSupport:
# FA2 does not support CUDA graphs with encoder-decoder models due to
# accuracy issues reported in https://github.com/vllm-project/vllm/issues/33091
if (
vllm_config.model_config.is_encoder_decoder
and get_flash_attn_version() == 2
):
logger.warning_once(
"FlashAttention2 does not support CUDA graphs with "
"encoder-decoder models due to accuracy issues reported in #33091. "
"Disabling CUDA graph."
)
return AttentionCGSupport.NEVER
return cls._cudagraph_support
def __init__(
self,
kv_cache_spec: AttentionSpec,

View File

@@ -911,6 +911,17 @@ class EngineCoreProc(EngineCore):
set_process_title("EngineCore")
decorate_logs()
if data_parallel and vllm_config.kv_transfer_config is not None:
# modify the engine_id and append the local_dp_rank to it to ensure
# that the kv_transfer_config is unique for each DP rank.
vllm_config.kv_transfer_config.engine_id = (
f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
)
logger.debug(
"Setting kv_transfer_config.engine_id to %s",
vllm_config.kv_transfer_config.engine_id,
)
parallel_config.data_parallel_index = dp_rank
if data_parallel and vllm_config.model_config.is_moe:
# Set data parallel rank for this engine process.
@@ -1285,17 +1296,6 @@ class DPEngineCoreProc(EngineCoreProc):
assert local_dp_rank is not None
assert 0 <= local_dp_rank <= dp_rank < dp_size
if vllm_config.kv_transfer_config is not None:
# modify the engine_id and append the local_dp_rank to it to ensure
# that the kv_transfer_config is unique for each DP rank.
vllm_config.kv_transfer_config.engine_id = (
f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
)
logger.debug(
"Setting kv_transfer_config.engine_id to %s",
vllm_config.kv_transfer_config.engine_id,
)
self.dp_rank = dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()

View File

@@ -313,6 +313,13 @@ class CoreEngineActorManager:
dp_vllm_config.parallel_config.placement_group = pg
local_client = index < local_engine_count
if dp_size > 1 and dp_vllm_config.kv_transfer_config is not None:
# modify the engine_id and append the local_dp_rank to it to ensure
# that the kv_transfer_config is unique for each DP rank.
dp_vllm_config.kv_transfer_config.engine_id = (
f"{dp_vllm_config.kv_transfer_config.engine_id}_dp{local_index}"
)
# Ray XPU known issue: dpctl initializes the GPU runtime early, so
# setting device env vars in Ray actor's initialization method
# will not affect device selection. See: