[BugFix] Avoid calling KV connector layer APIs when metadata is unset (#28253)
Signed-off-by: David Ben-David <davidb@pliops.com> Co-authored-by: David Ben-David <davidb@pliops.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
@@ -837,6 +837,8 @@ def wait_for_kv_layer_from_connector(layer_name: str):
|
|||||||
return
|
return
|
||||||
|
|
||||||
connector = get_kv_transfer_group()
|
connector = get_kv_transfer_group()
|
||||||
|
if not connector.has_connector_metadata():
|
||||||
|
return
|
||||||
|
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
@@ -854,6 +856,8 @@ def maybe_save_kv_layer_to_connector(
|
|||||||
return
|
return
|
||||||
|
|
||||||
connector = get_kv_transfer_group()
|
connector = get_kv_transfer_group()
|
||||||
|
if not connector.has_connector_metadata():
|
||||||
|
return
|
||||||
|
|
||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
|||||||
@@ -204,11 +204,18 @@ class KVConnectorBase_V1(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
ConnectorMetadata: the connector metadata.
|
ConnectorMetadata: the connector metadata.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Should only be called while set to valid metadata.
|
# Should only be called while set to valid metadata.
|
||||||
assert self._connector_metadata is not None
|
assert self._connector_metadata is not None
|
||||||
return self._connector_metadata
|
return self._connector_metadata
|
||||||
|
|
||||||
|
def has_connector_metadata(self) -> bool:
|
||||||
|
"""Check whether the connector metadata is currently set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if connector metadata exists, False otherwise.
|
||||||
|
"""
|
||||||
|
return self._connector_metadata is not None
|
||||||
|
|
||||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||||
"""
|
"""
|
||||||
Initialize with the KV caches. Useful for pre-registering the
|
Initialize with the KV caches. Useful for pre-registering the
|
||||||
|
|||||||
@@ -171,16 +171,22 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
# We must override the base class method here because we need to bind
|
# We must override the base class method here because we need to bind
|
||||||
# the metadata to each connector in the order of the connectors in the
|
# the metadata to each connector in the order of the connectors in the
|
||||||
# MultiKVConnectorMetadata.
|
# MultiKVConnectorMetadata.
|
||||||
|
#
|
||||||
|
# Note: Call the base class method to ensure metadata is also set on the
|
||||||
|
# MultiConnector instance itself; otherwise, `has_connector_metadata()` will
|
||||||
|
# always return False.
|
||||||
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
|
def bind_connector_metadata(self, connector_metadata: KVConnectorMetadata) -> None:
|
||||||
assert isinstance(connector_metadata, MultiKVConnectorMetadata)
|
assert isinstance(connector_metadata, MultiKVConnectorMetadata)
|
||||||
if connector_metadata.extra_async_saves:
|
if connector_metadata.extra_async_saves:
|
||||||
self._extra_async_saves.update(connector_metadata.extra_async_saves)
|
self._extra_async_saves.update(connector_metadata.extra_async_saves)
|
||||||
for c, cm in zip(self._connectors, connector_metadata.metadata):
|
for c, cm in zip(self._connectors, connector_metadata.metadata):
|
||||||
c.bind_connector_metadata(cm)
|
c.bind_connector_metadata(cm)
|
||||||
|
super().bind_connector_metadata(connector_metadata)
|
||||||
|
|
||||||
def clear_connector_metadata(self) -> None:
|
def clear_connector_metadata(self) -> None:
|
||||||
for c in self._connectors:
|
for c in self._connectors:
|
||||||
c.clear_connector_metadata()
|
c.clear_connector_metadata()
|
||||||
|
super().clear_connector_metadata()
|
||||||
|
|
||||||
def shutdown(self):
|
def shutdown(self):
|
||||||
exception: Exception | None = None
|
exception: Exception | None = None
|
||||||
|
|||||||
Reference in New Issue
Block a user