136 lines
6.1 KiB
Python
136 lines
6.1 KiB
Python
#!/usr/bin/env python3
|
|
"""Patch vLLM kv_cache_utils.py to handle DeepseekV4 SWA page sizes.
|
|
|
|
DeepseekV4 has three cache types:
|
|
- C128A (HCA): compress_ratio=128, very small page size
|
|
- C4A (CSA): compress_ratio=4, medium page size
|
|
- SWA: compress_ratio=1, large page size
|
|
|
|
The upstream code assumes SWA page sizes <= MLA page sizes and pads
|
|
SWA pages to match MLA. This breaks when SWA pages are LARGER than
|
|
MLA pages (which is always the case for DeepseekV4).
|
|
|
|
Our fix: when SWA pages exceed MLA pages, put them in their own
|
|
separate cache group without padding.
|
|
"""
|
|
import sys
|
|
|
|
def patch(path):
|
|
with open(path, 'r') as f:
|
|
content = f.read()
|
|
|
|
if "CLAWMINE_PATCH_KV_CACHE" in content:
|
|
print("Already patched, skipping")
|
|
return
|
|
|
|
# The old code: asserts SWA pages <= MLA pages, then pads SWA to MLA
|
|
old = """ assert max(sm_page_sizes) <= max(all_page_sizes)
|
|
|
|
# Unify page size by padding layers' page_size to the nearest larger page_size.
|
|
# Compute candidate (nearest larger page_size) for each unique page size.
|
|
size_to_candidate: dict[int, int] = {}
|
|
for ps in sm_page_sizes:
|
|
size_to_candidate[ps] = min(x for x in all_page_sizes if x >= ps)
|
|
# Pad and collect layer names per page size.
|
|
for layer_name, layer_spec in sm_spec.kv_cache_specs.items():
|
|
current_size = layer_spec.page_size_bytes
|
|
candidate = size_to_candidate[current_size]
|
|
if current_size < candidate:
|
|
object.__setattr__(layer_spec, "page_size_padded", candidate)
|
|
layers_per_size[candidate].append(layer_name)
|
|
# NOTE(yifan): for now, inside a UniformKV group, each page_size should
|
|
# have the same number of layers. This also means we don't need to pad layers
|
|
# inside a partial-full layer tuple.
|
|
assert len(set(len(layers) for layers in layers_per_size.values())) == 1
|
|
num_layers_per_size = len(next(iter(layers_per_size.values())))
|
|
|
|
# Split layers inside each UniformKV group for aligned #(layers).
|
|
# See `_get_kv_cache_groups_uniform_page_size` for more details.
|
|
num_tuple_groups = cdiv(num_layers_per_size, num_layer_tuples)
|
|
layer_tuples = list(zip(*layers_per_size.values()))
|
|
for i in range(num_tuple_groups):
|
|
group_layer_tuples = layer_tuples[i::num_tuple_groups]
|
|
# Flatten tuples and build dict for from_specs
|
|
group_layer_names = [
|
|
name for layer_tuple in group_layer_tuples for name in layer_tuple
|
|
]
|
|
group_layer_specs = {
|
|
name: sm_spec.kv_cache_specs[name] for name in group_layer_names
|
|
}
|
|
sub_sm_spec = UniformTypeKVCacheSpecs.from_specs(group_layer_specs)
|
|
assert sub_sm_spec is not None
|
|
swa_mla_groups.append(
|
|
KVCacheGroupSpec(
|
|
layer_names=group_layer_names,
|
|
kv_cache_spec=sub_sm_spec,
|
|
)
|
|
)"""
|
|
|
|
# The new code: handle both cases
|
|
new = """ # CLAWMINE_PATCH_KV_CACHE: Handle DeepseekV4 where SWA page sizes
|
|
# can be larger than MLA page sizes. Two cases:
|
|
# 1. All SWA pages <= some MLA page: original padding logic
|
|
# 2. Some SWA pages > all MLA pages: separate cache group, no padding
|
|
max_mla_page = max(all_page_sizes)
|
|
can_pad = max(sm_page_sizes) <= max_mla_page
|
|
|
|
if can_pad:
|
|
# Original logic: pad SWA pages to nearest MLA page
|
|
size_to_candidate: dict[int, int] = {}
|
|
for ps in sm_page_sizes:
|
|
size_to_candidate[ps] = min(x for x in all_page_sizes if x >= ps)
|
|
for layer_name, layer_spec in sm_spec.kv_cache_specs.items():
|
|
current_size = layer_spec.page_size_bytes
|
|
candidate = size_to_candidate[current_size]
|
|
if current_size < candidate:
|
|
object.__setattr__(layer_spec, "page_size_padded", candidate)
|
|
layers_per_size[candidate].append(layer_name)
|
|
assert len(set(len(layers) for layers in layers_per_size.values())) == 1
|
|
num_layers_per_size = len(next(iter(layers_per_size.values())))
|
|
num_tuple_groups = cdiv(num_layers_per_size, num_layer_tuples)
|
|
layer_tuples = list(zip(*layers_per_size.values()))
|
|
for i in range(num_tuple_groups):
|
|
group_layer_tuples = layer_tuples[i::num_tuple_groups]
|
|
group_layer_names = [
|
|
name for layer_tuple in group_layer_tuples for name in layer_tuple
|
|
]
|
|
group_layer_specs = {
|
|
name: sm_spec.kv_cache_specs[name] for name in group_layer_names
|
|
}
|
|
sub_sm_spec = UniformTypeKVCacheSpecs.from_specs(group_layer_specs)
|
|
assert sub_sm_spec is not None
|
|
swa_mla_groups.append(
|
|
KVCacheGroupSpec(
|
|
layer_names=group_layer_names,
|
|
kv_cache_spec=sub_sm_spec,
|
|
)
|
|
)
|
|
else:
|
|
# SWA pages are larger than MLA pages.
|
|
# Put each SWA layer in its own cache group (no padding needed).
|
|
# This is the DeepseekV4 Blackwell case where compress_ratio=1
|
|
# layers have much larger pages than compressed layers.
|
|
for layer_name, layer_spec in sm_spec.kv_cache_specs.items():
|
|
group_layer_specs = {layer_name: layer_spec}
|
|
sub_sm_spec = UniformTypeKVCacheSpecs.from_specs(group_layer_specs)
|
|
if sub_sm_spec is not None:
|
|
swa_mla_groups.append(
|
|
KVCacheGroupSpec(
|
|
layer_names=[layer_name],
|
|
kv_cache_spec=sub_sm_spec,
|
|
)
|
|
)"""
|
|
|
|
if old not in content:
|
|
print("ERROR: Could not find the code to patch")
|
|
sys.exit(1)
|
|
|
|
content = content.replace(old, new)
|
|
|
|
with open(path, 'w') as f:
|
|
f.write(content)
|
|
print("Patched kv_cache_utils.py for DeepseekV4 SWA page sizes")
|
|
|
|
if __name__ == "__main__":
|
|
patch(sys.argv[1])
|