Bugfix for offloading+prefetch for GLM-4.7-FP8 (#37178)
Signed-off-by: Benjamin Merkel <benjamin.merkel@tngtech.com> Co-authored-by: Benjamin Merkel <benjamin.merkel@tngtech.com>
This commit is contained in:
@@ -431,10 +431,32 @@ class _ModuleOffloader:
|
||||
|
||||
Called after process_weights_after_loading to ensure _cpu_storage
|
||||
contains the final processed weights, not stale pre-loading data.
|
||||
|
||||
Parameters whose underlying nn.Parameter was deleted by
|
||||
process_weights_after_loading (e.g. transient KV-cache scale params)
|
||||
are pruned from self._param_offloaders so they do not participate in
|
||||
buffer-pool allocation or prefetching.
|
||||
"""
|
||||
for param_offloader in self._param_offloaders.values():
|
||||
param_offloader.sync_cpu_storage()
|
||||
|
||||
# Remove offloaders whose parameter was deleted during
|
||||
# process_weights_after_loading (e.g. k_scale / v_scale).
|
||||
deleted = [
|
||||
name
|
||||
for name, offloader in self._param_offloaders.items()
|
||||
if getattr(offloader, "_param_deleted", False)
|
||||
]
|
||||
if deleted:
|
||||
logger.debug(
|
||||
"Pruning %d transient offloaded param(s) that were deleted "
|
||||
"by process_weights_after_loading: %s",
|
||||
len(deleted),
|
||||
deleted,
|
||||
)
|
||||
for name in deleted:
|
||||
del self._param_offloaders[name]
|
||||
|
||||
def get_param_infos(self) -> list[ParamInfo]:
|
||||
"""Get parameter metadata for buffer pool allocation.
|
||||
|
||||
@@ -590,6 +612,11 @@ class _CpuParamOffloader(_BaseParamOffloader):
|
||||
super().__init__(module, param_name)
|
||||
self._cpu_storage: torch.Tensor | None = None
|
||||
self._gpu_buffer: torch.Tensor | None = None # Store reference to GPU buffer
|
||||
# Set to True if the underlying nn.Parameter was deleted by
|
||||
# process_weights_after_loading (e.g. transient KV-cache scale params
|
||||
# such as k_scale/v_scale created by BaseKVCacheMethod.create_weights
|
||||
# and deleted after copying into permanent _k_scale buffers).
|
||||
self._param_deleted: bool = False
|
||||
|
||||
# Offload to CPU immediately to free GPU memory during model loading
|
||||
self._offload_to_cpu_internal()
|
||||
@@ -696,8 +723,22 @@ class _CpuParamOffloader(_BaseParamOffloader):
|
||||
1. process_weights_after_loading may transform weights (quantization)
|
||||
2. device_loading_context creates NEW CPU tensors when moving back
|
||||
3. Our old _cpu_storage would have pre-processed or stale data
|
||||
|
||||
If the parameter no longer exists on the module (e.g. transient
|
||||
KV-cache scale parameters such as k_scale/v_scale that are created
|
||||
by BaseKVCacheMethod.create_weights() and then deleted by
|
||||
process_weights_after_loading() after copying their values into
|
||||
permanent _k_scale buffers), the offloader marks itself as deleted
|
||||
and skips the sync. The caller (_ModuleOffloader.sync_cpu_storage)
|
||||
is responsible for removing these stale entries.
|
||||
"""
|
||||
self._update_cpu_storage_from_param()
|
||||
try:
|
||||
self._update_cpu_storage_from_param()
|
||||
except AttributeError:
|
||||
# The parameter was deleted by process_weights_after_loading.
|
||||
# Drop the now-stale CPU storage so this offloader can be pruned.
|
||||
self._param_deleted = True
|
||||
self._cpu_storage = None
|
||||
|
||||
def post_init(self):
|
||||
"""No-op: offloading done in offload_to_cpu/assign_static_buffer."""
|
||||
|
||||
Reference in New Issue
Block a user