[USAGE] Improve error handling for weight initialization in Unquantized… (#20321)

Signed-off-by: Rafael Marcelino Koike <rafael.koike@oracle.com>
Signed-off-by: Rafael Koike <koike.rafael@gmail.com>
This commit is contained in:
Rafael Marcelino Koike
2025-09-15 12:45:49 -04:00
committed by GitHub
parent 740f0647b1
commit b834b4cbf1
2 changed files with 43 additions and 8 deletions

View File

@@ -29,6 +29,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
# yapf: enable
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import GiB_bytes
logger = init_logger(__name__)
@@ -190,10 +191,27 @@ class UnquantizedLinearMethod(LinearMethodBase):
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
# This method creates unquantized linear weights.
# The weights are not quantized, and they are not sharded.
# The amount of memory allocated for the weights is
# sum(output_partition_sizes) * input_size_per_partition.
try:
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
except torch.cuda.OutOfMemoryError as e:
logger.error("Failed to create unquantized linear weights: %s", e)
if torch.cuda.is_available():
logger.debug("CUDA device: %s", torch.cuda.current_device())
logger.debug("Allocated: %.2f GiB",
torch.cuda.memory_allocated() / GiB_bytes)
logger.debug("Reserved: %.2f GiB",
torch.cuda.memory_reserved() / GiB_bytes)
raise RuntimeError(
"Failed to create unquantized linear weights. "
"This may be caused by insufficient memory to allocate "
"the weight.") from e
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)