[Attention] MLA - Flashinfer Ragged Prefill (#20034)
This commit is contained in:
committed by
GitHub
parent
922f316441
commit
5b032352cc
@@ -1,5 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
@@ -7,6 +9,11 @@ import torch
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
|
||||
ModelConfig, SchedulerConfig, VllmConfig)
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector import ( # noqa
|
||||
SharedStorageConnector)
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec)
|
||||
@@ -187,3 +194,58 @@ def create_model_runner_output(
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
)
|
||||
|
||||
|
||||
class TestSharedStorageConnector(SharedStorageConnector):
|
||||
|
||||
def __init__(self, config: VllmConfig, role):
|
||||
self.name = config.kv_transfer_config.kv_connector_extra_config["name"]
|
||||
self._connector = SharedStorageConnector(config, role)
|
||||
self.call_record: dict[str, int] = defaultdict(int)
|
||||
# Use a unique temp file per connector
|
||||
self._event_file = tempfile.gettempdir(
|
||||
) + f"/connector_{self.name}-{self.role.name}_events.log"
|
||||
# Start with an empty file
|
||||
with open(self._event_file, "w") as _:
|
||||
pass
|
||||
|
||||
def __getattribute__(self, name):
|
||||
if name in ("_connector", "call_record", "name", "_event_file",
|
||||
"__class__", "__dict__", "__getattribute__",
|
||||
"__init__"): # avoid recursion
|
||||
return object.__getattribute__(self, name)
|
||||
if not hasattr(self._connector, name):
|
||||
return object.__getattribute__(self, name)
|
||||
attr = getattr(self._connector, name)
|
||||
|
||||
# Intercept calls to the connector interface and write an event
|
||||
# for each one to a file, which can be read back in the main test proc.
|
||||
if callable(attr):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
self.call_record[name] += 1
|
||||
|
||||
# Include args that we're interested in
|
||||
to_log = [name]
|
||||
for arg in args:
|
||||
if isinstance(arg, int):
|
||||
to_log.append(str(arg))
|
||||
elif isinstance(arg, KVCacheBlocks):
|
||||
to_log.append(
|
||||
f"num_blocks={[len(b) for b in arg.blocks]}")
|
||||
|
||||
# Log the event as a line to the file
|
||||
try:
|
||||
with open(self._event_file, "a") as f:
|
||||
f.write(' '.join(to_log) + "\n")
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Could not log event {name} "
|
||||
f"for {self.name}: {e}")
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return attr
|
||||
|
||||
|
||||
KVConnectorFactory.register_connector("TestSharedStorageConnector", __name__,
|
||||
TestSharedStorageConnector.__name__)
|
||||
|
||||
Reference in New Issue
Block a user