[Attention] MLA - Flashinfer Ragged Prefill (#20034)

This commit is contained in:
Alexander Matveev
2025-07-10 23:17:47 -04:00
committed by GitHub
parent 922f316441
commit 5b032352cc
10 changed files with 421 additions and 214 deletions

View File

@@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from collections import defaultdict
from typing import Any, Optional
import torch
@@ -7,6 +9,11 @@ import torch
from vllm import SamplingParams
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
ModelConfig, SchedulerConfig, VllmConfig)
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
SharedStorageConnector)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
@@ -187,3 +194,58 @@ def create_model_runner_output(
finished_sending=finished_sending,
finished_recving=finished_recving,
)
class TestSharedStorageConnector(SharedStorageConnector):
def __init__(self, config: VllmConfig, role):
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
self._connector = SharedStorageConnector(config, role)
self.call_record: dict[str, int] = defaultdict(int)
# Use a unique temp file per connector
self._event_file = tempfile.gettempdir(
) + f"/connector_{self.name}-{self.role.name}_events.log"
# Start with an empty file
with open(self._event_file, "w") as _:
pass
def __getattribute__(self, name):
if name in ("_connector", "call_record", "name", "_event_file",
"__class__", "__dict__", "__getattribute__",
"__init__"): # avoid recursion
return object.__getattribute__(self, name)
if not hasattr(self._connector, name):
return object.__getattribute__(self, name)
attr = getattr(self._connector, name)
# Intercept calls to the connector interface and write an event
# for each one to a file, which can be read back in the main test proc.
if callable(attr):
def wrapper(*args, **kwargs):
self.call_record[name] += 1
# Include args that we're interested in
to_log = [name]
for arg in args:
if isinstance(arg, int):
to_log.append(str(arg))
elif isinstance(arg, KVCacheBlocks):
to_log.append(
f"num_blocks={[len(b) for b in arg.blocks]}")
# Log the event as a line to the file
try:
with open(self._event_file, "a") as f:
f.write(' '.join(to_log) + "\n")
except Exception as e:
print(f"[ERROR] Could not log event {name} "
f"for {self.name}: {e}")
return attr(*args, **kwargs)
return wrapper
return attr
KVConnectorFactory.register_connector("TestSharedStorageConnector", __name__,
TestSharedStorageConnector.__name__)