perf: add __slots__ to KVCacheBlock (#36164)
Signed-off-by: cong-or <conchubhar.gannon@gmail.com>
This commit is contained in:
@@ -202,6 +202,18 @@ def test_kv_cache_block():
|
||||
assert block.block_hash is None
|
||||
|
||||
|
||||
def test_kv_cache_block_uses_slots():
|
||||
block = KVCacheBlock(block_id=0)
|
||||
|
||||
# Slots eliminate per-instance __dict__, saving ~264 bytes per block.
|
||||
# At 100K+ blocks this avoids tens of MB of overhead and GC pressure.
|
||||
assert not hasattr(block, "__dict__")
|
||||
|
||||
# Verify that slots actually prevent dynamic attribute assignment.
|
||||
with pytest.raises(AttributeError):
|
||||
block.unexpected_field = True
|
||||
|
||||
|
||||
def test_free_kv_cache_block_queue_initialization():
|
||||
# Test with a single block
|
||||
block = KVCacheBlock(block_id=0)
|
||||
|
||||
@@ -106,7 +106,7 @@ def init_none_hash(hash_fn: Callable[[Any], bytes]):
|
||||
NONE_HASH = BlockHash(hash_fn(hash_seed))
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(slots=True)
|
||||
class KVCacheBlock:
|
||||
"""KV-cache block metadata."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user