[TPU][Quantization] TPU W8A8 (#11785)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -6,6 +6,7 @@ from torch.nn import Parameter
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.utils import _make_synced_weight_loader
|
||||
|
||||
__all__ = [
|
||||
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
|
||||
@@ -37,6 +38,18 @@ class BasevLLMParameter(Parameter):
|
||||
:returns: a torch.nn.parameter
|
||||
"""
|
||||
|
||||
# During weight loading, we often do something like:
|
||||
# narrowed_tensor = param.data.narrow(0, offset, len)
|
||||
# narrowed_tensor.copy_(real_weight)
|
||||
# expecting narrowed_tensor and param.data to share the same storage.
|
||||
# However, on TPUs, narrowed_tensor will lazily propagate to the base
|
||||
# tensor, which is param.data, leading to the redundant memory usage.
|
||||
# This sometimes causes OOM errors during model loading. To avoid this,
|
||||
# we sync the param tensor after its weight loader is called.
|
||||
from vllm.platforms import current_platform
|
||||
if current_platform.is_tpu():
|
||||
weight_loader = _make_synced_weight_loader(weight_loader)
|
||||
|
||||
self._weight_loader = weight_loader
|
||||
|
||||
@property
|
||||
|
||||
Reference in New Issue
Block a user