[MRV2] Fix for DS v3.2 (#38030)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -115,9 +115,12 @@ def _reshape_kv_cache(
|
||||
) -> dict[str, torch.Tensor]:
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_group_spec in kv_cache_config.kv_cache_groups:
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
assert isinstance(kv_cache_spec, AttentionSpec)
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
if isinstance(kv_cache_spec, UniformTypeKVCacheSpecs):
|
||||
kv_cache_spec = kv_cache_spec.kv_cache_specs[layer_name]
|
||||
assert isinstance(kv_cache_spec, AttentionSpec)
|
||||
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = raw_tensor.numel() // kv_cache_spec.page_size_bytes
|
||||
|
||||
Reference in New Issue
Block a user