72 lines
2.7 KiB
Python
72 lines
2.7 KiB
Python
#!/usr/bin/env python3
|
|
"""Patch DeepseekV4SWACache on Blackwell: remove FlashMLA alignment and model_version.
|
|
|
|
On Blackwell (SM100+), FlashMLA doesn't work. We use our own CSA/SDPA attention.
|
|
The SWA cache should use standard fp8 format (not fp8_ds_mla) and no FlashMLA alignment.
|
|
"""
|
|
import sys
|
|
|
|
def patch(path):
|
|
with open(path, 'r') as f:
|
|
content = f.read()
|
|
|
|
if "CLAWMINE_PATCH_SWA_CACHE" in content:
|
|
print("Already patched, skipping")
|
|
return
|
|
|
|
# Patch the get_kv_cache_spec method to remove FlashMLA-specific values on Blackwell
|
|
old = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
|
return SlidingWindowMLASpec(
|
|
block_size=self.block_size,
|
|
num_kv_heads=1,
|
|
head_size=self.head_dim,
|
|
dtype=self.dtype,
|
|
sliding_window=self.window_size,
|
|
cache_dtype_str=self.cache_config.cache_dtype,
|
|
alignment=576, # NOTE: FlashMLA requires 576B alignment
|
|
model_version="deepseek_v4",
|
|
)"""
|
|
|
|
new = """ def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
|
# CLAWMINE_PATCH_SWA_CACHE: On Blackwell, no FlashMLA = no 576B alignment
|
|
# Use standard fp8 format (not fp8_ds_mla), no model_version override
|
|
from vllm.platforms import current_platform
|
|
_is_blackwell = (
|
|
current_platform.get_device_capability() is not None
|
|
and current_platform.get_device_capability().major >= 10
|
|
)
|
|
if _is_blackwell:
|
|
return SlidingWindowMLASpec(
|
|
block_size=self.block_size,
|
|
num_kv_heads=1,
|
|
head_size=self.head_dim,
|
|
dtype=self.dtype,
|
|
sliding_window=self.window_size,
|
|
cache_dtype_str=self.cache_config.cache_dtype,
|
|
alignment=None, # No FlashMLA alignment on Blackwell
|
|
model_version=None, # Don't use 584B deepseek_v4 format
|
|
)
|
|
return SlidingWindowMLASpec(
|
|
block_size=self.block_size,
|
|
num_kv_heads=1,
|
|
head_size=self.head_dim,
|
|
dtype=self.dtype,
|
|
sliding_window=self.window_size,
|
|
cache_dtype_str=self.cache_config.cache_dtype,
|
|
alignment=576, # NOTE: FlashMLA requires 576B alignment
|
|
model_version="deepseek_v4",
|
|
)"""
|
|
|
|
if old not in content:
|
|
print("ERROR: Could not find the code to patch in " + path)
|
|
sys.exit(1)
|
|
|
|
content = content.replace(old, new)
|
|
|
|
with open(path, 'w') as f:
|
|
f.write(content)
|
|
print("Patched DeepseekV4SWACache for Blackwell")
|
|
|
|
if __name__ == "__main__":
|
|
patch(sys.argv[1])
|