Files
nvfp4-megamoe-kernel/vllm/patches/patch_indexer_cache.py

59 lines
2.1 KiB
Python

#!/usr/bin/env python3
"""Patch DeepseekV4IndexerCache on Blackwell: remove FlashMLA alignment.
Same as patch_swa_cache but for the indexer cache class.
"""
import sys
def patch(path):
with open(path, 'r') as f:
content = f.read()
if "CLAWMINE_PATCH_INDEXER_CACHE" in content:
print("Already patched, skipping")
return
# Patch the indexer cache's get_kv_cache_spec to remove FlashMLA alignment on Blackwell
old = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# head_dim already carries the fp8 scale padding
# compress_ratio=1 for V3.2, >1 for DeepseekV4; both use the same cache layout.
return MLAAttentionSpec(
block_size=self.cache_config.block_size,
num_kv_heads=1,
head_size=self.head_dim,
dtype=self.dtype,
compress_ratio=self.compress_ratio,
# DeepseekV4 aligns indexer pages to FlashMLA's 576B so they can pack with
# the indexer's compressor state cache. V3.2 keeps the legacy layout.
alignment=576,
)"""
new = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# CLAWMINE_PATCH_INDEXER_CACHE: No FlashMLA alignment on Blackwell
from vllm.platforms import current_platform
_is_blackwell = (
current_platform.get_device_capability() is not None
and current_platform.get_device_capability().major >= 10
)
return MLAAttentionSpec(
block_size=self.cache_config.block_size,
num_kv_heads=1,
head_size=self.head_dim,
dtype=self.dtype,
compress_ratio=self.compress_ratio,
alignment=None if _is_blackwell else 576,
)"""
if old not in content:
print("ERROR: Could not find the code to patch in " + path)
sys.exit(1)
content = content.replace(old, new)
with open(path, 'w') as f:
f.write(content)
print("Patched DeepseekV4IndexerCache for Blackwell")
if __name__ == "__main__":
patch(sys.argv[1])