Integrate Marlin Kernels for Int4 GPTQ inference (#2497)
Co-authored-by: Robert Shaw <114415538+rib-2@users.noreply.github.com> Co-authored-by: alexm <alexm@neuralmagic.com>
This commit is contained in:
@@ -17,6 +17,14 @@ from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def adjust_marlin_shard(param, shard_size, shard_offset):
|
||||
marlin_tile_size = getattr(param, "marlin_tile_size", None)
|
||||
if marlin_tile_size is None:
|
||||
return shard_size, shard_offset
|
||||
|
||||
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
||||
|
||||
|
||||
class LinearMethodBase(ABC):
|
||||
"""Base class for different (maybe quantized) linear methods."""
|
||||
|
||||
@@ -276,6 +284,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
|
||||
# If marlin, we need to adjust the offset and size to account for the tiling.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
output_dim, shard_offset, shard_size)
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
@@ -293,6 +306,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
|
||||
# If marlin, we need to adjust the offset and size to account for the tiling.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
start_idx = tp_rank * shard_size
|
||||
@@ -372,6 +390,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already packed.
|
||||
if output_dim is None:
|
||||
@@ -393,6 +412,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
|
||||
# If marlin, we need to adjust the offset and size to account for the tiling.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
loaded_weight_shard = loaded_weight.narrow(
|
||||
output_dim, shard_offset, shard_size)
|
||||
self.weight_loader(param, loaded_weight_shard, shard_id)
|
||||
@@ -417,6 +441,11 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
if packed_dim == output_dim:
|
||||
shard_size = shard_size // param.pack_factor
|
||||
shard_offset = shard_offset // param.pack_factor
|
||||
|
||||
# If marlin, we need to adjust the offset and size to account for the tiling.
|
||||
shard_size, shard_offset = adjust_marlin_shard(
|
||||
param, shard_size, shard_offset)
|
||||
|
||||
param_data = param_data.narrow(output_dim, shard_offset,
|
||||
shard_size)
|
||||
if loaded_shard_id == "q":
|
||||
|
||||
Reference in New Issue
Block a user