Integrate Marlin Kernels for Int4 GPTQ inference (#2497)

Co-authored-by: Robert Shaw <114415538+rib-2@users.noreply.github.com>
Co-authored-by: alexm <alexm@neuralmagic.com>
This commit is contained in:
Robert Shaw
2024-03-01 14:47:51 -06:00
committed by GitHub
parent 90fbf12540
commit c0c2335ce0
12 changed files with 1752 additions and 6 deletions

View File

@@ -17,6 +17,14 @@ from vllm.logger import init_logger
logger = init_logger(__name__)
def adjust_marlin_shard(param, shard_size, shard_offset):
marlin_tile_size = getattr(param, "marlin_tile_size", None)
if marlin_tile_size is None:
return shard_size, shard_offset
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
class LinearMethodBase(ABC):
"""Base class for different (maybe quantized) linear methods."""
@@ -276,6 +284,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id)
@@ -293,6 +306,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
@@ -372,6 +390,7 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_shard_id: Optional[str] = None):
param_data = param.data
output_dim = getattr(param, "output_dim", None)
if loaded_shard_id is None:
# Loaded weight is already packed.
if output_dim is None:
@@ -393,6 +412,11 @@ class QKVParallelLinear(ColumnParallelLinear):
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id)
@@ -417,6 +441,11 @@ class QKVParallelLinear(ColumnParallelLinear):
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
if loaded_shard_id == "q":