#!/usr/bin/env python3 """Patch DeepseekCompressor cache for Blackwell: remove FlashMLA alignment.""" import sys def patch(path): with open(path, 'r') as f: content = f.read() if "CLAWMINE_PATCH_COMPRESSOR" in content: print("Already patched, skipping") return old = """ return SlidingWindowMLASpec( # only has one vector instead of K + V block_size=self.block_size, num_kv_heads=1, head_size=self.state_dim, dtype=self.dtype, sliding_window=self.sliding_window, alignment=576, # NOTE: FlashMLA requires 576B alignment )""" new = """ # CLAWMINE_PATCH_COMPRESSOR: No FlashMLA alignment on Blackwell from vllm.platforms import current_platform _is_blackwell = ( current_platform.get_device_capability() is not None and current_platform.get_device_capability().major >= 10 ) return SlidingWindowMLASpec( # only has one vector instead of K + V block_size=self.block_size, num_kv_heads=1, head_size=self.state_dim, dtype=self.dtype, sliding_window=self.sliding_window, alignment=None if _is_blackwell else 576, )""" if old not in content: print("ERROR: Could not find the code to patch in " + path) sys.exit(1) content = content.replace(old, new) with open(path, 'w') as f: f.write(content) print("Patched DeepseekCompressor for Blackwell") if __name__ == "__main__": patch(sys.argv[1])