49 lines
1.6 KiB
Python
49 lines
1.6 KiB
Python
#!/usr/bin/env python3
|
|
"""Patch DeepseekCompressor cache for Blackwell: remove FlashMLA alignment."""
|
|
import sys
|
|
|
|
def patch(path):
|
|
with open(path, 'r') as f:
|
|
content = f.read()
|
|
|
|
if "CLAWMINE_PATCH_COMPRESSOR" in content:
|
|
print("Already patched, skipping")
|
|
return
|
|
|
|
old = """ return SlidingWindowMLASpec( # only has one vector instead of K + V
|
|
block_size=self.block_size,
|
|
num_kv_heads=1,
|
|
head_size=self.state_dim,
|
|
dtype=self.dtype,
|
|
sliding_window=self.sliding_window,
|
|
alignment=576, # NOTE: FlashMLA requires 576B alignment
|
|
)"""
|
|
|
|
new = """ # CLAWMINE_PATCH_COMPRESSOR: No FlashMLA alignment on Blackwell
|
|
from vllm.platforms import current_platform
|
|
_is_blackwell = (
|
|
current_platform.get_device_capability() is not None
|
|
and current_platform.get_device_capability().major >= 10
|
|
)
|
|
return SlidingWindowMLASpec( # only has one vector instead of K + V
|
|
block_size=self.block_size,
|
|
num_kv_heads=1,
|
|
head_size=self.state_dim,
|
|
dtype=self.dtype,
|
|
sliding_window=self.sliding_window,
|
|
alignment=None if _is_blackwell else 576,
|
|
)"""
|
|
|
|
if old not in content:
|
|
print("ERROR: Could not find the code to patch in " + path)
|
|
sys.exit(1)
|
|
|
|
content = content.replace(old, new)
|
|
|
|
with open(path, 'w') as f:
|
|
f.write(content)
|
|
print("Patched DeepseekCompressor for Blackwell")
|
|
|
|
if __name__ == "__main__":
|
|
patch(sys.argv[1])
|