Files
nvfp4-megamoe-kernel/vllm/patches/patch_compressor_cache.py

49 lines
1.6 KiB
Python

#!/usr/bin/env python3
"""Patch DeepseekCompressor cache for Blackwell: remove FlashMLA alignment."""
import sys
def patch(path):
with open(path, 'r') as f:
content = f.read()
if "CLAWMINE_PATCH_COMPRESSOR" in content:
print("Already patched, skipping")
return
old = """ return SlidingWindowMLASpec( # only has one vector instead of K + V
block_size=self.block_size,
num_kv_heads=1,
head_size=self.state_dim,
dtype=self.dtype,
sliding_window=self.sliding_window,
alignment=576, # NOTE: FlashMLA requires 576B alignment
)"""
new = """ # CLAWMINE_PATCH_COMPRESSOR: No FlashMLA alignment on Blackwell
from vllm.platforms import current_platform
_is_blackwell = (
current_platform.get_device_capability() is not None
and current_platform.get_device_capability().major >= 10
)
return SlidingWindowMLASpec( # only has one vector instead of K + V
block_size=self.block_size,
num_kv_heads=1,
head_size=self.state_dim,
dtype=self.dtype,
sliding_window=self.sliding_window,
alignment=None if _is_blackwell else 576,
)"""
if old not in content:
print("ERROR: Could not find the code to patch in " + path)
sys.exit(1)
content = content.replace(old, new)
with open(path, 'w') as f:
f.write(content)
print("Patched DeepseekCompressor for Blackwell")
if __name__ == "__main__":
patch(sys.argv[1])